Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from transformers import pipeline | |
import soundfile as sf | |
import numpy as np | |
import torch | |
from pydub import AudioSegment | |
import tempfile | |
import os | |
import logging | |
from contextlib import contextmanager | |
from timeout_decorator import timeout | |
import librosa | |
app = FastAPI() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
print("Starting FastAPI app...") | |
async def root(): | |
logger.info("Root endpoint accessed") | |
return {"message": "Whisper API is running"} | |
async def health(): | |
logger.info("Health check accessed") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
return {"status": "ok", "model": "whisper-medium", "device": device} | |
async def startup_event(): | |
print("Uvicorn started successfully") | |
print("Loading Whisper-medium...") | |
try: | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-medium", | |
torch_dtype=torch.float32, | |
device="cpu", | |
model_kwargs={"use_safetensors": True} | |
) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Model loading failed: {str(e)}") | |
raise e | |
def temp_file(suffix): | |
temp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) | |
try: | |
yield temp.name | |
finally: | |
os.unlink(temp.name) | |
# 30s timeout | |
def transcribe_audio(audio_data, language): | |
logger.info("Starting transcription pipeline...") | |
result = pipe(audio_data, generate_kwargs={"language": language, "task": "transcribe"}) | |
logger.info("Transcription pipeline completed") | |
return result | |
async def transcribe(audio: UploadFile = File(...), language: str = Form(...)): | |
logger.info(f"Received request: language={language}, filename={audio.filename}") | |
try: | |
valid_languages = {"en", "ur", "ar"} | |
if language.lower() not in valid_languages: | |
raise HTTPException(400, detail="Invalid language. Use 'en', 'ur', or 'ar'") | |
audio_bytes = await audio.read() | |
if not audio_bytes: | |
raise HTTPException(400, detail="Empty audio file") | |
ext = os.path.splitext(audio.filename)[1].lower() or ".mp3" | |
logger.info(f"Processing audio with extension: {ext}") | |
with temp_file(ext) as temp_audio_path: | |
with open(temp_audio_path, "wb") as f: | |
f.write(audio_bytes) | |
duration = librosa.get_duration(path=temp_audio_path) | |
logger.info(f"Audio duration: {duration} seconds") | |
if duration > 60: | |
raise HTTPException(400, detail="Audio too long, max 60s") | |
with temp_file(".wav") as temp_wav_path: | |
if ext != ".wav": | |
logger.info(f"Converting {temp_audio_path} to WAV...") | |
audio_segment = AudioSegment.from_file(temp_audio_path) | |
audio_segment = audio_segment.set_frame_rate(16000).set_channels(1) | |
audio_segment.export(temp_wav_path, format="wav") | |
else: | |
logger.info("Skipping conversion for WAV input") | |
temp_wav_path = temp_audio_path | |
audio_data, sample_rate = sf.read(temp_wav_path) | |
logger.info(f"Audio data shape: {audio_data.shape}, sample rate: {sample_rate}") | |
if len(audio_data.shape) > 1: | |
audio_data = np.mean(audio_data, axis=1) | |
if sample_rate != 16000: | |
raise HTTPException(500, detail="Audio is not 16kHz") | |
if np.max(np.abs(audio_data)) < 1e-5: | |
raise HTTPException(400, detail="Audio is silent") | |
logger.info("Transcribing...") | |
try: | |
result = transcribe_audio(audio_data, language.lower()) | |
logger.info(f"Transcription: {result['text']}") | |
return {"text": result["text"]} | |
except Exception as e: | |
logger.error(f"Transcription failed: {str(e)}") | |
raise HTTPException(500, detail=f"Transcription failed: {str(e)}") | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
raise HTTPException(500, detail=f"Unexpected error: {str(e)}") |