all repos — videocr @ 3f73cb9bcafbd639ef5791a846b861d633cdb9dc

Extract hardcoded subtitles from videos using machine learning

videocr/video.py (view raw)

  1from __future__ import annotations
  2from concurrent import futures
  3import datetime
  4import pytesseract
  5import cv2
  6import timeit
  7
  8from .models import PredictedFrame, PredictedSubtitle
  9
 10
 11class Video:
 12    path: str
 13    lang: str
 14    use_fullframe: bool
 15    num_frames: int
 16    fps: float
 17    ocr_frame_start: int
 18    num_ocr_frames: int
 19    pred_frames: List[PredictedFrame]
 20    pred_subs: List[PredictedSubtitle]
 21
 22    def __init__(self, path: str, lang: str, use_fullframe=False,
 23                 time_start='0:00', time_end=''):
 24        self.path = path
 25        self.lang = lang
 26        self.use_fullframe = use_fullframe
 27        v = cv2.VideoCapture(path)
 28        self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
 29        self.fps = v.get(cv2.CAP_PROP_FPS)
 30        v.release()
 31
 32        self.ocr_frame_start = self._frame_index(time_start)
 33
 34        if time_end:
 35            ocr_end = self._frame_index(time_end)
 36        else:
 37            ocr_end = self.num_frames
 38        self.num_ocr_frames = ocr_end - self.ocr_frame_start
 39
 40        if self.num_ocr_frames < 0:
 41            raise ValueError('time_start is later than time_end')
 42
 43    def _frame_index(self, time: str) -> int:
 44        t = time.split(':')
 45        t = list(map(int, t))
 46        if len(t) == 3:
 47            td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
 48        elif len(t) == 2:
 49            td = datetime.timedelta(minutes=t[0], seconds=t[1])
 50        else:
 51            raise ValueError(
 52                'time data "{}" does not match format "%H:%M:%S"'.format(time))
 53
 54        index = int(td.total_seconds() * self.fps)
 55        if index > self.num_frames or index < 0:
 56            raise ValueError(
 57                'time data "{}" exceeds video duration'.format(time))
 58
 59        return index
 60
 61    def run_ocr(self) -> None:
 62        v = cv2.VideoCapture(self.path)
 63        v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start)
 64        frames = (v.read()[1] for _ in range(self.num_ocr_frames))
 65
 66        # perform ocr to all frames in parallel
 67        with futures.ProcessPoolExecutor() as pool:
 68            ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
 69            self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data) 
 70                                for i, data in enumerate(ocr_map)]
 71
 72        v.release()
 73
 74    def _single_frame_ocr(self, img) -> str:
 75        if not self.use_fullframe:
 76            # only use bottom half of the frame by default
 77            img = img[img.shape[0] // 2:, :]
 78        return pytesseract.image_to_data(img, lang=self.lang)
 79
 80    def get_subtitles(self) -> str:
 81        self._generate_subtitles()
 82        return ''.join(
 83            '{}\n{} --> {}\n{}\n\n'.format(
 84                i,
 85                self._srt_timestamp(sub.index_start),
 86                self._srt_timestamp(sub.index_end),
 87                sub.text)
 88            for i, sub in enumerate(self.pred_subs))
 89
 90    def _generate_subtitles(self) -> None:
 91        self.pred_subs = []
 92
 93        if self.pred_frames is None:
 94            raise AttributeError(
 95                'Please call self.run_ocr() first to perform ocr on frames')
 96
 97        # divide ocr of frames into subtitle paragraphs using sliding window
 98        WIN_BOUND = int(self.fps // 2)  # 1/2 sec sliding window boundary
 99        bound = WIN_BOUND
100        i = 0
101        j = 1
102        while j < self.num_ocr_frames:
103            fi, fj = self.pred_frames[i], self.pred_frames[j]
104
105            if fi.is_similar_to(fj):
106                bound = WIN_BOUND
107            elif bound > 0:
108                bound -= 1
109            else:
110                # divide subtitle paragraphs
111                para_new = j - WIN_BOUND
112                self._append_sub(
113                    PredictedSubtitle(self.pred_frames[i:para_new]))
114                i = para_new
115                j = i
116                bound = WIN_BOUND
117
118            j += 1
119
120        # also handle the last remaining frames
121        if i < self.num_ocr_frames - 1:
122            self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
123
124    def _append_sub(self, sub: PredictedSubtitle) -> None:
125        if len(sub.text) == 0:
126            return
127
128        # merge new sub to the last subs if they are similar
129        while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
130            ls = self.pred_subs[-1]
131            del self.pred_subs[-1]
132            sub = PredictedSubtitle(ls.frames + sub.frames)
133
134        self.pred_subs.append(sub)
135
136    def _srt_timestamp(self, frame_index: int) -> str:
137        td = datetime.timedelta(seconds=frame_index / self.fps)
138        ms = td.microseconds // 1000
139        m, s = divmod(td.seconds, 60)
140        h, m = divmod(m, 60)
141        return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
142
143    def save_subtitles_to_file(self, path='subtitle.srt') -> None:
144        with open(path, 'w+') as f:
145            f.write(self.get_subtitles())
146
147
148time_start = timeit.default_timer()
149v = Video('1.mp4', 'HanS')
150v.run_ocr()
151time_stop = timeit.default_timer()
152print('time for ocr: ', time_stop - time_start)
153
154time_start = timeit.default_timer()
155v.save_subtitles_to_file()
156time_stop = timeit.default_timer()
157print('time for save sub: ', time_stop - time_start)