all repos — videocr @ 720c9d479ffc8e6f314d823102d99ef2b581cf66

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from multiprocessing import Pool
  3import datetime
  4import pytesseract
  5import cv2
  6
  7from . import constants
  8from .models import PredictedFrame, PredictedSubtitle
  9
 10
 11class Video:
 12    path: str
 13    lang: str
 14    use_fullframe: bool
 15    num_frames: int
 16    fps: float
 17    height: int
 18    pred_frames: List[PredictedFrame]
 19    pred_subs: List[PredictedSubtitle]
 20
 21    def __init__(self, path: str):
 22        self.path = path
 23        v = cv2.VideoCapture(path)
 24        if not v.isOpened():
 25            raise IOError('can not open video format {}'.format(path))
 26        self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 27        self.fps = v.get(cv2.CAP_PROP_FPS)
 28        self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
 29        v.release()
 30
 31    def run_ocr(self, lang: str, time_start: str, time_end: str,
 32                conf_threshold:int, use_fullframe: bool) -> None:
 33        self.lang = lang
 34        self.use_fullframe = use_fullframe
 35
 36        ocr_start = self._frame_index(time_start) if time_start else 0
 37        ocr_end = self._frame_index(time_end) if time_end else self.num_frames
 38
 39        if ocr_end < ocr_start:
 40            raise ValueError('time_start is later than time_end')
 41        num_ocr_frames = ocr_end - ocr_start
 42
 43        # get frames from ocr_start to ocr_end
 44        v = cv2.VideoCapture(self.path)
 45        v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
 46        frames = (v.read()[1] for _ in range(num_ocr_frames))
 47
 48        # perform ocr to frames in parallel
 49        with Pool() as pool:
 50            it_ocr = pool.imap(self._single_frame_ocr, frames, chunksize=10)
 51            self.pred_frames = [
 52                PredictedFrame(i + ocr_start, data, conf_threshold) 
 53                for i, data in enumerate(it_ocr)]
 54
 55        v.release()
 56
 57    # convert time str to frame index
 58    def _frame_index(self, time: str) -> int:
 59        t = time.split(':')
 60        t = list(map(float, t))
 61        if len(t) == 3:
 62            td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
 63        elif len(t) == 2:
 64            td = datetime.timedelta(minutes=t[0], seconds=t[1])
 65        else:
 66            raise ValueError(
 67                'time data "{}" does not match format "%H:%M:%S"'.format(time))
 68
 69        index = int(td.total_seconds() * self.fps)
 70        if index > self.num_frames or index < 0:
 71            raise ValueError(
 72                'time data "{}" exceeds video duration'.format(time))
 73
 74        return index
 75
 76    def _single_frame_ocr(self, img) -> str:
 77        if not self.use_fullframe:
 78            # only use bottom half of the frame by default
 79            img = img[self.height // 2:, :]
 80        config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
 81        return pytesseract.image_to_data(img, lang=self.lang, config=config)
 82
 83    def get_subtitles(self, sim_threshold: int) -> str:
 84        self._generate_subtitles(sim_threshold)
 85        return ''.join(
 86            '{}\n{} --> {}\n{}\n\n'.format(
 87                i,
 88                self._srt_timestamp(sub.index_start),
 89                self._srt_timestamp(sub.index_end),
 90                sub.text)
 91            for i, sub in enumerate(self.pred_subs))
 92
 93    def _generate_subtitles(self, sim_threshold: int) -> None:
 94        self.pred_subs = []
 95
 96        if self.pred_frames is None:
 97            raise AttributeError(
 98                'Please call self.run_ocr() first to perform ocr on frames')
 99
100        # divide ocr of frames into subtitle paragraphs using sliding window
101        WIN_BOUND = int(self.fps // 2)  # 1/2 sec sliding window boundary
102        bound = WIN_BOUND
103        i = 0
104        j = 1
105        while j < len(self.pred_frames):
106            fi, fj = self.pred_frames[i], self.pred_frames[j]
107
108            if fi.is_similar_to(fj):
109                bound = WIN_BOUND
110            elif bound > 0:
111                bound -= 1
112            else:
113                # divide subtitle paragraphs
114                para_new = j - WIN_BOUND
115                self._append_sub(PredictedSubtitle(
116                    self.pred_frames[i:para_new], sim_threshold))
117                i = para_new
118                j = i
119                bound = WIN_BOUND
120
121            j += 1
122
123        # also handle the last remaining frames
124        if i < len(self.pred_frames) - 1:
125            self._append_sub(PredictedSubtitle(
126                self.pred_frames[i:], sim_threshold))
127
128    def _append_sub(self, sub: PredictedSubtitle) -> None:
129        if len(sub.text) == 0:
130            return
131
132        # merge new sub to the last subs if they are similar
133        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
134            ls = self.pred_subs[-1]
135            del self.pred_subs[-1]
136            sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
137
138        self.pred_subs.append(sub)
139
140    def _srt_timestamp(self, frame_index: int) -> str:
141        td = datetime.timedelta(seconds=frame_index / self.fps)
142        ms = td.microseconds // 1000
143        m, s = divmod(td.seconds, 60)
144        h, m = divmod(m, 60)
145        return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)