Spaces:
Running
Running
import asyncio | |
import logging | |
import time | |
from fastapi import FastAPI, File, UploadFile | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import torch | |
import io | |
import soundfile as sf | |
import numpy as np | |
import torchaudio | |
import psutil | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Load the model and processor | |
model_name = "ihanif/whisper-medium-urdu" | |
try: | |
logger.info(f"Loading processor for {model_name}") | |
processor = WhisperProcessor.from_pretrained( | |
model_name, | |
language="urdu", | |
task="transcribe", | |
clean_up_tokenization_spaces=True # Suppress FutureWarning | |
) | |
logger.info(f"Loading model for {model_name}") | |
model = WhisperForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True) | |
except Exception as e: | |
logger.error(f"Error loading model or processor: {str(e)}") | |
raise | |
# Set Urdu language and task | |
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe") | |
logger.info("Set forced_decoder_ids for Urdu transcription") | |
# Move model to CPU | |
device = "cpu" | |
model.to(device) | |
logger.info(f"Model loaded and moved to {device}") | |
# Log memory usage | |
def log_memory_usage(): | |
process = psutil.Process() | |
mem_info = process.memory_info() | |
logger.info(f"Memory usage: {mem_info.rss / 1024**2:.2f} MB") | |
async def transcribe_audio(file: UploadFile = File(...)): | |
try: | |
start_time = time.time() | |
log_memory_usage() | |
# Read audio file | |
logger.info("Reading audio file") | |
audio_data, sample_rate = sf.read(io.BytesIO(await file.read())) | |
logger.info(f"Audio read in {time.time() - start_time:.2f} seconds") | |
# Ensure audio is mono | |
if len(audio_data.shape) > 1: | |
audio_data = np.mean(audio_data, axis=1) | |
# Resample to 16kHz if necessary | |
target_sample_rate = 16000 | |
if sample_rate != target_sample_rate: | |
logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz") | |
audio_tensor = torch.from_numpy(audio_data).float() | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) | |
audio_data = resampler(audio_tensor).numpy() | |
sample_rate = target_sample_rate | |
# Process audio input | |
logger.info("Processing audio input") | |
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt") | |
input_features = inputs.input_features.to(device) | |
# Generate transcription with async timeout | |
logger.info("Generating transcription") | |
async def generate_transcription(): | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
input_features, | |
max_new_tokens=225, | |
num_beams=1, | |
length_penalty=0.0 | |
) | |
return generated_ids | |
try: | |
generated_ids = await asyncio.wait_for(generate_transcription(), timeout=60) # 60-second timeout | |
except asyncio.TimeoutError: | |
logger.error("Transcription timed out after 60 seconds") | |
return {"error": "Transcription took too long. Try a smaller model (e.g., whisper-small) or upgrade to a paid Hugging Face Space with GPU."} | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
total_time = time.time() - start_time | |
logger.info(f"Total transcription time: {total_time:.2f} seconds") | |
log_memory_usage() | |
return {"transcription": transcription} | |
except Exception as e: | |
logger.error(f"Error during transcription: {str(e)}") | |
return {"error": str(e)} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |