WhisperX-V2 / app.py
StevenChen16's picture
Update app.py
4036c8e verified
raw
history blame
4.34 kB
import spaces
import torch
import gradio as gr
import whisperx
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import gc
import os
import time
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2 # Reduced batch size
COMPUTE_TYPE = "int8" # Changed to int8 for lower memory usage
FILE_LIMIT_MB = 25 # Reduced file size limit
def clean_gpu_memory():
"""Helper function to clean GPU memory"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU
def transcribe_audio(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
try:
# Check file size
file_size = os.path.getsize(inputs) / (1024 * 1024) # Convert to MB
if file_size > FILE_LIMIT_MB:
raise gr.Error(f"File size ({file_size:.1f}MB) exceeds limit of {FILE_LIMIT_MB}MB")
# Load audio with error handling
try:
audio = whisperx.load_audio(inputs)
except Exception as e:
raise gr.Error(f"Error loading audio file: {str(e)}")
# 1. Transcribe with base Whisper model
try:
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
result = model.transcribe(audio, batch_size=BATCH_SIZE)
finally:
clean_gpu_memory()
if 'model' in locals():
del model
# 2. Align whisper output
try:
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
finally:
clean_gpu_memory()
if 'model_a' in locals():
del model_a
# 3. Diarize audio
try:
diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ.get("HF_TOKEN"), device=DEVICE)
diarize_segments = diarize_model(audio)
finally:
if 'diarize_model' in locals():
del diarize_model
clean_gpu_memory()
# 4. Assign speaker labels
result = whisperx.assign_word_speakers(diarize_segments, result)
# Format output
output_text = ""
for segment in result['segments']:
speaker = segment.get('speaker', 'Unknown Speaker')
text = segment['text']
output_text += f"{speaker}: {text}\n"
return output_text
except Exception as e:
clean_gpu_memory()
raise gr.Error(f"Error processing audio: {str(e)}")
# Create Gradio interface
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label=f"Audio Input (Max {FILE_LIMIT_MB}MB)",
)
task = gr.Radio(
["transcribe", "translate"],
label="Task",
value="transcribe"
)
submit_button = gr.Button("Process Audio")
with gr.Column():
output_text = gr.Textbox(
label="Transcription with Speaker Diarization",
lines=10,
placeholder="Transcribed text will appear here..."
)
gr.Markdown(f"""
### Features:
- High-accuracy transcription using WhisperX
- Automatic speaker diarization
- Support for both microphone recording and file upload
- File size limit: {FILE_LIMIT_MB}MB
### Note:
- Processing may take a few moments
- For optimal results, use clear audio with minimal background noise
- If you encounter errors, try with a shorter audio clip
""")
submit_button.click(
fn=transcribe_audio,
inputs=[audio_input, task],
outputs=output_text
)
demo.queue(max_size=1).launch(
share=False,
debug=True,
show_error=True,
ssr_mode=False
)