Tijs Zwinkels
commited on
Commit
·
8896389
1
Parent(s):
5929a82
Fix crash when using openai-api with whisper_online_server
Browse files- whisper_online.py +32 -21
- whisper_online_server.py +1 -24
whisper_online.py
CHANGED
|
@@ -548,6 +548,37 @@ def add_shared_args(parser):
|
|
| 548 |
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
| 549 |
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
| 550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
## main:
|
| 552 |
|
| 553 |
if __name__ == "__main__":
|
|
@@ -575,28 +606,8 @@ if __name__ == "__main__":
|
|
| 575 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 576 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
| 577 |
|
|
|
|
| 578 |
language = args.lan
|
| 579 |
-
|
| 580 |
-
if args.backend == "openai-api":
|
| 581 |
-
print("Using OpenAI API.",file=logfile)
|
| 582 |
-
asr = OpenaiApiASR(lan=language)
|
| 583 |
-
else:
|
| 584 |
-
if args.backend == "faster-whisper":
|
| 585 |
-
asr_cls = FasterWhisperASR
|
| 586 |
-
else:
|
| 587 |
-
asr_cls = WhisperTimestampedASR
|
| 588 |
-
|
| 589 |
-
size = args.model
|
| 590 |
-
t = time.time()
|
| 591 |
-
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
| 592 |
-
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
| 593 |
-
e = time.time()
|
| 594 |
-
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
| 595 |
-
|
| 596 |
-
if args.vad:
|
| 597 |
-
print("setting VAD filter",file=logfile)
|
| 598 |
-
asr.use_vad()
|
| 599 |
-
|
| 600 |
if args.task == "translate":
|
| 601 |
asr.set_translate_task()
|
| 602 |
tgt_language = "en" # Whisper translates into English
|
|
|
|
| 548 |
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
| 549 |
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
| 550 |
|
| 551 |
+
def asr_factory(args, logfile=sys.stderr):
|
| 552 |
+
"""
|
| 553 |
+
Creates and configures an ASR instance based on the specified backend and arguments.
|
| 554 |
+
"""
|
| 555 |
+
backend = args.backend
|
| 556 |
+
if backend == "openai-api":
|
| 557 |
+
print("Using OpenAI API.", file=logfile)
|
| 558 |
+
asr = OpenaiApiASR(lan=args.lan)
|
| 559 |
+
else:
|
| 560 |
+
if backend == "faster-whisper":
|
| 561 |
+
from faster_whisper import FasterWhisperASR
|
| 562 |
+
asr_cls = FasterWhisperASR
|
| 563 |
+
else:
|
| 564 |
+
from whisper_timestamped import WhisperTimestampedASR
|
| 565 |
+
asr_cls = WhisperTimestampedASR
|
| 566 |
+
|
| 567 |
+
# Only for FasterWhisperASR and WhisperTimestampedASR
|
| 568 |
+
size = args.model
|
| 569 |
+
t = time.time()
|
| 570 |
+
print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True)
|
| 571 |
+
asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
| 572 |
+
e = time.time()
|
| 573 |
+
print(f"done. It took {round(e-t,2)} seconds.", file=logfile)
|
| 574 |
+
|
| 575 |
+
# Apply common configurations
|
| 576 |
+
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
|
| 577 |
+
print("Setting VAD filter", file=logfile)
|
| 578 |
+
asr.use_vad()
|
| 579 |
+
|
| 580 |
+
return asr
|
| 581 |
+
|
| 582 |
## main:
|
| 583 |
|
| 584 |
if __name__ == "__main__":
|
|
|
|
| 606 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 607 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
| 608 |
|
| 609 |
+
asr = asr_factory(args, logfile=logfile)
|
| 610 |
language = args.lan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
if args.task == "translate":
|
| 612 |
asr.set_translate_task()
|
| 613 |
tgt_language = "en" # Whisper translates into English
|
whisper_online_server.py
CHANGED
|
@@ -24,36 +24,13 @@ SAMPLING_RATE = 16000
|
|
| 24 |
size = args.model
|
| 25 |
language = args.lan
|
| 26 |
|
| 27 |
-
|
| 28 |
-
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
|
| 29 |
-
|
| 30 |
-
if args.backend == "faster-whisper":
|
| 31 |
-
from faster_whisper import WhisperModel
|
| 32 |
-
asr_cls = FasterWhisperASR
|
| 33 |
-
elif args.backend == "openai-api":
|
| 34 |
-
asr_cls = OpenaiApiASR
|
| 35 |
-
else:
|
| 36 |
-
import whisper
|
| 37 |
-
import whisper_timestamped
|
| 38 |
-
# from whisper_timestamped_model import WhisperTimestampedASR
|
| 39 |
-
asr_cls = WhisperTimestampedASR
|
| 40 |
-
|
| 41 |
-
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
| 42 |
-
|
| 43 |
if args.task == "translate":
|
| 44 |
asr.set_translate_task()
|
| 45 |
tgt_language = "en"
|
| 46 |
else:
|
| 47 |
tgt_language = language
|
| 48 |
|
| 49 |
-
e = time.time()
|
| 50 |
-
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
|
| 51 |
-
|
| 52 |
-
if args.vad:
|
| 53 |
-
print("setting VAD filter",file=sys.stderr)
|
| 54 |
-
asr.use_vad()
|
| 55 |
-
|
| 56 |
-
|
| 57 |
min_chunk = args.min_chunk_size
|
| 58 |
|
| 59 |
if args.buffer_trimming == "sentence":
|
|
|
|
| 24 |
size = args.model
|
| 25 |
language = args.lan
|
| 26 |
|
| 27 |
+
asr = asr_factory(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if args.task == "translate":
|
| 29 |
asr.set_translate_task()
|
| 30 |
tgt_language = "en"
|
| 31 |
else:
|
| 32 |
tgt_language = language
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
min_chunk = args.min_chunk_size
|
| 35 |
|
| 36 |
if args.buffer_trimming == "sentence":
|