videocr/video.py (view raw)
1from __future__ import annotations
2from multiprocessing import Pool
3import datetime
4import pytesseract
5import cv2
6
7from . import constants
8from .models import PredictedFrame, PredictedSubtitle
9
10
11class Video:
12 path: str
13 lang: str
14 use_fullframe: bool
15 num_frames: int
16 fps: float
17 height: int
18 pred_frames: List[PredictedFrame]
19 pred_subs: List[PredictedSubtitle]
20
21 def __init__(self, path: str):
22 self.path = path
23 v = cv2.VideoCapture(path)
24 if not v.isOpened():
25 raise IOError('can not open video format {}'.format(path))
26 self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
27 self.fps = v.get(cv2.CAP_PROP_FPS)
28 self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
29 v.release()
30
31 def run_ocr(self, lang: str, time_start: str, time_end: str,
32 conf_threshold:int, use_fullframe: bool) -> None:
33 self.lang = lang
34 self.use_fullframe = use_fullframe
35
36 ocr_start = self._frame_index(time_start) if time_start else 0
37 ocr_end = self._frame_index(time_end) if time_end else self.num_frames
38
39 if ocr_end < ocr_start:
40 raise ValueError('time_start is later than time_end')
41 num_ocr_frames = ocr_end - ocr_start
42
43 # get frames from ocr_start to ocr_end
44 v = cv2.VideoCapture(self.path)
45 v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
46 frames = (v.read()[1] for _ in range(num_ocr_frames))
47
48 # perform ocr to frames in parallel
49 with Pool() as pool:
50 it_ocr = pool.imap(self._single_frame_ocr, frames, chunksize=10)
51 self.pred_frames = [
52 PredictedFrame(i + ocr_start, data, conf_threshold)
53 for i, data in enumerate(it_ocr)]
54
55 v.release()
56
57 # convert time str to frame index
58 def _frame_index(self, time: str) -> int:
59 t = time.split(':')
60 t = list(map(float, t))
61 if len(t) == 3:
62 td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
63 elif len(t) == 2:
64 td = datetime.timedelta(minutes=t[0], seconds=t[1])
65 else:
66 raise ValueError(
67 'time data "{}" does not match format "%H:%M:%S"'.format(time))
68
69 index = int(td.total_seconds() * self.fps)
70 if index > self.num_frames or index < 0:
71 raise ValueError(
72 'time data "{}" exceeds video duration'.format(time))
73
74 return index
75
76 def _single_frame_ocr(self, img) -> str:
77 if not self.use_fullframe:
78 # only use bottom half of the frame by default
79 img = img[self.height // 2:, :]
80 config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
81 return pytesseract.image_to_data(img, lang=self.lang, config=config)
82
83 def get_subtitles(self, sim_threshold: int) -> str:
84 self._generate_subtitles(sim_threshold)
85 return ''.join(
86 '{}\n{} --> {}\n{}\n\n'.format(
87 i,
88 self._srt_timestamp(sub.index_start),
89 self._srt_timestamp(sub.index_end),
90 sub.text)
91 for i, sub in enumerate(self.pred_subs))
92
93 def _generate_subtitles(self, sim_threshold: int) -> None:
94 self.pred_subs = []
95
96 if self.pred_frames is None:
97 raise AttributeError(
98 'Please call self.run_ocr() first to perform ocr on frames')
99
100 # divide ocr of frames into subtitle paragraphs using sliding window
101 WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
102 bound = WIN_BOUND
103 i = 0
104 j = 1
105 while j < len(self.pred_frames):
106 fi, fj = self.pred_frames[i], self.pred_frames[j]
107
108 if fi.is_similar_to(fj):
109 bound = WIN_BOUND
110 elif bound > 0:
111 bound -= 1
112 else:
113 # divide subtitle paragraphs
114 para_new = j - WIN_BOUND
115 self._append_sub(PredictedSubtitle(
116 self.pred_frames[i:para_new], sim_threshold))
117 i = para_new
118 j = i
119 bound = WIN_BOUND
120
121 j += 1
122
123 # also handle the last remaining frames
124 if i < len(self.pred_frames) - 1:
125 self._append_sub(PredictedSubtitle(
126 self.pred_frames[i:], sim_threshold))
127
128 def _append_sub(self, sub: PredictedSubtitle) -> None:
129 if len(sub.text) == 0:
130 return
131
132 # merge new sub to the last subs if they are similar
133 while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
134 ls = self.pred_subs[-1]
135 del self.pred_subs[-1]
136 sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
137
138 self.pred_subs.append(sub)
139
140 def _srt_timestamp(self, frame_index: int) -> str:
141 td = datetime.timedelta(seconds=frame_index / self.fps)
142 ms = td.microseconds // 1000
143 m, s = divmod(td.seconds, 60)
144 h, m = divmod(m, 60)
145 return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)