import gradio as gr from transformers import pipeline import numpy as np import os import torch import torchaudio # For VAD print(f"DEBUG: Gradio version being used: {gr.__version__}") # --- Configuration --- MODEL_NAME = os.getenv("ASR_MODEL", "openai/whisper-base.en") DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("USE_GPU", "false").lower() == "true" else "cpu" print(f"Using device: {DEVICE}") # --- Global Variables --- asr_pipeline = None vad_model = None vad_utils = None audio_buffer = [] # To accumulate audio chunks MAX_BUFFER_SECONDS = 10 # Max audio to buffer before forcing transcription SILENCE_THRESHOLD_SECONDS = 1.5 # How long silence before processing speech segment # --- Load Models --- def load_models(): global asr_pipeline, vad_model, vad_utils try: print(f"Loading ASR model: {MODEL_NAME} on device: {DEVICE}") asr_pipeline = pipeline( task="automatic-speech-recognition", model=MODEL_NAME, device=DEVICE if DEVICE == "cuda" else -1 ) print("ASR model loaded successfully.") print("Loading Silero VAD model...") # Silero VAD model itself is small and runs on CPU efficiently vad_model, vad_utils_tuple = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, # Set to True if you have issues onnx=True) # Use ONNX for better CPU performance (get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = vad_utils_tuple vad_utils = { "get_speech_timestamps": get_speech_timestamps, "VADIterator": VADIterator } print("Silero VAD model loaded successfully.") except Exception as e: print(f"Error loading models: {e}") if asr_pipeline is None: print("ASR pipeline failed to load.") if vad_model is None: print("VAD model failed to load.") load_models() # Load models at startup # --- Core Transcription Logic with VAD --- def transcribe_with_vad(new_chunk_audio, history_state): global audio_buffer if new_chunk_audio is None or asr_pipeline is None or vad_model is None: return history_state.get("full_text", ""), history_state sample_rate, audio_data = new_chunk_audio audio_data_float32 = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max # Append to buffer audio_buffer.append(audio_data_float32) # Check buffer length; if too short, wait for more audio current_buffer_duration = sum(len(chunk) / sample_rate for chunk in audio_buffer) # If buffer is empty or too short, just return current state if not audio_buffer or current_buffer_duration < 0.2: # Minimum duration to process return history_state.get("full_text", ""), history_state # Concatenate buffer for VAD processing full_audio_np = np.concatenate(audio_buffer) full_audio_tensor = torch.from_numpy(full_audio_np).float() # Use VAD to find speech timestamps # We're looking for the *end* of speech segments # This is a simplified approach: we process if VAD detects no speech in the latest part # or if the buffer gets too long. try: # For simplicity, let's analyze the last N seconds for silence # A more robust VADIterator approach would be better for continuous streaming # but is more complex to manage with Gradio's chunking. # Let's try a simpler VAD: check if the last chunk contains speech # For a more robust solution, use VADIterator or process the whole buffer speech_timestamps = vad_utils["get_speech_timestamps"]( full_audio_tensor, vad_model, sampling_rate=sample_rate, min_silence_duration_ms=500 # ms of silence to consider a break ) # Heuristic: if speech_timestamps is empty for the latest chunk, # OR if the buffer is long, OR if there's a significant pause process_now = False transcribed_text_segment = "" if not speech_timestamps: # If no speech detected in the current combined buffer if current_buffer_duration > SILENCE_THRESHOLD_SECONDS: # and we have enough audio to assume it's silence after speech process_now = True elif current_buffer_duration > MAX_BUFFER_SECONDS: # Buffer is too long, process it process_now = True else: # If speech is detected, check if the end of the last speech segment is significantly before the end of the buffer # This indicates a pause after speech. if speech_timestamps: last_speech_end_s = speech_timestamps[-1]['end'] / sample_rate if current_buffer_duration - last_speech_end_s > SILENCE_THRESHOLD_SECONDS: process_now = True if process_now and full_audio_np.any(): # Ensure there's actual audio data print(f"Processing {current_buffer_duration:.2f}s of buffered audio.") # Transcribe the entire current buffer transcription_result = asr_pipeline( {"sampling_rate": sample_rate, "raw": full_audio_np.copy()}, # Send a copy # You can add whisper specific args here if needed e.g. chunk_length_s for long-form # generate_kwargs={"task": "transcribe", "language": "<|en|>"} # for multilingual models ) new_text = transcription_result["text"].strip() if new_text: transcribed_text_segment = new_text + " " history_state["full_text"] = history_state.get("full_text", "") + transcribed_text_segment print(f"VAD processed: '{new_text}'") audio_buffer = [] # Clear buffer after processing except Exception as e: print(f"Error during VAD/transcription: {e}") # Fallback: transcribe accumulated buffer if error, then clear if audio_buffer: try: full_audio_fallback = np.concatenate(audio_buffer) if full_audio_fallback.any(): transcription_result = asr_pipeline( {"sampling_rate": sample_rate, "raw": full_audio_fallback.copy()} ) new_text = transcription_result["text"].strip() if new_text: history_state["full_text"] = history_state.get("full_text", "") + new_text + " " print(f"Fallback processed: '{new_text}'") except Exception as fallback_e: print(f"Error during fallback transcription: {fallback_e}") audio_buffer = [] # Clear buffer return history_state.get("full_text", ""), history_state # --- Gradio UI (largely the same, just point to new function and manage state) --- with gr.Blocks(title="Live Transcription with VAD") as demo: gr.Markdown( f""" # 🎙️ Live Speech-to-Text with VAD & Hugging Face Whisper Speak into your microphone. Transcription will appear after speech segments. Using model: `{MODEL_NAME}` on device: `{DEVICE}`. VAD: Silero VAD """ ) if asr_pipeline is None or vad_model is None: gr.Markdown("## ⚠️ Error: Models Not Loaded. Check logs. ⚠️") transcription_history = gr.State({"full_text": ""}) with gr.Row(): audio_input = gr.Audio( sources=["microphone"], type="numpy", streaming=True, label="Speak Here (Streaming Active with VAD)", ) transcription_output = gr.Textbox( label="Live Transcription", lines=15, interactive=False, show_copy_button=True ) # Adjust 'every' based on how frequently you want to check the VAD buffer # Smaller 'every' means more frequent checks, potentially more responsive VAD # but also more frequent function calls. audio_input.stream( fn=transcribe_with_vad, inputs=[audio_input, transcription_history], outputs=[transcription_output, transcription_history], every=0.5 # Check buffer and VAD every 0.5 seconds ) def clear_transcription_state(current_state): global audio_buffer audio_buffer = [] # Also clear the audio buffer current_state["full_text"] = "" print("Transcription and audio buffer cleared.") return "", current_state clear_button = gr.Button("Clear Transcription & Buffer") clear_button.click( fn=clear_transcription_state, inputs=[transcription_history], outputs=[transcription_output, transcription_history] ) gr.Markdown("---") if __name__ == "__main__": # os.environ["ASR_MODEL"] = "openai/whisper-tiny.en" # os.environ["USE_GPU"] = "False" # load_models() # Ensure models are loaded if running locally demo.queue().launch(debug=True, share=False)