add api definition
Yi Ge me@yige.ch
Sun, 28 Apr 2019 15:46:24 +0200
3 files changed,
28 insertions(+),
33 deletions(-)
A
videocr/api.py
@@ -0,0 +1,17 @@
+ +from .video import Video + + +def get_subtitles(video_path: str, lang='eng', + time_start='0:00', time_end='', use_fullframe=False) -> str: + v = Video(video_path) + v.run_ocr(lang, time_start, time_end, use_fullframe) + return v.get_subtitles() + + +def save_subtitles_to_file( + video_path: str, file_path='subtitle.srt', lang='eng', + time_start='0:00', time_end='', use_fullframe=False) -> None: + with open(file_path, 'w+') as f: + f.write(get_subtitles( + video_path, lang, time_start, time_end, use_fullframe))
M
videocr/models.py
→
videocr/models.py
@@ -48,9 +48,8 @@ self.confidence = sum(word.confidence for word in self.words)
self.text = ' '.join(word.text for word in self.words) # remove chars that are obviously ocr errors - translate_table = {ord(c): None for c in '<>{}[];`@#$%^*_=~\\'} - translate_table[ord('|')] = 'I' - self.text = self.text.translate(translate_table).strip() + table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\') + self.text = self.text.translate(table).replace(' \n ', '\n').strip() def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool: return fuzz.ratio(self.text, other.text) >= threshold@@ -58,14 +57,13 @@
class PredictedSubtitle: frames: List[PredictedFrame] + text: str def __init__(self, frames: List[PredictedFrame]): self.frames = [f for f in frames if f.confidence > 0] if self.frames: - conf_max = max(f.confidence for f in self.frames) - self.text = next(f.text for f in self.frames - if f.confidence == conf_max) + self.text = max(self.frames, key=lambda f: f.confidence).text else: self.text = ''
M
videocr/video.py
→
videocr/video.py
@@ -3,7 +3,6 @@ from concurrent import futures
import datetime import pytesseract import cv2 -import timeit from .models import PredictedFrame, PredictedSubtitle@@ -24,19 +23,16 @@ self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS) v.release() - def run_ocr(self, lang: str, use_fullframe=False, - time_start='0:00', time_end='') -> None: + def run_ocr(self, lang: str, time_start: str, time_end: str, + use_fullframe: bool) -> None: self.lang = lang self.use_fullframe = use_fullframe - ocr_start = self._frame_index(time_start) + ocr_start = self._frame_index(time_start) if time_start else 0 + ocr_end = self._frame_index(time_end) if time_end else self.num_frames - if time_end: - ocr_end = self._frame_index(time_end) - if ocr_end < ocr_start: - raise ValueError('time_start is later than time_end') - else: - ocr_end = self.num_frames + if ocr_end < ocr_start: + raise ValueError('time_start is later than time_end') num_ocr_frames = ocr_end - ocr_start # get frames from ocr_start to ocr_end@@ -55,7 +51,7 @@
# convert time str to frame index def _frame_index(self, time: str) -> int: t = time.split(':') - t = list(map(int, t)) + t = list(map(float, t)) if len(t) == 3: td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) elif len(t) == 2:@@ -139,19 +135,3 @@ ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms) - - def save_subtitles_to_file(self, path='subtitle.srt') -> None: - with open(path, 'w+') as f: - f.write(self.get_subtitles()) - - -time_start = timeit.default_timer() -v = Video('1.mp4', 'HanS') -v.run_ocr() -time_stop = timeit.default_timer() -print('time for ocr: ', time_stop - time_start) - -time_start = timeit.default_timer() -v.save_subtitles_to_file() -time_stop = timeit.default_timer() -print('time for save sub: ', time_stop - time_start)