v-e-n-o-m's picture
deploy
f1ea267
import asyncio
import logging
import time
from fastapi import FastAPI, File, UploadFile
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import io
import soundfile as sf
import numpy as np
import torchaudio
import psutil
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load the model and processor
model_name = "ihanif/whisper-medium-urdu"
try:
logger.info(f"Loading processor for {model_name}")
processor = WhisperProcessor.from_pretrained(
model_name,
language="urdu",
task="transcribe",
clean_up_tokenization_spaces=True # Suppress FutureWarning
)
logger.info(f"Loading model for {model_name}")
model = WhisperForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True)
except Exception as e:
logger.error(f"Error loading model or processor: {str(e)}")
raise
# Set Urdu language and task
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe")
logger.info("Set forced_decoder_ids for Urdu transcription")
# Move model to CPU
device = "cpu"
model.to(device)
logger.info(f"Model loaded and moved to {device}")
# Log memory usage
def log_memory_usage():
process = psutil.Process()
mem_info = process.memory_info()
logger.info(f"Memory usage: {mem_info.rss / 1024**2:.2f} MB")
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
try:
start_time = time.time()
log_memory_usage()
# Read audio file
logger.info("Reading audio file")
audio_data, sample_rate = sf.read(io.BytesIO(await file.read()))
logger.info(f"Audio read in {time.time() - start_time:.2f} seconds")
# Ensure audio is mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16kHz if necessary
target_sample_rate = 16000
if sample_rate != target_sample_rate:
logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
audio_tensor = torch.from_numpy(audio_data).float()
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_data = resampler(audio_tensor).numpy()
sample_rate = target_sample_rate
# Process audio input
logger.info("Processing audio input")
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
input_features = inputs.input_features.to(device)
# Generate transcription with async timeout
logger.info("Generating transcription")
async def generate_transcription():
with torch.no_grad():
generated_ids = model.generate(
input_features,
max_new_tokens=225,
num_beams=1,
length_penalty=0.0
)
return generated_ids
try:
generated_ids = await asyncio.wait_for(generate_transcription(), timeout=60) # 60-second timeout
except asyncio.TimeoutError:
logger.error("Transcription timed out after 60 seconds")
return {"error": "Transcription took too long. Try a smaller model (e.g., whisper-small) or upgrade to a paid Hugging Face Space with GPU."}
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
total_time = time.time() - start_time
logger.info(f"Total transcription time: {total_time:.2f} seconds")
log_memory_usage()
return {"transcription": transcription}
except Exception as e:
logger.error(f"Error during transcription: {str(e)}")
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)