all repos — videocr @ 3a73f1f508933f27a5646f9643a642fc9871c020

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from concurrent import futures
  3import datetime
  4import pytesseract
  5import cv2
  6import timeit
  7
  8from .models import PredictedFrame, PredictedSubtitle
  9
 10
 11class Video:
 12    path: str
 13    lang: str
 14    use_fullframe: bool
 15    num_frames: int
 16    fps: float
 17    pred_frames: List[PredictedFrame]
 18    pred_subs: List[PredictedSubtitle]
 19
 20    def __init__(self, path, lang, use_fullframe=False):
 21        self.path = path
 22        self.lang = lang
 23        self.use_fullframe = use_fullframe
 24        v = cv2.VideoCapture(path)
 25        self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 26        self.fps = v.get(cv2.CAP_PROP_FPS)
 27        v.release()
 28
 29    def _single_frame_ocr(self, img) -> str:
 30        if not self.use_fullframe:
 31            # only use bottom half of the frame by default
 32            img = img[img.shape[0] // 2:, :]
 33        data = pytesseract.image_to_data(img, lang=self.lang)
 34        return data
 35
 36    def run_ocr(self) -> None:
 37        v = cv2.VideoCapture(self.path)
 38        frames = (v.read()[1] for _ in range(self.num_frames))
 39
 40        # perform ocr to all frames in parallel
 41        with futures.ProcessPoolExecutor() as pool:
 42            frames_ocr = pool.map(self._single_frame_ocr, frames, chunksize=10)
 43            self.pred_frames = [PredictedFrame(i, data) 
 44                                for i, data in enumerate(frames_ocr)]
 45
 46        v.release()
 47
 48    def get_subtitles(self) -> str:
 49        if self.pred_frames is None:
 50            raise AttributeError(
 51                'Please call self.run_ocr() first to generate ocr of frames')
 52
 53        self.pred_subs = []
 54
 55        # divide ocr of frames into subtitle paragraphs using sliding window
 56        WIN_BOUND = int(self.fps / 2)  # 1/2 sec sliding window boundary
 57        bound = WIN_BOUND
 58        i = 0
 59        j = 1
 60        while j < self.num_frames:
 61            fi, fj = self.pred_frames[i], self.pred_frames[j]
 62
 63            if fi.is_similar_to(fj):
 64                bound = WIN_BOUND
 65            elif bound > 0:
 66                bound -= 1
 67            else:
 68                # divide subtitle paragraphs
 69                para_new = j - WIN_BOUND
 70                self._append_sub(
 71                    PredictedSubtitle(self.pred_frames[i:para_new]))
 72                i = para_new
 73                j = i
 74                bound = WIN_BOUND
 75
 76            j += 1
 77
 78        if i < self.num_frames - 1:
 79            self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
 80
 81        for i, sub in enumerate(self.pred_subs):
 82            print('{}\n{} --> {}\n{}\n'.format(
 83                i,
 84                self._srt_timestamp(sub.index_start),
 85                self._srt_timestamp(sub.index_end),
 86                sub.text))
 87
 88        return ''
 89
 90    def _append_sub(self, sub: PredictedSubtitle) -> None:
 91        if len(sub.text) == 0:
 92            return
 93
 94        # merge new sub to the last subs if they are similar
 95        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
 96            lsub = self.pred_subs[-1]
 97            del self.pred_subs[-1]
 98            sub = PredictedSubtitle(lsub.frames + sub.frames)
 99
100        self.pred_subs.append(sub)
101
102    def _srt_timestamp(self, frame_index) -> str:
103        time = str(datetime.timedelta(seconds=frame_index / self.fps))
104        return time.replace('.', ',')  # srt uses comma as fractional separator
105
106
107time_start = timeit.default_timer()
108v = Video('1.mp4', 'HanS')
109v.run_ocr()
110time_stop = timeit.default_timer()
111print('time for ocr: ', time_stop - time_start)
112
113time_start = timeit.default_timer()
114v.get_subtitles()
115time_stop = timeit.default_timer()
116print('time for get sub: ', time_stop - time_start)