videocr/video.py (view raw)
1from __future__ import annotations
2import multiprocessing
3import pytesseract
4import cv2
5
6from . import constants
7from . import utils
8from .models import PredictedFrame, PredictedSubtitle
9from .opencv_adapter import Capture
10
11
12class Video:
13 path: str
14 lang: str
15 use_fullframe: bool
16 num_frames: int
17 fps: float
18 height: int
19 pred_frames: List[PredictedFrame]
20 pred_subs: List[PredictedSubtitle]
21
22 def __init__(self, path: str):
23 self.path = path
24 with Capture(path) as v:
25 self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
26 self.fps = v.get(cv2.CAP_PROP_FPS)
27 self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
29 def run_ocr(self, lang: str, time_start: str, time_end: str,
30 conf_threshold: int, use_fullframe: bool) -> None:
31 self.lang = lang
32 self.use_fullframe = use_fullframe
33
34 ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
35 ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
36
37 if ocr_end < ocr_start:
38 raise ValueError('time_start is later than time_end')
39 num_ocr_frames = ocr_end - ocr_start
40
41 # get frames from ocr_start to ocr_end
42 with Capture(self.path) as v, multiprocessing.Pool() as pool:
43 v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
44 frames = (v.read()[1] for _ in range(num_ocr_frames))
45
46 # perform ocr to frames in parallel
47 it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
48 self.pred_frames = [
49 PredictedFrame(i + ocr_start, data, conf_threshold)
50 for i, data in enumerate(it_ocr)
51 ]
52
53 def _image_to_data(self, img) -> str:
54 if not self.use_fullframe:
55 # only use bottom half of the frame by default
56 img = img[self.height // 2:, :]
57 config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
58 return pytesseract.image_to_data(img, lang=self.lang, config=config)
59
60 def get_subtitles(self, sim_threshold: int) -> str:
61 self._generate_subtitles(sim_threshold)
62 return ''.join(
63 '{}\n{} --> {}\n{}\n\n'.format(
64 i,
65 utils.get_srt_timestamp(sub.index_start, self.fps),
66 utils.get_srt_timestamp(sub.index_end, self.fps),
67 sub.text)
68 for i, sub in enumerate(self.pred_subs))
69
70 def _generate_subtitles(self, sim_threshold: int) -> None:
71 self.pred_subs = []
72
73 if self.pred_frames is None:
74 raise AttributeError(
75 'Please call self.run_ocr() first to perform ocr on frames')
76
77 # divide ocr of frames into subtitle paragraphs using sliding window
78 WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
79 bound = WIN_BOUND
80 i = 0
81 j = 1
82 while j < len(self.pred_frames):
83 fi, fj = self.pred_frames[i], self.pred_frames[j]
84
85 if fi.is_similar_to(fj):
86 bound = WIN_BOUND
87 elif bound > 0:
88 bound -= 1
89 else:
90 # divide subtitle paragraphs
91 para_new = j - WIN_BOUND
92 self._append_sub(PredictedSubtitle(
93 self.pred_frames[i:para_new], sim_threshold))
94 i = para_new
95 j = i
96 bound = WIN_BOUND
97
98 j += 1
99
100 # also handle the last remaining frames
101 if i < len(self.pred_frames) - 1:
102 self._append_sub(PredictedSubtitle(
103 self.pred_frames[i:], sim_threshold))
104
105 def _append_sub(self, sub: PredictedSubtitle) -> None:
106 if len(sub.text) == 0:
107 return
108
109 # merge new sub to the last subs if they are similar
110 while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
111 ls = self.pred_subs[-1]
112 del self.pred_subs[-1]
113 sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
114
115 self.pred_subs.append(sub)