Dominik Macháček
		
	commited on
		
		
					Commit 
							
							·
						
						b1878ce
	
1
								Parent(s):
							
							8116b21
								
offline option
Browse files- whisper_online.py +54 -35
    	
        whisper_online.py
    CHANGED
    
    | @@ -22,6 +22,8 @@ def load_audio_chunk(fname, beg, end): | |
| 22 |  | 
| 23 | 
             
            class ASRBase:
         | 
| 24 |  | 
|  | |
|  | |
| 25 | 
             
                def __init__(self, modelsize, lan, cache_dir):
         | 
| 26 | 
             
                    self.original_language = lan 
         | 
| 27 |  | 
| @@ -74,6 +76,8 @@ class FasterWhisperASR(ASRBase): | |
| 74 | 
             
                    import faster_whisper
         | 
| 75 | 
             
                """
         | 
| 76 |  | 
|  | |
|  | |
| 77 | 
             
                def load_model(self, modelsize, cache_dir):
         | 
| 78 | 
             
                    # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
         | 
| 79 |  | 
| @@ -98,8 +102,8 @@ class FasterWhisperASR(ASRBase): | |
| 98 | 
             
                    o = []
         | 
| 99 | 
             
                    for segment in segments:
         | 
| 100 | 
             
                        for word in segment.words:
         | 
| 101 | 
            -
                            # stripping the spaces
         | 
| 102 | 
            -
                            w = word.word | 
| 103 | 
             
                            t = (word.start, word.end, w)
         | 
| 104 | 
             
                            o.append(t)
         | 
| 105 | 
             
                    return o
         | 
| @@ -109,19 +113,6 @@ class FasterWhisperASR(ASRBase): | |
| 109 |  | 
| 110 |  | 
| 111 |  | 
| 112 | 
            -
            def to_flush(sents, offset=0):
         | 
| 113 | 
            -
                # concatenates the timestamped words or sentences into one sequence that is flushed in one line
         | 
| 114 | 
            -
                # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
         | 
| 115 | 
            -
                # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
         | 
| 116 | 
            -
                t = " ".join(s[2] for s in sents)
         | 
| 117 | 
            -
                if len(sents) == 0:
         | 
| 118 | 
            -
                    b = None
         | 
| 119 | 
            -
                    e = None
         | 
| 120 | 
            -
                else:
         | 
| 121 | 
            -
                    b = offset + sents[0][0]
         | 
| 122 | 
            -
                    e = offset + sents[-1][1]
         | 
| 123 | 
            -
                return (b,e,t)
         | 
| 124 | 
            -
             | 
| 125 | 
             
            class HypothesisBuffer:
         | 
| 126 |  | 
| 127 | 
             
                def __init__(self):
         | 
| @@ -254,8 +245,8 @@ class OnlineASRProcessor: | |
| 254 | 
             
                    self.transcript_buffer.insert(tsw, self.buffer_time_offset)
         | 
| 255 | 
             
                    o = self.transcript_buffer.flush()
         | 
| 256 | 
             
                    self.commited.extend(o)
         | 
| 257 | 
            -
                    print(">>>>COMPLETE NOW:",to_flush(o),file=sys.stderr,flush=True)
         | 
| 258 | 
            -
                    print("INCOMPLETE:",to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
         | 
| 259 |  | 
| 260 | 
             
                    # there is a newly confirmed text
         | 
| 261 | 
             
                    if o:
         | 
| @@ -301,7 +292,7 @@ class OnlineASRProcessor: | |
| 301 | 
             
                        #self.chunk_at(t)
         | 
| 302 |  | 
| 303 | 
             
                    print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr)
         | 
| 304 | 
            -
                    return to_flush(o)
         | 
| 305 |  | 
| 306 | 
             
                def chunk_completed_sentence(self):
         | 
| 307 | 
             
                    if self.commited == []: return
         | 
| @@ -383,11 +374,26 @@ class OnlineASRProcessor: | |
| 383 | 
             
                    Returns: the same format as self.process_iter()
         | 
| 384 | 
             
                    """
         | 
| 385 | 
             
                    o = self.transcript_buffer.complete()
         | 
| 386 | 
            -
                    f = to_flush(o)
         | 
| 387 | 
             
                    print("last, noncommited:",f,file=sys.stderr)
         | 
| 388 | 
             
                    return f
         | 
| 389 |  | 
| 390 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 391 |  | 
| 392 |  | 
| 393 | 
             
            ## main:
         | 
| @@ -401,6 +407,7 @@ parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the | |
| 401 | 
             
            parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
         | 
| 402 | 
             
            parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
         | 
| 403 | 
             
            parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
         | 
|  | |
| 404 | 
             
            args = parser.parse_args()
         | 
| 405 |  | 
| 406 | 
             
            audio_path = args.audio_path
         | 
| @@ -440,6 +447,9 @@ a = load_audio_chunk(audio_path,0,1) | |
| 440 | 
             
            # warm up the ASR, because the very first transcribe takes much more time than the other
         | 
| 441 | 
             
            asr.transcribe(a)
         | 
| 442 |  | 
|  | |
|  | |
|  | |
| 443 | 
             
            def output_transcript(o):
         | 
| 444 | 
             
                # output format in stdout is like:
         | 
| 445 | 
             
                # 4186.3606 0 1720 Takhle to je
         | 
| @@ -453,18 +463,9 @@ def output_transcript(o): | |
| 453 | 
             
                else:
         | 
| 454 | 
             
                    print(o,file=sys.stderr,flush=True)
         | 
| 455 |  | 
| 456 | 
            -
             | 
| 457 | 
            -
             | 
| 458 | 
            -
            start = time.time()-beg
         | 
| 459 | 
            -
            while True:
         | 
| 460 | 
            -
                now = time.time() - start
         | 
| 461 | 
            -
                if now < end+min_chunk:
         | 
| 462 | 
            -
                    time.sleep(min_chunk+end-now)
         | 
| 463 | 
            -
                end = time.time() - start
         | 
| 464 | 
            -
                a = load_audio_chunk(audio_path,beg,end)
         | 
| 465 | 
            -
                beg = end
         | 
| 466 | 
             
                online.insert_audio_chunk(a)
         | 
| 467 | 
            -
             | 
| 468 | 
             
                try:
         | 
| 469 | 
             
                    o = online.process_iter()
         | 
| 470 | 
             
                except AssertionError:
         | 
| @@ -472,13 +473,31 @@ while True: | |
| 472 | 
             
                    pass
         | 
| 473 | 
             
                else:
         | 
| 474 | 
             
                    output_transcript(o)
         | 
| 475 | 
            -
             | 
| 476 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 477 |  | 
| 478 | 
            -
             | 
| 479 |  | 
| 480 | 
            -
             | 
| 481 | 
            -
             | 
| 482 |  | 
| 483 | 
             
            o = online.finish()
         | 
| 484 | 
             
            output_transcript(o)
         | 
|  | |
| 22 |  | 
| 23 | 
             
            class ASRBase:
         | 
| 24 |  | 
| 25 | 
            +
                sep = " "
         | 
| 26 | 
            +
             | 
| 27 | 
             
                def __init__(self, modelsize, lan, cache_dir):
         | 
| 28 | 
             
                    self.original_language = lan 
         | 
| 29 |  | 
|  | |
| 76 | 
             
                    import faster_whisper
         | 
| 77 | 
             
                """
         | 
| 78 |  | 
| 79 | 
            +
                sep = ""
         | 
| 80 | 
            +
             | 
| 81 | 
             
                def load_model(self, modelsize, cache_dir):
         | 
| 82 | 
             
                    # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
         | 
| 83 |  | 
|  | |
| 102 | 
             
                    o = []
         | 
| 103 | 
             
                    for segment in segments:
         | 
| 104 | 
             
                        for word in segment.words:
         | 
| 105 | 
            +
                            # not stripping the spaces -- should not be merged with them!
         | 
| 106 | 
            +
                            w = word.word
         | 
| 107 | 
             
                            t = (word.start, word.end, w)
         | 
| 108 | 
             
                            o.append(t)
         | 
| 109 | 
             
                    return o
         | 
|  | |
| 113 |  | 
| 114 |  | 
| 115 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 116 | 
             
            class HypothesisBuffer:
         | 
| 117 |  | 
| 118 | 
             
                def __init__(self):
         | 
|  | |
| 245 | 
             
                    self.transcript_buffer.insert(tsw, self.buffer_time_offset)
         | 
| 246 | 
             
                    o = self.transcript_buffer.flush()
         | 
| 247 | 
             
                    self.commited.extend(o)
         | 
| 248 | 
            +
                    print(">>>>COMPLETE NOW:",self.to_flush(o),file=sys.stderr,flush=True)
         | 
| 249 | 
            +
                    print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
         | 
| 250 |  | 
| 251 | 
             
                    # there is a newly confirmed text
         | 
| 252 | 
             
                    if o:
         | 
|  | |
| 292 | 
             
                        #self.chunk_at(t)
         | 
| 293 |  | 
| 294 | 
             
                    print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr)
         | 
| 295 | 
            +
                    return self.to_flush(o)
         | 
| 296 |  | 
| 297 | 
             
                def chunk_completed_sentence(self):
         | 
| 298 | 
             
                    if self.commited == []: return
         | 
|  | |
| 374 | 
             
                    Returns: the same format as self.process_iter()
         | 
| 375 | 
             
                    """
         | 
| 376 | 
             
                    o = self.transcript_buffer.complete()
         | 
| 377 | 
            +
                    f = self.to_flush(o)
         | 
| 378 | 
             
                    print("last, noncommited:",f,file=sys.stderr)
         | 
| 379 | 
             
                    return f
         | 
| 380 |  | 
| 381 |  | 
| 382 | 
            +
                def to_flush(self, sents, sep=None, offset=0, ):
         | 
| 383 | 
            +
                    # concatenates the timestamped words or sentences into one sequence that is flushed in one line
         | 
| 384 | 
            +
                    # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
         | 
| 385 | 
            +
                    # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
         | 
| 386 | 
            +
                    if sep is None:
         | 
| 387 | 
            +
                        sep = self.asr.sep
         | 
| 388 | 
            +
                    t = sep.join(s[2] for s in sents)
         | 
| 389 | 
            +
                    if len(sents) == 0:
         | 
| 390 | 
            +
                        b = None
         | 
| 391 | 
            +
                        e = None
         | 
| 392 | 
            +
                    else:
         | 
| 393 | 
            +
                        b = offset + sents[0][0]
         | 
| 394 | 
            +
                        e = offset + sents[-1][1]
         | 
| 395 | 
            +
                    return (b,e,t)
         | 
| 396 | 
            +
             | 
| 397 |  | 
| 398 |  | 
| 399 | 
             
            ## main:
         | 
|  | |
| 407 | 
             
            parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
         | 
| 408 | 
             
            parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
         | 
| 409 | 
             
            parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
         | 
| 410 | 
            +
            parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
         | 
| 411 | 
             
            args = parser.parse_args()
         | 
| 412 |  | 
| 413 | 
             
            audio_path = args.audio_path
         | 
|  | |
| 447 | 
             
            # warm up the ASR, because the very first transcribe takes much more time than the other
         | 
| 448 | 
             
            asr.transcribe(a)
         | 
| 449 |  | 
| 450 | 
            +
            beg = args.start_at
         | 
| 451 | 
            +
            start = time.time()-beg
         | 
| 452 | 
            +
             | 
| 453 | 
             
            def output_transcript(o):
         | 
| 454 | 
             
                # output format in stdout is like:
         | 
| 455 | 
             
                # 4186.3606 0 1720 Takhle to je
         | 
|  | |
| 463 | 
             
                else:
         | 
| 464 | 
             
                    print(o,file=sys.stderr,flush=True)
         | 
| 465 |  | 
| 466 | 
            +
            if args.offline: ## offline mode processing (for testing/debugging)
         | 
| 467 | 
            +
                a = load_audio(audio_path)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 468 | 
             
                online.insert_audio_chunk(a)
         | 
|  | |
| 469 | 
             
                try:
         | 
| 470 | 
             
                    o = online.process_iter()
         | 
| 471 | 
             
                except AssertionError:
         | 
|  | |
| 473 | 
             
                    pass
         | 
| 474 | 
             
                else:
         | 
| 475 | 
             
                    output_transcript(o)
         | 
| 476 | 
            +
            else: # online = simultaneous mode
         | 
| 477 | 
            +
                end = 0
         | 
| 478 | 
            +
                while True:
         | 
| 479 | 
            +
                    now = time.time() - start
         | 
| 480 | 
            +
                    if now < end+min_chunk:
         | 
| 481 | 
            +
                        time.sleep(min_chunk+end-now)
         | 
| 482 | 
            +
                    end = time.time() - start
         | 
| 483 | 
            +
                    a = load_audio_chunk(audio_path,beg,end)
         | 
| 484 | 
            +
                    beg = end
         | 
| 485 | 
            +
                    online.insert_audio_chunk(a)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    try:
         | 
| 488 | 
            +
                        o = online.process_iter()
         | 
| 489 | 
            +
                    except AssertionError:
         | 
| 490 | 
            +
                        print("assertion error",file=sys.stderr)
         | 
| 491 | 
            +
                        pass
         | 
| 492 | 
            +
                    else:
         | 
| 493 | 
            +
                        output_transcript(o)
         | 
| 494 | 
            +
                    now = time.time() - start
         | 
| 495 | 
            +
                    print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
         | 
| 496 |  | 
| 497 | 
            +
                    print(file=sys.stderr,flush=True)
         | 
| 498 |  | 
| 499 | 
            +
                    if end >= duration:
         | 
| 500 | 
            +
                        break
         | 
| 501 |  | 
| 502 | 
             
            o = online.finish()
         | 
| 503 | 
             
            output_transcript(o)
         | 
