videocr/video.py (view raw)
1from __future__ import annotations
2from concurrent import futures
3import datetime
4import pytesseract
5import cv2
6
7from .models import PredictedFrame, PredictedSubtitle
8
9
10class Video:
11 path: str
12 lang: str
13 use_fullframe: bool
14 num_frames: int
15 fps: float
16 pred_frames: List[PredictedFrame]
17 pred_subs: List[PredictedSubtitle]
18
19 def __init__(self, path: str):
20 self.path = path
21 v = cv2.VideoCapture(path)
22 self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
23 self.fps = v.get(cv2.CAP_PROP_FPS)
24 v.release()
25
26 def run_ocr(self, lang: str, time_start: str, time_end: str,
27 use_fullframe: bool) -> None:
28 self.lang = lang
29 self.use_fullframe = use_fullframe
30
31 ocr_start = self._frame_index(time_start) if time_start else 0
32 ocr_end = self._frame_index(time_end) if time_end else self.num_frames
33
34 if ocr_end < ocr_start:
35 raise ValueError('time_start is later than time_end')
36 num_ocr_frames = ocr_end - ocr_start
37
38 # get frames from ocr_start to ocr_end
39 v = cv2.VideoCapture(self.path)
40 v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
41 frames = (v.read()[1] for _ in range(num_ocr_frames))
42
43 # perform ocr to frames in parallel
44 with futures.ProcessPoolExecutor() as pool:
45 ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
46 self.pred_frames = [PredictedFrame(i + ocr_start, data)
47 for i, data in enumerate(ocr_map)]
48
49 v.release()
50
51 # convert time str to frame index
52 def _frame_index(self, time: str) -> int:
53 t = time.split(':')
54 t = list(map(float, t))
55 if len(t) == 3:
56 td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
57 elif len(t) == 2:
58 td = datetime.timedelta(minutes=t[0], seconds=t[1])
59 else:
60 raise ValueError(
61 'time data "{}" does not match format "%H:%M:%S"'.format(time))
62
63 index = int(td.total_seconds() * self.fps)
64 if index > self.num_frames or index < 0:
65 raise ValueError(
66 'time data "{}" exceeds video duration'.format(time))
67
68 return index
69
70 def _single_frame_ocr(self, img) -> str:
71 if not self.use_fullframe:
72 # only use bottom half of the frame by default
73 img = img[img.shape[0] // 2:, :]
74 config = r'--tessdata-dir "tessdata"'
75 return pytesseract.image_to_data(img, lang=self.lang, config=config)
76
77 def get_subtitles(self) -> str:
78 self._generate_subtitles()
79 return ''.join(
80 '{}\n{} --> {}\n{}\n\n'.format(
81 i,
82 self._srt_timestamp(sub.index_start),
83 self._srt_timestamp(sub.index_end),
84 sub.text)
85 for i, sub in enumerate(self.pred_subs))
86
87 def _generate_subtitles(self) -> None:
88 self.pred_subs = []
89
90 if self.pred_frames is None:
91 raise AttributeError(
92 'Please call self.run_ocr() first to perform ocr on frames')
93
94 # divide ocr of frames into subtitle paragraphs using sliding window
95 WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
96 bound = WIN_BOUND
97 i = 0
98 j = 1
99 while j < len(self.pred_frames):
100 fi, fj = self.pred_frames[i], self.pred_frames[j]
101
102 if fi.is_similar_to(fj):
103 bound = WIN_BOUND
104 elif bound > 0:
105 bound -= 1
106 else:
107 # divide subtitle paragraphs
108 para_new = j - WIN_BOUND
109 self._append_sub(
110 PredictedSubtitle(self.pred_frames[i:para_new]))
111 i = para_new
112 j = i
113 bound = WIN_BOUND
114
115 j += 1
116
117 # also handle the last remaining frames
118 if i < len(self.pred_frames) - 1:
119 self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
120
121 def _append_sub(self, sub: PredictedSubtitle) -> None:
122 if len(sub.text) == 0:
123 return
124
125 # merge new sub to the last subs if they are similar
126 while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
127 ls = self.pred_subs[-1]
128 del self.pred_subs[-1]
129 sub = PredictedSubtitle(ls.frames + sub.frames)
130
131 self.pred_subs.append(sub)
132
133 def _srt_timestamp(self, frame_index: int) -> str:
134 td = datetime.timedelta(seconds=frame_index / self.fps)
135 ms = td.microseconds // 1000
136 m, s = divmod(td.seconds, 60)
137 h, m = divmod(m, 60)
138 return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)