import spaces import torch import gradio as gr import whisperx from transformers.pipelines.audio_utils import ffmpeg_read import tempfile import gc import os import time # Constants DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 2 # Reduced batch size COMPUTE_TYPE = "int8" # Changed to int8 for lower memory usage FILE_LIMIT_MB = 25 # Reduced file size limit def clean_gpu_memory(): """Helper function to clean GPU memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() @spaces.GPU def transcribe_audio(inputs, task): if inputs is None: raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") try: # Check file size file_size = os.path.getsize(inputs) / (1024 * 1024) # Convert to MB if file_size > FILE_LIMIT_MB: raise gr.Error(f"File size ({file_size:.1f}MB) exceeds limit of {FILE_LIMIT_MB}MB") # Load audio with error handling try: audio = whisperx.load_audio(inputs) except Exception as e: raise gr.Error(f"Error loading audio file: {str(e)}") # 1. Transcribe with base Whisper model try: model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE) result = model.transcribe(audio, batch_size=BATCH_SIZE) finally: clean_gpu_memory() if 'model' in locals(): del model # 2. Align whisper output try: model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE) result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) finally: clean_gpu_memory() if 'model_a' in locals(): del model_a # 3. Diarize audio try: diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ.get("HF_TOKEN"), device=DEVICE) diarize_segments = diarize_model(audio) finally: if 'diarize_model' in locals(): del diarize_model clean_gpu_memory() # 4. Assign speaker labels result = whisperx.assign_word_speakers(diarize_segments, result) # Format output output_text = "" for segment in result['segments']: speaker = segment.get('speaker', 'Unknown Speaker') text = segment['text'] output_text += f"{speaker}: {text}\n" return output_text except Exception as e: clean_gpu_memory() raise gr.Error(f"Error processing audio: {str(e)}") # Create Gradio interface demo = gr.Blocks(theme=gr.themes.Ocean()) with demo: gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label=f"Audio Input (Max {FILE_LIMIT_MB}MB)", ) task = gr.Radio( ["transcribe", "translate"], label="Task", value="transcribe" ) submit_button = gr.Button("Process Audio") with gr.Column(): output_text = gr.Textbox( label="Transcription with Speaker Diarization", lines=10, placeholder="Transcribed text will appear here..." ) gr.Markdown(f""" ### Features: - High-accuracy transcription using WhisperX - Automatic speaker diarization - Support for both microphone recording and file upload - File size limit: {FILE_LIMIT_MB}MB ### Note: - Processing may take a few moments - For optimal results, use clear audio with minimal background noise - If you encounter errors, try with a shorter audio clip """) submit_button.click( fn=transcribe_audio, inputs=[audio_input, task], outputs=output_text ) demo.queue(max_size=1).launch( share=False, debug=True, show_error=True, ssr_mode=False )