Commit
·
c0dd2e2
1
Parent(s):
2249846
import backend from __init__
Browse files- whisper_online.py +19 -8
whisper_online.py
CHANGED
|
@@ -23,15 +23,19 @@ def load_audio_chunk(fname, beg, end):
|
|
| 23 |
|
| 24 |
class ASRBase:
|
| 25 |
|
| 26 |
-
# join transcribe words with this character (" " for whisper_timestamped,
|
| 27 |
-
|
| 28 |
|
| 29 |
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
|
| 30 |
self.transcribe_kargs = {}
|
| 31 |
self.original_language = lan
|
| 32 |
|
|
|
|
| 33 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
def load_model(self, modelsize, cache_dir):
|
| 36 |
raise NotImplemented("must be implemented in the child class")
|
| 37 |
|
|
@@ -49,11 +53,14 @@ class ASRBase:
|
|
| 49 |
class WhisperTimestampedASR(ASRBase):
|
| 50 |
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
| 51 |
On the other hand, the installation for GPU could be easier.
|
|
|
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
import whisper
|
| 55 |
import whisper_timestamped
|
| 56 |
-
"""
|
| 57 |
|
| 58 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 59 |
if model_dir is not None:
|
|
@@ -89,8 +96,12 @@ class FasterWhisperASR(ASRBase):
|
|
| 89 |
|
| 90 |
sep = ""
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 93 |
-
from faster_whisper import WhisperModel
|
| 94 |
|
| 95 |
|
| 96 |
if model_dir is not None:
|
|
@@ -465,11 +476,11 @@ if __name__ == "__main__":
|
|
| 465 |
#asr = WhisperASR(lan=language, modelsize=size)
|
| 466 |
|
| 467 |
if args.backend == "faster-whisper":
|
| 468 |
-
from faster_whisper import WhisperModel
|
| 469 |
asr_cls = FasterWhisperASR
|
| 470 |
else:
|
| 471 |
-
import whisper
|
| 472 |
-
import whisper_timestamped
|
| 473 |
# from whisper_timestamped_model import WhisperTimestampedASR
|
| 474 |
asr_cls = WhisperTimestampedASR
|
| 475 |
|
|
|
|
| 23 |
|
| 24 |
class ASRBase:
|
| 25 |
|
| 26 |
+
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
| 27 |
+
# "" for faster-whisper because it emits the spaces when neeeded)
|
| 28 |
|
| 29 |
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
|
| 30 |
self.transcribe_kargs = {}
|
| 31 |
self.original_language = lan
|
| 32 |
|
| 33 |
+
self.import_backend()
|
| 34 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
| 35 |
|
| 36 |
+
def import_backend(self):
|
| 37 |
+
raise NotImplemented("must be implemented in the child class")
|
| 38 |
+
|
| 39 |
def load_model(self, modelsize, cache_dir):
|
| 40 |
raise NotImplemented("must be implemented in the child class")
|
| 41 |
|
|
|
|
| 53 |
class WhisperTimestampedASR(ASRBase):
|
| 54 |
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
| 55 |
On the other hand, the installation for GPU could be easier.
|
| 56 |
+
"""
|
| 57 |
|
| 58 |
+
sep = " "
|
| 59 |
+
|
| 60 |
+
def import_backend(self):
|
| 61 |
+
global whisper, whisper_timestamped
|
| 62 |
import whisper
|
| 63 |
import whisper_timestamped
|
|
|
|
| 64 |
|
| 65 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 66 |
if model_dir is not None:
|
|
|
|
| 96 |
|
| 97 |
sep = ""
|
| 98 |
|
| 99 |
+
def import_backend(self):
|
| 100 |
+
global faster_whisper
|
| 101 |
+
import faster_whisper
|
| 102 |
+
|
| 103 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 104 |
+
#from faster_whisper import WhisperModel
|
| 105 |
|
| 106 |
|
| 107 |
if model_dir is not None:
|
|
|
|
| 476 |
#asr = WhisperASR(lan=language, modelsize=size)
|
| 477 |
|
| 478 |
if args.backend == "faster-whisper":
|
| 479 |
+
#from faster_whisper import WhisperModel
|
| 480 |
asr_cls = FasterWhisperASR
|
| 481 |
else:
|
| 482 |
+
#import whisper
|
| 483 |
+
#import whisper_timestamped
|
| 484 |
# from whisper_timestamped_model import WhisperTimestampedASR
|
| 485 |
asr_cls = WhisperTimestampedASR
|
| 486 |
|