videocr/video.py (view raw)
1from __future__ import annotations
2from concurrent import futures
3import datetime
4import pytesseract
5import cv2
6import timeit
7
8from .models import PredictedFrame, PredictedSubtitle
9
10
11class Video:
12 path: str
13 lang: str
14 use_fullframe: bool
15 num_frames: int
16 fps: float
17 ocr_frame_start: int
18 num_ocr_frames: int
19 pred_frames: List[PredictedFrame]
20 pred_subs: List[PredictedSubtitle]
21
22 def __init__(self, path: str, lang: str, use_fullframe=False,
23 time_start='0:00', time_end=''):
24 self.path = path
25 self.lang = lang
26 self.use_fullframe = use_fullframe
27 v = cv2.VideoCapture(path)
28 self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
29 self.fps = v.get(cv2.CAP_PROP_FPS)
30 v.release()
31
32 self.ocr_frame_start = self._frame_index(time_start)
33
34 if time_end:
35 ocr_end = self._frame_index(time_end)
36 else:
37 ocr_end = self.num_frames
38 self.num_ocr_frames = ocr_end - self.ocr_frame_start
39
40 if self.num_ocr_frames < 0:
41 raise ValueError('time_start is later than time_end')
42
43 def _frame_index(self, time: str) -> int:
44 t = time.split(':')
45 t = list(map(int, t))
46 if len(t) == 3:
47 td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
48 elif len(t) == 2:
49 td = datetime.timedelta(minutes=t[0], seconds=t[1])
50 else:
51 raise ValueError(
52 'time data "{}" does not match format "%H:%M:%S"'.format(time))
53
54 index = int(td.total_seconds() * self.fps)
55 if index > self.num_frames or index < 0:
56 raise ValueError(
57 'time data "{}" exceeds video duration'.format(time))
58
59 return index
60
61 def run_ocr(self) -> None:
62 v = cv2.VideoCapture(self.path)
63 v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start)
64 frames = (v.read()[1] for _ in range(self.num_ocr_frames))
65
66 # perform ocr to all frames in parallel
67 with futures.ProcessPoolExecutor() as pool:
68 ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
69 self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data)
70 for i, data in enumerate(ocr_map)]
71
72 v.release()
73
74 def _single_frame_ocr(self, img) -> str:
75 if not self.use_fullframe:
76 # only use bottom half of the frame by default
77 img = img[img.shape[0] // 2:, :]
78 return pytesseract.image_to_data(img, lang=self.lang)
79
80 def get_subtitles(self) -> str:
81 self._generate_subtitles()
82 return ''.join(
83 '{}\n{} --> {}\n{}\n\n'.format(
84 i,
85 self._srt_timestamp(sub.index_start),
86 self._srt_timestamp(sub.index_end),
87 sub.text)
88 for i, sub in enumerate(self.pred_subs))
89
90 def _generate_subtitles(self) -> None:
91 self.pred_subs = []
92
93 if self.pred_frames is None:
94 raise AttributeError(
95 'Please call self.run_ocr() first to perform ocr on frames')
96
97 # divide ocr of frames into subtitle paragraphs using sliding window
98 WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
99 bound = WIN_BOUND
100 i = 0
101 j = 1
102 while j < self.num_ocr_frames:
103 fi, fj = self.pred_frames[i], self.pred_frames[j]
104
105 if fi.is_similar_to(fj):
106 bound = WIN_BOUND
107 elif bound > 0:
108 bound -= 1
109 else:
110 # divide subtitle paragraphs
111 para_new = j - WIN_BOUND
112 self._append_sub(
113 PredictedSubtitle(self.pred_frames[i:para_new]))
114 i = para_new
115 j = i
116 bound = WIN_BOUND
117
118 j += 1
119
120 # also handle the last remaining frames
121 if i < self.num_ocr_frames - 1:
122 self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
123
124 def _append_sub(self, sub: PredictedSubtitle) -> None:
125 if len(sub.text) == 0:
126 return
127
128 # merge new sub to the last subs if they are similar
129 while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
130 ls = self.pred_subs[-1]
131 del self.pred_subs[-1]
132 sub = PredictedSubtitle(ls.frames + sub.frames)
133
134 self.pred_subs.append(sub)
135
136 def _srt_timestamp(self, frame_index: int) -> str:
137 td = datetime.timedelta(seconds=frame_index / self.fps)
138 ms = td.microseconds // 1000
139 m, s = divmod(td.seconds, 60)
140 h, m = divmod(m, 60)
141 return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
142
143 def save_subtitles_to_file(self, path='subtitle.srt') -> None:
144 with open(path, 'w+') as f:
145 f.write(self.get_subtitles())
146
147
148time_start = timeit.default_timer()
149v = Video('1.mp4', 'HanS')
150v.run_ocr()
151time_stop = timeit.default_timer()
152print('time for ocr: ', time_stop - time_start)
153
154time_start = timeit.default_timer()
155v.save_subtitles_to_file()
156time_stop = timeit.default_timer()
157print('time for save sub: ', time_stop - time_start)