videocr/video.py (view raw)
1from __future__ import annotations
2from typing import List
3import sys
4import multiprocessing
5import pytesseract
6import cv2
7
8from . import constants
9from . import utils
10from .models import PredictedFrame, PredictedSubtitle
11from .opencv_adapter import Capture
12
13
14class Video:
15 path: str
16 lang: str
17 use_fullframe: bool
18 num_frames: int
19 fps: float
20 height: int
21 pred_frames: List[PredictedFrame]
22 pred_subs: List[PredictedSubtitle]
23
24 def __init__(self, path: str):
25 self.path = path
26 with Capture(path) as v:
27 self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
28 self.fps = v.get(cv2.CAP_PROP_FPS)
29 self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
30
31 def run_ocr(self, lang: str, time_start: str, time_end: str,
32 conf_threshold: int, use_fullframe: bool) -> None:
33 self.lang = lang
34 self.use_fullframe = use_fullframe
35
36 ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
37 ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
38
39 if ocr_end < ocr_start:
40 raise ValueError('time_start is later than time_end')
41 num_ocr_frames = ocr_end - ocr_start
42
43 # get frames from ocr_start to ocr_end
44 with Capture(self.path) as v, multiprocessing.Pool() as pool:
45 v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
46 frames = (v.read()[1] for _ in range(num_ocr_frames))
47
48 # perform ocr to frames in parallel
49 it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
50 self.pred_frames = [
51 PredictedFrame(i + ocr_start, data, conf_threshold)
52 for i, data in enumerate(it_ocr)
53 ]
54
55 def _image_to_data(self, img) -> str:
56 if not self.use_fullframe:
57 # only use bottom half of the frame by default
58 img = img[self.height // 2:, :]
59 config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
60 try:
61 return pytesseract.image_to_data(img, lang=self.lang, config=config)
62 except Exception as e:
63 sys.exit('{}: {}'.format(e.__class__.__name__, e))
64
65 def get_subtitles(self, sim_threshold: int) -> str:
66 self._generate_subtitles(sim_threshold)
67 return ''.join(
68 '{}\n{} --> {}\n{}\n\n'.format(
69 i,
70 utils.get_srt_timestamp(sub.index_start, self.fps),
71 utils.get_srt_timestamp(sub.index_end, self.fps),
72 sub.text)
73 for i, sub in enumerate(self.pred_subs))
74
75 def _generate_subtitles(self, sim_threshold: int) -> None:
76 self.pred_subs = []
77
78 if self.pred_frames is None:
79 raise AttributeError(
80 'Please call self.run_ocr() first to perform ocr on frames')
81
82 # divide ocr of frames into subtitle paragraphs using sliding window
83 WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
84 bound = WIN_BOUND
85 i = 0
86 j = 1
87 while j < len(self.pred_frames):
88 fi, fj = self.pred_frames[i], self.pred_frames[j]
89
90 if fi.is_similar_to(fj):
91 bound = WIN_BOUND
92 elif bound > 0:
93 bound -= 1
94 else:
95 # divide subtitle paragraphs
96 para_new = j - WIN_BOUND
97 self._append_sub(PredictedSubtitle(
98 self.pred_frames[i:para_new], sim_threshold))
99 i = para_new
100 j = i
101 bound = WIN_BOUND
102
103 j += 1
104
105 # also handle the last remaining frames
106 if i < len(self.pred_frames) - 1:
107 self._append_sub(PredictedSubtitle(
108 self.pred_frames[i:], sim_threshold))
109
110 def _append_sub(self, sub: PredictedSubtitle) -> None:
111 if len(sub.text) == 0:
112 return
113
114 # merge new sub to the last subs if they are similar
115 while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
116 ls = self.pred_subs[-1]
117 del self.pred_subs[-1]
118 sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
119
120 self.pred_subs.append(sub)