v-e-n-o-m's picture
Switch to whisper-medium for CPU
04cf987
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from transformers import pipeline
import soundfile as sf
import numpy as np
import torch
from pydub import AudioSegment
import tempfile
import os
import logging
from contextlib import contextmanager
from timeout_decorator import timeout
import librosa
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print("Starting FastAPI app...")
@app.get("/")
async def root():
logger.info("Root endpoint accessed")
return {"message": "Whisper API is running"}
@app.get("/health")
async def health():
logger.info("Health check accessed")
device = "cuda" if torch.cuda.is_available() else "cpu"
return {"status": "ok", "model": "whisper-medium", "device": device}
@app.on_event("startup")
async def startup_event():
print("Uvicorn started successfully")
print("Loading Whisper-medium...")
try:
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-medium",
torch_dtype=torch.float32,
device="cpu",
model_kwargs={"use_safetensors": True}
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise e
@contextmanager
def temp_file(suffix):
temp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
try:
yield temp.name
finally:
os.unlink(temp.name)
@timeout(30, use_signals=False) # 30s timeout
def transcribe_audio(audio_data, language):
logger.info("Starting transcription pipeline...")
result = pipe(audio_data, generate_kwargs={"language": language, "task": "transcribe"})
logger.info("Transcription pipeline completed")
return result
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...), language: str = Form(...)):
logger.info(f"Received request: language={language}, filename={audio.filename}")
try:
valid_languages = {"en", "ur", "ar"}
if language.lower() not in valid_languages:
raise HTTPException(400, detail="Invalid language. Use 'en', 'ur', or 'ar'")
audio_bytes = await audio.read()
if not audio_bytes:
raise HTTPException(400, detail="Empty audio file")
ext = os.path.splitext(audio.filename)[1].lower() or ".mp3"
logger.info(f"Processing audio with extension: {ext}")
with temp_file(ext) as temp_audio_path:
with open(temp_audio_path, "wb") as f:
f.write(audio_bytes)
duration = librosa.get_duration(path=temp_audio_path)
logger.info(f"Audio duration: {duration} seconds")
if duration > 60:
raise HTTPException(400, detail="Audio too long, max 60s")
with temp_file(".wav") as temp_wav_path:
if ext != ".wav":
logger.info(f"Converting {temp_audio_path} to WAV...")
audio_segment = AudioSegment.from_file(temp_audio_path)
audio_segment = audio_segment.set_frame_rate(16000).set_channels(1)
audio_segment.export(temp_wav_path, format="wav")
else:
logger.info("Skipping conversion for WAV input")
temp_wav_path = temp_audio_path
audio_data, sample_rate = sf.read(temp_wav_path)
logger.info(f"Audio data shape: {audio_data.shape}, sample rate: {sample_rate}")
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if sample_rate != 16000:
raise HTTPException(500, detail="Audio is not 16kHz")
if np.max(np.abs(audio_data)) < 1e-5:
raise HTTPException(400, detail="Audio is silent")
logger.info("Transcribing...")
try:
result = transcribe_audio(audio_data, language.lower())
logger.info(f"Transcription: {result['text']}")
return {"text": result["text"]}
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
raise HTTPException(500, detail=f"Transcription failed: {str(e)}")
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(500, detail=f"Unexpected error: {str(e)}")