move video parameters to run_ocr() function
Yi Ge me@yige.ch
Sat, 27 Apr 2019 21:41:19 +0200
1 files changed,
25 insertions(+),
25 deletions(-)
jump to
M
videocr/video.py
→
videocr/video.py
@@ -14,32 +14,45 @@ lang: str
use_fullframe: bool num_frames: int fps: float - ocr_frame_start: int - num_ocr_frames: int pred_frames: List[PredictedFrame] pred_subs: List[PredictedSubtitle] - def __init__(self, path: str, lang: str, use_fullframe=False, - time_start='0:00', time_end=''): + def __init__(self, path: str): self.path = path - self.lang = lang - self.use_fullframe = use_fullframe v = cv2.VideoCapture(path) self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.fps = v.get(cv2.CAP_PROP_FPS) v.release() - self.ocr_frame_start = self._frame_index(time_start) + def run_ocr(self, lang: str, use_fullframe=False, + time_start='0:00', time_end='') -> None: + self.lang = lang + self.use_fullframe = use_fullframe + + ocr_start = self._frame_index(time_start) if time_end: ocr_end = self._frame_index(time_end) + if ocr_end < ocr_start: + raise ValueError('time_start is later than time_end') else: ocr_end = self.num_frames - self.num_ocr_frames = ocr_end - self.ocr_frame_start + num_ocr_frames = ocr_end - ocr_start - if self.num_ocr_frames < 0: - raise ValueError('time_start is later than time_end') + # get frames from ocr_start to ocr_end + v = cv2.VideoCapture(self.path) + v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) + frames = (v.read()[1] for _ in range(num_ocr_frames)) + # perform ocr to frames in parallel + with futures.ProcessPoolExecutor() as pool: + ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) + self.pred_frames = [PredictedFrame(i + ocr_start, data) + for i, data in enumerate(ocr_map)] + + v.release() + + # convert time str to frame index def _frame_index(self, time: str) -> int: t = time.split(':') t = list(map(int, t))@@ -58,19 +71,6 @@ 'time data "{}" exceeds video duration'.format(time))
return index - def run_ocr(self) -> None: - v = cv2.VideoCapture(self.path) - v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start) - frames = (v.read()[1] for _ in range(self.num_ocr_frames)) - - # perform ocr to all frames in parallel - with futures.ProcessPoolExecutor() as pool: - ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) - self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data) - for i, data in enumerate(ocr_map)] - - v.release() - def _single_frame_ocr(self, img) -> str: if not self.use_fullframe: # only use bottom half of the frame by default@@ -99,7 +99,7 @@ WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
bound = WIN_BOUND i = 0 j = 1 - while j < self.num_ocr_frames: + while j < len(self.pred_frames): fi, fj = self.pred_frames[i], self.pred_frames[j] if fi.is_similar_to(fj):@@ -118,7 +118,7 @@
j += 1 # also handle the last remaining frames - if i < self.num_ocr_frames - 1: + if i < len(self.pred_frames) - 1: self._append_sub(PredictedSubtitle(self.pred_frames[i:])) def _append_sub(self, sub: PredictedSubtitle) -> None: