all repos — videocr @ f5d27a7a46f369e269580735566b5a5d03e7d378

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from concurrent import futures
  3import datetime
  4import pytesseract
  5import cv2
  6import timeit
  7
  8from .models import PredictedFrame, PredictedSubtitle
  9
 10
 11class Video:
 12    path: str
 13    lang: str
 14    use_fullframe: bool
 15    num_frames: int
 16    fps: float
 17    pred_frames: List[PredictedFrame]
 18    pred_subs: List[PredictedSubtitle]
 19
 20    def __init__(self, path, lang, use_fullframe=False):
 21        self.path = path
 22        self.lang = lang
 23        self.use_fullframe = use_fullframe
 24        v = cv2.VideoCapture(path)
 25        self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 26        self.fps = v.get(cv2.CAP_PROP_FPS)
 27        v.release()
 28
 29    def _single_frame_ocr(self, img) -> str:
 30        if not self.use_fullframe:
 31            # only use bottom half of the frame by default
 32            img = img[img.shape[0] // 2:, :]
 33        data = pytesseract.image_to_data(img, lang=self.lang)
 34        return data
 35
 36    def run_ocr(self) -> None:
 37        v = cv2.VideoCapture(self.path)
 38        frames = (v.read()[1] for _ in range(self.num_frames))
 39
 40        # perform ocr to all frames in parallel
 41        with futures.ProcessPoolExecutor() as pool:
 42            frames_ocr = pool.map(self._single_frame_ocr, frames, chunksize=10)
 43            self.pred_frames = [PredictedFrame(i, data) 
 44                                for i, data in enumerate(frames_ocr)]
 45
 46        v.release()
 47
 48    def get_subtitles(self) -> str:
 49        self._generate_subtitles()
 50        return ''.join(
 51            '{}\n{} --> {}\n{}\n'.format(
 52                i,
 53                self._srt_timestamp(sub.index_start),
 54                self._srt_timestamp(sub.index_end),
 55                sub.text)
 56            for i, sub in enumerate(self.pred_subs))
 57
 58    def _generate_subtitles(self) -> None:
 59        self.pred_subs = []
 60
 61        if self.pred_frames is None:
 62            raise AttributeError(
 63                'Please call self.run_ocr() first to generate ocr of frames')
 64
 65        # divide ocr of frames into subtitle paragraphs using sliding window
 66        WIN_BOUND = int(self.fps / 2)  # 1/2 sec sliding window boundary
 67        bound = WIN_BOUND
 68        i = 0
 69        j = 1
 70        while j < self.num_frames:
 71            fi, fj = self.pred_frames[i], self.pred_frames[j]
 72
 73            if fi.is_similar_to(fj):
 74                bound = WIN_BOUND
 75            elif bound > 0:
 76                bound -= 1
 77            else:
 78                # divide subtitle paragraphs
 79                para_new = j - WIN_BOUND
 80                self._append_sub(
 81                    PredictedSubtitle(self.pred_frames[i:para_new]))
 82                i = para_new
 83                j = i
 84                bound = WIN_BOUND
 85
 86            j += 1
 87
 88        # also handle the last remaining frames
 89        if i < self.num_frames - 1:
 90            self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
 91
 92    def _append_sub(self, sub: PredictedSubtitle) -> None:
 93        if len(sub.text) == 0:
 94            return
 95
 96        # merge new sub to the last subs if they are similar
 97        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
 98            lsub = self.pred_subs[-1]
 99            del self.pred_subs[-1]
100            sub = PredictedSubtitle(lsub.frames + sub.frames)
101
102        self.pred_subs.append(sub)
103
104    def _srt_timestamp(self, frame_index) -> str:
105        time = str(datetime.timedelta(seconds=frame_index / self.fps))
106        return time.replace('.', ',')  # srt uses comma, not dot
107
108
109time_start = timeit.default_timer()
110v = Video('1.mp4', 'HanS')
111v.run_ocr()
112time_stop = timeit.default_timer()
113print('time for ocr: ', time_stop - time_start)
114
115time_start = timeit.default_timer()
116v.get_subtitles()
117time_stop = timeit.default_timer()
118print('time for get sub: ', time_stop - time_start)