all repos — videocr @ bccdcc02fc1fbda93da12642ad2c72a145f2011f

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from concurrent import futures
  3import datetime
  4import pytesseract
  5import cv2
  6
  7from .models import PredictedFrame, PredictedSubtitle
  8
  9
 10class Video:
 11    path: str
 12    lang: str
 13    use_fullframe: bool
 14    num_frames: int
 15    fps: float
 16    pred_frames: List[PredictedFrame]
 17    pred_subs: List[PredictedSubtitle]
 18
 19    def __init__(self, path: str):
 20        self.path = path
 21        v = cv2.VideoCapture(path)
 22        self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 23        self.fps = v.get(cv2.CAP_PROP_FPS)
 24        v.release()
 25
 26    def run_ocr(self, lang: str, time_start: str, time_end: str,
 27                use_fullframe: bool) -> None:
 28        self.lang = lang
 29        self.use_fullframe = use_fullframe
 30
 31        ocr_start = self._frame_index(time_start) if time_start else 0
 32        ocr_end = self._frame_index(time_end) if time_end else self.num_frames
 33
 34        if ocr_end < ocr_start:
 35            raise ValueError('time_start is later than time_end')
 36        num_ocr_frames = ocr_end - ocr_start
 37
 38        # get frames from ocr_start to ocr_end
 39        v = cv2.VideoCapture(self.path)
 40        v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
 41        frames = (v.read()[1] for _ in range(num_ocr_frames))
 42
 43        # perform ocr to frames in parallel
 44        with futures.ProcessPoolExecutor() as pool:
 45            ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
 46            self.pred_frames = [PredictedFrame(i + ocr_start, data) 
 47                                for i, data in enumerate(ocr_map)]
 48
 49        v.release()
 50
 51    # convert time str to frame index
 52    def _frame_index(self, time: str) -> int:
 53        t = time.split(':')
 54        t = list(map(float, t))
 55        if len(t) == 3:
 56            td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
 57        elif len(t) == 2:
 58            td = datetime.timedelta(minutes=t[0], seconds=t[1])
 59        else:
 60            raise ValueError(
 61                'time data "{}" does not match format "%H:%M:%S"'.format(time))
 62
 63        index = int(td.total_seconds() * self.fps)
 64        if index > self.num_frames or index < 0:
 65            raise ValueError(
 66                'time data "{}" exceeds video duration'.format(time))
 67
 68        return index
 69
 70    def _single_frame_ocr(self, img) -> str:
 71        if not self.use_fullframe:
 72            # only use bottom half of the frame by default
 73            img = img[img.shape[0] // 2:, :]
 74        return pytesseract.image_to_data(img, lang=self.lang)
 75
 76    def get_subtitles(self) -> str:
 77        self._generate_subtitles()
 78        return ''.join(
 79            '{}\n{} --> {}\n{}\n\n'.format(
 80                i,
 81                self._srt_timestamp(sub.index_start),
 82                self._srt_timestamp(sub.index_end),
 83                sub.text)
 84            for i, sub in enumerate(self.pred_subs))
 85
 86    def _generate_subtitles(self) -> None:
 87        self.pred_subs = []
 88
 89        if self.pred_frames is None:
 90            raise AttributeError(
 91                'Please call self.run_ocr() first to perform ocr on frames')
 92
 93        # divide ocr of frames into subtitle paragraphs using sliding window
 94        WIN_BOUND = int(self.fps // 2)  # 1/2 sec sliding window boundary
 95        bound = WIN_BOUND
 96        i = 0
 97        j = 1
 98        while j < len(self.pred_frames):
 99            fi, fj = self.pred_frames[i], self.pred_frames[j]
100
101            if fi.is_similar_to(fj):
102                bound = WIN_BOUND
103            elif bound > 0:
104                bound -= 1
105            else:
106                # divide subtitle paragraphs
107                para_new = j - WIN_BOUND
108                self._append_sub(
109                    PredictedSubtitle(self.pred_frames[i:para_new]))
110                i = para_new
111                j = i
112                bound = WIN_BOUND
113
114            j += 1
115
116        # also handle the last remaining frames
117        if i < len(self.pred_frames) - 1:
118            self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
119
120    def _append_sub(self, sub: PredictedSubtitle) -> None:
121        if len(sub.text) == 0:
122            return
123
124        # merge new sub to the last subs if they are similar
125        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
126            ls = self.pred_subs[-1]
127            del self.pred_subs[-1]
128            sub = PredictedSubtitle(ls.frames + sub.frames)
129
130        self.pred_subs.append(sub)
131
132    def _srt_timestamp(self, frame_index: int) -> str:
133        td = datetime.timedelta(seconds=frame_index / self.fps)
134        ms = td.microseconds // 1000
135        m, s = divmod(td.seconds, 60)
136        h, m = divmod(m, 60)
137        return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)