all repos — videocr @ d15676700bb14da69667c6d08958c8c2ca44c7b9

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from typing import List
  3import sys
  4import multiprocessing
  5import pytesseract
  6import cv2
  7
  8from . import constants
  9from . import utils
 10from .models import PredictedFrame, PredictedSubtitle
 11from .opencv_adapter import Capture
 12
 13
 14class Video:
 15    path: str
 16    lang: str
 17    use_fullframe: bool
 18    num_frames: int
 19    fps: float
 20    height: int
 21    pred_frames: List[PredictedFrame]
 22    pred_subs: List[PredictedSubtitle]
 23
 24    def __init__(self, path: str):
 25        self.path = path
 26        with Capture(path) as v:
 27            self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 28            self.fps = v.get(cv2.CAP_PROP_FPS)
 29            self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
 30
 31    def run_ocr(self, lang: str, time_start: str, time_end: str,
 32                conf_threshold: int, use_fullframe: bool) -> None:
 33        self.lang = lang
 34        self.use_fullframe = use_fullframe
 35
 36        ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
 37        ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
 38
 39        if ocr_end < ocr_start:
 40            raise ValueError('time_start is later than time_end')
 41        num_ocr_frames = ocr_end - ocr_start
 42
 43        # get frames from ocr_start to ocr_end
 44        with Capture(self.path) as v, multiprocessing.Pool() as pool:
 45            v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
 46            frames = (v.read()[1] for _ in range(num_ocr_frames))
 47
 48            # perform ocr to frames in parallel
 49            it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
 50            self.pred_frames = [
 51                PredictedFrame(i + ocr_start, data, conf_threshold)
 52                for i, data in enumerate(it_ocr)
 53            ]
 54
 55    def _image_to_data(self, img) -> str:
 56        if not self.use_fullframe:
 57            # only use bottom half of the frame by default
 58            img = img[self.height // 2:, :]
 59        config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
 60        try:
 61            return pytesseract.image_to_data(img, lang=self.lang, config=config)
 62        except Exception as e:
 63            sys.exit('{}: {}'.format(e.__class__.__name__, e))
 64
 65    def get_subtitles(self, sim_threshold: int) -> str:
 66        self._generate_subtitles(sim_threshold)
 67        return ''.join(
 68            '{}\n{} --> {}\n{}\n\n'.format(
 69                i,
 70                utils.get_srt_timestamp(sub.index_start, self.fps),
 71                utils.get_srt_timestamp(sub.index_end, self.fps),
 72                sub.text)
 73            for i, sub in enumerate(self.pred_subs))
 74
 75    def _generate_subtitles(self, sim_threshold: int) -> None:
 76        self.pred_subs = []
 77
 78        if self.pred_frames is None:
 79            raise AttributeError(
 80                'Please call self.run_ocr() first to perform ocr on frames')
 81
 82        # divide ocr of frames into subtitle paragraphs using sliding window
 83        WIN_BOUND = int(self.fps // 2)  # 1/2 sec sliding window boundary
 84        bound = WIN_BOUND
 85        i = 0
 86        j = 1
 87        while j < len(self.pred_frames):
 88            fi, fj = self.pred_frames[i], self.pred_frames[j]
 89
 90            if fi.is_similar_to(fj):
 91                bound = WIN_BOUND
 92            elif bound > 0:
 93                bound -= 1
 94            else:
 95                # divide subtitle paragraphs
 96                para_new = j - WIN_BOUND
 97                self._append_sub(PredictedSubtitle(
 98                    self.pred_frames[i:para_new], sim_threshold))
 99                i = para_new
100                j = i
101                bound = WIN_BOUND
102
103            j += 1
104
105        # also handle the last remaining frames
106        if i < len(self.pred_frames) - 1:
107            self._append_sub(PredictedSubtitle(
108                self.pred_frames[i:], sim_threshold))
109
110    def _append_sub(self, sub: PredictedSubtitle) -> None:
111        if len(sub.text) == 0:
112            return
113
114        # merge new sub to the last subs if they are similar
115        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
116            ls = self.pred_subs[-1]
117            del self.pred_subs[-1]
118            sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
119
120        self.pred_subs.append(sub)