Dominik Macháček
commited on
Commit
·
8f32dea
1
Parent(s):
bd0d848
logfile reviewed, whisper_timestamped loading module and vad
Browse files- whisper_online.py +33 -20
whisper_online.py
CHANGED
|
@@ -26,12 +26,15 @@ class ASRBase:
|
|
| 26 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
| 27 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
| 28 |
|
| 29 |
-
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
|
|
|
|
|
|
|
| 30 |
self.transcribe_kargs = {}
|
| 31 |
self.original_language = lan
|
| 32 |
|
| 33 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
| 34 |
|
|
|
|
| 35 |
def load_model(self, modelsize, cache_dir):
|
| 36 |
raise NotImplemented("must be implemented in the child class")
|
| 37 |
|
|
@@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
|
|
| 50 |
sep = " "
|
| 51 |
|
| 52 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 53 |
-
global whisper_timestamped # has to be global as it is used at each `transcribe` call
|
| 54 |
import whisper
|
| 55 |
-
import
|
|
|
|
| 56 |
if model_dir is not None:
|
| 57 |
print("ignoring model_dir, not implemented",file=self.logfile)
|
| 58 |
return whisper.load_model(modelsize, download_root=cache_dir)
|
| 59 |
|
| 60 |
def transcribe(self, audio, init_prompt=""):
|
| 61 |
-
result =
|
|
|
|
|
|
|
|
|
|
| 62 |
return result
|
| 63 |
|
| 64 |
def ts_words(self,r):
|
|
@@ -74,7 +80,12 @@ class WhisperTimestampedASR(ASRBase):
|
|
| 74 |
return [s["end"] for s in res["segments"]]
|
| 75 |
|
| 76 |
def use_vad(self):
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
class FasterWhisperASR(ASRBase):
|
|
@@ -135,7 +146,6 @@ class FasterWhisperASR(ASRBase):
|
|
| 135 |
class HypothesisBuffer:
|
| 136 |
|
| 137 |
def __init__(self, logfile=sys.stderr):
|
| 138 |
-
"""output: where to store the log. Leave it unchanged to print to terminal."""
|
| 139 |
self.commited_in_buffer = []
|
| 140 |
self.buffer = []
|
| 141 |
self.new = []
|
|
@@ -205,7 +215,7 @@ class OnlineASRProcessor:
|
|
| 205 |
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
| 206 |
"""asr: WhisperASR object
|
| 207 |
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
| 208 |
-
|
| 209 |
"""
|
| 210 |
self.asr = asr
|
| 211 |
self.tokenizer = tokenizer
|
|
@@ -468,21 +478,24 @@ if __name__ == "__main__":
|
|
| 468 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
| 469 |
args = parser.parse_args()
|
| 470 |
|
|
|
|
|
|
|
|
|
|
| 471 |
if args.offline and args.comp_unaware:
|
| 472 |
-
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=
|
| 473 |
sys.exit(1)
|
| 474 |
|
| 475 |
audio_path = args.audio_path
|
| 476 |
|
| 477 |
SAMPLING_RATE = 16000
|
| 478 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 479 |
-
print("Audio duration is: %2.2f seconds" % duration, file=
|
| 480 |
|
| 481 |
size = args.model
|
| 482 |
language = args.lan
|
| 483 |
|
| 484 |
t = time.time()
|
| 485 |
-
print(f"Loading Whisper {size} model for {language}...",file=
|
| 486 |
|
| 487 |
if args.backend == "faster-whisper":
|
| 488 |
asr_cls = FasterWhisperASR
|
|
@@ -499,15 +512,15 @@ if __name__ == "__main__":
|
|
| 499 |
|
| 500 |
|
| 501 |
e = time.time()
|
| 502 |
-
print(f"done. It took {round(e-t,2)} seconds.",file=
|
| 503 |
|
| 504 |
if args.vad:
|
| 505 |
-
print("setting VAD filter",file=
|
| 506 |
asr.use_vad()
|
| 507 |
|
| 508 |
|
| 509 |
min_chunk = args.min_chunk_size
|
| 510 |
-
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
|
| 511 |
|
| 512 |
|
| 513 |
# load the audio into the LRU cache before we start the timer
|
|
@@ -529,10 +542,10 @@ if __name__ == "__main__":
|
|
| 529 |
if now is None:
|
| 530 |
now = time.time()-start
|
| 531 |
if o[0] is not None:
|
| 532 |
-
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=
|
| 533 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
| 534 |
else:
|
| 535 |
-
print(o,file=
|
| 536 |
|
| 537 |
if args.offline: ## offline mode processing (for testing/debugging)
|
| 538 |
a = load_audio(audio_path)
|
|
@@ -540,7 +553,7 @@ if __name__ == "__main__":
|
|
| 540 |
try:
|
| 541 |
o = online.process_iter()
|
| 542 |
except AssertionError:
|
| 543 |
-
print("assertion error",file=
|
| 544 |
pass
|
| 545 |
else:
|
| 546 |
output_transcript(o)
|
|
@@ -553,12 +566,12 @@ if __name__ == "__main__":
|
|
| 553 |
try:
|
| 554 |
o = online.process_iter()
|
| 555 |
except AssertionError:
|
| 556 |
-
print("assertion error",file=
|
| 557 |
pass
|
| 558 |
else:
|
| 559 |
output_transcript(o, now=end)
|
| 560 |
|
| 561 |
-
print(f"## last processed {end:.2f}s",file=
|
| 562 |
|
| 563 |
beg = end
|
| 564 |
end += min_chunk
|
|
@@ -580,12 +593,12 @@ if __name__ == "__main__":
|
|
| 580 |
try:
|
| 581 |
o = online.process_iter()
|
| 582 |
except AssertionError:
|
| 583 |
-
print("assertion error",file=
|
| 584 |
pass
|
| 585 |
else:
|
| 586 |
output_transcript(o)
|
| 587 |
now = time.time() - start
|
| 588 |
-
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=
|
| 589 |
|
| 590 |
if end >= duration:
|
| 591 |
break
|
|
|
|
| 26 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
| 27 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
| 28 |
|
| 29 |
+
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
| 30 |
+
self.logfile = logfile
|
| 31 |
+
|
| 32 |
self.transcribe_kargs = {}
|
| 33 |
self.original_language = lan
|
| 34 |
|
| 35 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
| 36 |
|
| 37 |
+
|
| 38 |
def load_model(self, modelsize, cache_dir):
|
| 39 |
raise NotImplemented("must be implemented in the child class")
|
| 40 |
|
|
|
|
| 53 |
sep = " "
|
| 54 |
|
| 55 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
|
|
| 56 |
import whisper
|
| 57 |
+
from whisper_timestamped import transcribe_timestamped
|
| 58 |
+
self.transcribe_timestamped = transcribe_timestamped
|
| 59 |
if model_dir is not None:
|
| 60 |
print("ignoring model_dir, not implemented",file=self.logfile)
|
| 61 |
return whisper.load_model(modelsize, download_root=cache_dir)
|
| 62 |
|
| 63 |
def transcribe(self, audio, init_prompt=""):
|
| 64 |
+
result = self.transcribe_timestamped(self.model,
|
| 65 |
+
audio, language=self.original_language,
|
| 66 |
+
initial_prompt=init_prompt, verbose=None,
|
| 67 |
+
condition_on_previous_text=True, **self.transcribe_kargs)
|
| 68 |
return result
|
| 69 |
|
| 70 |
def ts_words(self,r):
|
|
|
|
| 80 |
return [s["end"] for s in res["segments"]]
|
| 81 |
|
| 82 |
def use_vad(self):
|
| 83 |
+
self.transcribe_kargs["vad"] = True
|
| 84 |
+
|
| 85 |
+
def set_translate_task(self):
|
| 86 |
+
self.transcribe_kargs["task"] = "translate"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
|
| 90 |
|
| 91 |
class FasterWhisperASR(ASRBase):
|
|
|
|
| 146 |
class HypothesisBuffer:
|
| 147 |
|
| 148 |
def __init__(self, logfile=sys.stderr):
|
|
|
|
| 149 |
self.commited_in_buffer = []
|
| 150 |
self.buffer = []
|
| 151 |
self.new = []
|
|
|
|
| 215 |
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
| 216 |
"""asr: WhisperASR object
|
| 217 |
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
| 218 |
+
logfile: where to store the log.
|
| 219 |
"""
|
| 220 |
self.asr = asr
|
| 221 |
self.tokenizer = tokenizer
|
|
|
|
| 478 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
| 479 |
args = parser.parse_args()
|
| 480 |
|
| 481 |
+
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
|
| 482 |
+
logfile = sys.stderr
|
| 483 |
+
|
| 484 |
if args.offline and args.comp_unaware:
|
| 485 |
+
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
|
| 486 |
sys.exit(1)
|
| 487 |
|
| 488 |
audio_path = args.audio_path
|
| 489 |
|
| 490 |
SAMPLING_RATE = 16000
|
| 491 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 492 |
+
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
| 493 |
|
| 494 |
size = args.model
|
| 495 |
language = args.lan
|
| 496 |
|
| 497 |
t = time.time()
|
| 498 |
+
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
| 499 |
|
| 500 |
if args.backend == "faster-whisper":
|
| 501 |
asr_cls = FasterWhisperASR
|
|
|
|
| 512 |
|
| 513 |
|
| 514 |
e = time.time()
|
| 515 |
+
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
| 516 |
|
| 517 |
if args.vad:
|
| 518 |
+
print("setting VAD filter",file=logfile)
|
| 519 |
asr.use_vad()
|
| 520 |
|
| 521 |
|
| 522 |
min_chunk = args.min_chunk_size
|
| 523 |
+
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)
|
| 524 |
|
| 525 |
|
| 526 |
# load the audio into the LRU cache before we start the timer
|
|
|
|
| 542 |
if now is None:
|
| 543 |
now = time.time()-start
|
| 544 |
if o[0] is not None:
|
| 545 |
+
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
|
| 546 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
| 547 |
else:
|
| 548 |
+
print(o,file=logfile,flush=True)
|
| 549 |
|
| 550 |
if args.offline: ## offline mode processing (for testing/debugging)
|
| 551 |
a = load_audio(audio_path)
|
|
|
|
| 553 |
try:
|
| 554 |
o = online.process_iter()
|
| 555 |
except AssertionError:
|
| 556 |
+
print("assertion error",file=logfile)
|
| 557 |
pass
|
| 558 |
else:
|
| 559 |
output_transcript(o)
|
|
|
|
| 566 |
try:
|
| 567 |
o = online.process_iter()
|
| 568 |
except AssertionError:
|
| 569 |
+
print("assertion error",file=logfile)
|
| 570 |
pass
|
| 571 |
else:
|
| 572 |
output_transcript(o, now=end)
|
| 573 |
|
| 574 |
+
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
|
| 575 |
|
| 576 |
beg = end
|
| 577 |
end += min_chunk
|
|
|
|
| 593 |
try:
|
| 594 |
o = online.process_iter()
|
| 595 |
except AssertionError:
|
| 596 |
+
print("assertion error",file=logfile)
|
| 597 |
pass
|
| 598 |
else:
|
| 599 |
output_transcript(o)
|
| 600 |
now = time.time() - start
|
| 601 |
+
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
|
| 602 |
|
| 603 |
if end >= duration:
|
| 604 |
break
|