Spaces:
Build error
Build error
import spaces | |
import torch | |
import gradio as gr | |
import whisperx | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
import tempfile | |
import gc | |
import os | |
import time | |
# Constants | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
BATCH_SIZE = 2 # Reduced batch size | |
COMPUTE_TYPE = "int8" # Changed to int8 for lower memory usage | |
FILE_LIMIT_MB = 25 # Reduced file size limit | |
def clean_gpu_memory(): | |
"""Helper function to clean GPU memory""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def transcribe_audio(inputs, task): | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
try: | |
# Check file size | |
file_size = os.path.getsize(inputs) / (1024 * 1024) # Convert to MB | |
if file_size > FILE_LIMIT_MB: | |
raise gr.Error(f"File size ({file_size:.1f}MB) exceeds limit of {FILE_LIMIT_MB}MB") | |
# Load audio with error handling | |
try: | |
audio = whisperx.load_audio(inputs) | |
except Exception as e: | |
raise gr.Error(f"Error loading audio file: {str(e)}") | |
# 1. Transcribe with base Whisper model | |
try: | |
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE) | |
result = model.transcribe(audio, batch_size=BATCH_SIZE) | |
finally: | |
clean_gpu_memory() | |
if 'model' in locals(): | |
del model | |
# 2. Align whisper output | |
try: | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE) | |
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) | |
finally: | |
clean_gpu_memory() | |
if 'model_a' in locals(): | |
del model_a | |
# 3. Diarize audio | |
try: | |
diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ.get("HF_TOKEN"), device=DEVICE) | |
diarize_segments = diarize_model(audio) | |
finally: | |
if 'diarize_model' in locals(): | |
del diarize_model | |
clean_gpu_memory() | |
# 4. Assign speaker labels | |
result = whisperx.assign_word_speakers(diarize_segments, result) | |
# Format output | |
output_text = "" | |
for segment in result['segments']: | |
speaker = segment.get('speaker', 'Unknown Speaker') | |
text = segment['text'] | |
output_text += f"{speaker}: {text}\n" | |
return output_text | |
except Exception as e: | |
clean_gpu_memory() | |
raise gr.Error(f"Error processing audio: {str(e)}") | |
# Create Gradio interface | |
demo = gr.Blocks(theme=gr.themes.Ocean()) | |
with demo: | |
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label=f"Audio Input (Max {FILE_LIMIT_MB}MB)", | |
) | |
task = gr.Radio( | |
["transcribe", "translate"], | |
label="Task", | |
value="transcribe" | |
) | |
submit_button = gr.Button("Process Audio") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Transcription with Speaker Diarization", | |
lines=10, | |
placeholder="Transcribed text will appear here..." | |
) | |
gr.Markdown(f""" | |
### Features: | |
- High-accuracy transcription using WhisperX | |
- Automatic speaker diarization | |
- Support for both microphone recording and file upload | |
- File size limit: {FILE_LIMIT_MB}MB | |
### Note: | |
- Processing may take a few moments | |
- For optimal results, use clear audio with minimal background noise | |
- If you encounter errors, try with a shorter audio clip | |
""") | |
submit_button.click( | |
fn=transcribe_audio, | |
inputs=[audio_input, task], | |
outputs=output_text | |
) | |
demo.queue(max_size=1).launch( | |
share=False, | |
debug=True, | |
show_error=True, | |
ssr_mode=False | |
) |