all repos — videocr @ v0.1

Extract hardcoded subtitles from videos using machine learning

videocr/models.py (view raw)

 1from __future__ import annotations
 2from typing import List
 3from dataclasses import dataclass
 4from fuzzywuzzy import fuzz
 5
 6
 7@dataclass
 8class PredictedWord:
 9    __slots__ = 'confidence', 'text'
10    confidence: int
11    text: str
12
13
14class PredictedFrame:
15    index: int  # 0-based index of the frame
16    words: List[PredictedWord]
17    confidence: int  # total confidence of all words
18    text: str
19
20    def __init__(self, index: int, pred_data: str, conf_threshold: int):
21        self.index = index
22        self.words = []
23
24        block = 0  # keep track of line breaks
25
26        for l in pred_data.splitlines()[1:]:
27            word_data = l.split()
28            if len(word_data) < 12:
29                # no word is predicted
30                continue
31            _, _, block_num, *_, conf, text = word_data
32            block_num, conf = int(block_num), int(conf)
33
34            # handle line breaks
35            if block < block_num:
36                block = block_num
37                if self.words and self.words[-1].text != '\n':
38                    self.words.append(PredictedWord(0, '\n'))
39
40            # word predictions with low confidence will be filtered out
41            if conf >= conf_threshold:
42                self.words.append(PredictedWord(conf, text))
43
44        self.confidence = sum(word.confidence for word in self.words)
45
46        self.text = ' '.join(word.text for word in self.words)
47        # remove chars that are obviously ocr errors
48        table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')
49        self.text = self.text.translate(table).replace(' \n ', '\n').strip()
50
51    def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
52        return fuzz.ratio(self.text, other.text) >= threshold
53
54
55class PredictedSubtitle:
56    frames: List[PredictedFrame]
57    sim_threshold: int
58    text: str
59
60    def __init__(self, frames: List[PredictedFrame], sim_threshold: int):
61        self.frames = [f for f in frames if f.confidence > 0]
62        self.sim_threshold = sim_threshold
63
64        if self.frames:
65            self.text = max(self.frames, key=lambda f: f.confidence).text
66        else:
67            self.text = ''
68
69    @property
70    def index_start(self) -> int:
71        if self.frames:
72            return self.frames[0].index
73        return 0
74
75    @property
76    def index_end(self) -> int:
77        if self.frames:
78            return self.frames[-1].index
79        return 0
80
81    def is_similar_to(self, other: PredictedSubtitle) -> bool:
82        return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold
83
84    def __repr__(self):
85        return '{} - {}. {}'.format(self.index_start, self.index_end, self.text)