Tijs Zwinkels
		
	commited on
		
		
					Commit 
							
							·
						
						c30969f
	
1
								Parent(s):
							
							1f2352f
								
OpenAI Whisper API backend
Browse files- whisper_online.py +75 -1
- whisper_online_server.py +2 -0
    	
        whisper_online.py
    CHANGED
    
    | @@ -4,6 +4,8 @@ import numpy as np | |
| 4 | 
             
            import librosa  
         | 
| 5 | 
             
            from functools import lru_cache
         | 
| 6 | 
             
            import time
         | 
|  | |
|  | |
| 7 |  | 
| 8 |  | 
| 9 |  | 
| @@ -142,6 +144,76 @@ class FasterWhisperASR(ASRBase): | |
| 142 | 
             
                    self.transcribe_kargs["task"] = "translate"
         | 
| 143 |  | 
| 144 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 145 |  | 
| 146 | 
             
            class HypothesisBuffer:
         | 
| 147 |  | 
| @@ -453,7 +525,7 @@ def add_shared_args(parser): | |
| 453 | 
             
                parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
         | 
| 454 | 
             
                parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
         | 
| 455 | 
             
                parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
         | 
| 456 | 
            -
                parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
         | 
| 457 | 
             
                parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
         | 
| 458 | 
             
                parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
         | 
| 459 | 
             
                parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
         | 
| @@ -493,6 +565,8 @@ if __name__ == "__main__": | |
| 493 |  | 
| 494 | 
             
                if args.backend == "faster-whisper":
         | 
| 495 | 
             
                    asr_cls = FasterWhisperASR
         | 
|  | |
|  | |
| 496 | 
             
                else:
         | 
| 497 | 
             
                    asr_cls = WhisperTimestampedASR
         | 
| 498 |  | 
|  | |
| 4 | 
             
            import librosa  
         | 
| 5 | 
             
            from functools import lru_cache
         | 
| 6 | 
             
            import time
         | 
| 7 | 
            +
            import io
         | 
| 8 | 
            +
            import soundfile as sf
         | 
| 9 |  | 
| 10 |  | 
| 11 |  | 
|  | |
| 144 | 
             
                    self.transcribe_kargs["task"] = "translate"
         | 
| 145 |  | 
| 146 |  | 
| 147 | 
            +
            class OpenaiApiASR(ASRBase):
         | 
| 148 | 
            +
                """Uses OpenAI's Whisper API for audio transcription."""
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0):
         | 
| 151 | 
            +
                    self.modelname = "whisper-1"  # modelsize is not used but kept for interface consistency
         | 
| 152 | 
            +
                    self.language = lan  # ISO-639-1 language code
         | 
| 153 | 
            +
                    self.response_format = response_format
         | 
| 154 | 
            +
                    self.temperature = temperature
         | 
| 155 | 
            +
                    self.model = self.load_model(modelsize, cache_dir, model_dir)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def load_model(self, *args, **kwargs):
         | 
| 158 | 
            +
                    from openai import OpenAI
         | 
| 159 | 
            +
                    self.client = OpenAI()
         | 
| 160 | 
            +
                    # Since we're using the OpenAI API, there's no model to load locally.
         | 
| 161 | 
            +
                    print("Model configuration is set to use the OpenAI Whisper API.")
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def ts_words(self, segments):
         | 
| 164 | 
            +
                    o = []
         | 
| 165 | 
            +
                    for segment in segments:
         | 
| 166 | 
            +
                        # Skip segments containing no speech
         | 
| 167 | 
            +
                        if segment["no_speech_prob"] > 0.8:
         | 
| 168 | 
            +
                            continue
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        # Splitting the text into words and filtering out empty strings
         | 
| 171 | 
            +
                        words = [word.strip() for word in segment["text"].split() if word.strip()]
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        if not words:
         | 
| 174 | 
            +
                            continue
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                        # Assign start and end times for each word
         | 
| 177 | 
            +
                        # We only have timestamps per segment, so interpolating start and end-times
         | 
| 178 | 
            +
                        # assuming equal duration per word
         | 
| 179 | 
            +
                        segment_duration = segment["end"] - segment["start"]
         | 
| 180 | 
            +
                        duration_per_word = segment_duration / len(words)
         | 
| 181 | 
            +
                        start_time = segment["start"]
         | 
| 182 | 
            +
                        for word in words:
         | 
| 183 | 
            +
                            end_time = start_time + duration_per_word
         | 
| 184 | 
            +
                            o.append((start_time, end_time, word))
         | 
| 185 | 
            +
                            start_time = end_time
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    return o
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
                def segments_end_ts(self, res):
         | 
| 191 | 
            +
                    return [s["end"] for s in res]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def transcribe(self, audio_data, prompt=None, *args, **kwargs):
         | 
| 194 | 
            +
                    # Write the audio data to a buffer
         | 
| 195 | 
            +
                    buffer = io.BytesIO()
         | 
| 196 | 
            +
                    buffer.name = "temp.wav"
         | 
| 197 | 
            +
                    sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
         | 
| 198 | 
            +
                    buffer.seek(0)  # Reset buffer's position to the beginning
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # Prepare transcription parameters
         | 
| 201 | 
            +
                    transcription_params = {
         | 
| 202 | 
            +
                        "model": self.modelname,
         | 
| 203 | 
            +
                        "file": buffer,
         | 
| 204 | 
            +
                        "response_format": self.response_format,
         | 
| 205 | 
            +
                        "temperature": self.temperature
         | 
| 206 | 
            +
                    }
         | 
| 207 | 
            +
                    if self.language:
         | 
| 208 | 
            +
                        transcription_params["language"] = self.language
         | 
| 209 | 
            +
                    if prompt:
         | 
| 210 | 
            +
                        transcription_params["prompt"] = prompt
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    # Perform the transcription
         | 
| 213 | 
            +
                    transcript = self.client.audio.transcriptions.create(**transcription_params)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    return transcript.segments
         | 
| 216 | 
            +
             | 
| 217 |  | 
| 218 | 
             
            class HypothesisBuffer:
         | 
| 219 |  | 
|  | |
| 525 | 
             
                parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
         | 
| 526 | 
             
                parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
         | 
| 527 | 
             
                parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
         | 
| 528 | 
            +
                parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
         | 
| 529 | 
             
                parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
         | 
| 530 | 
             
                parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
         | 
| 531 | 
             
                parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
         | 
|  | |
| 565 |  | 
| 566 | 
             
                if args.backend == "faster-whisper":
         | 
| 567 | 
             
                    asr_cls = FasterWhisperASR
         | 
| 568 | 
            +
                elif args.backend == "openai-api":
         | 
| 569 | 
            +
                    asr_cls = OpenaiApiASR
         | 
| 570 | 
             
                else:
         | 
| 571 | 
             
                    asr_cls = WhisperTimestampedASR
         | 
| 572 |  | 
    	
        whisper_online_server.py
    CHANGED
    
    | @@ -29,6 +29,8 @@ print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ", | |
| 29 | 
             
            if args.backend == "faster-whisper":
         | 
| 30 | 
             
                from faster_whisper import WhisperModel
         | 
| 31 | 
             
                asr_cls = FasterWhisperASR
         | 
|  | |
|  | |
| 32 | 
             
            else:
         | 
| 33 | 
             
                import whisper
         | 
| 34 | 
             
                import whisper_timestamped
         | 
|  | |
| 29 | 
             
            if args.backend == "faster-whisper":
         | 
| 30 | 
             
                from faster_whisper import WhisperModel
         | 
| 31 | 
             
                asr_cls = FasterWhisperASR
         | 
| 32 | 
            +
            elif args.backend == "openai-api":
         | 
| 33 | 
            +
                asr_cls = OpenaiApiASR
         | 
| 34 | 
             
            else:
         | 
| 35 | 
             
                import whisper
         | 
| 36 | 
             
                import whisper_timestamped
         | 
