import os import tempfile import json from pathlib import Path from typing import Dict, Any from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import torch import torchaudio from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq import logging import uvicorn # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="Speech-to-Text API", description="API for speech-to-text transcription using CrisperWhisper model", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize model and processor @app.on_event("startup") async def load_model(): logger.info("Loading CrisperWhisper model...") global processor, model, device device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained("nyrahealth/CrisperWhisper") model = AutoModelForSpeechSeq2Seq.from_pretrained("nyrahealth/CrisperWhisper").to(device) model.eval() logger.info(f"Model loaded successfully on {device}") # Create a temporary directory to store files TEMP_DIR = Path(tempfile.mkdtemp()) ALLOWED_EXTENSIONS = {'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'} def is_valid_audio_file(filename: str) -> bool: return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): """ Transcribe an audio file and return word-level timestamps. - **file**: Audio file to transcribe (MP3, WAV, FLAC, OGG, M4A, MP4) Returns a JSON with transcription and timestamps. """ # Check if file is selected if not file.filename: raise HTTPException(status_code=400, detail="No file selected") # Check if file type is allowed if not is_valid_audio_file(file.filename): raise HTTPException(status_code=400, detail=f"File type not allowed. Supported formats: {', '.join(ALLOWED_EXTENSIONS)}") try: # Create a safe filename safe_filename = ''.join(c if c.isalnum() or c in '._- ' else '_' for c in file.filename) file_path = TEMP_DIR / safe_filename # Save the uploaded file with open(file_path, "wb") as buffer: content = await file.read() buffer.write(content) logger.info(f"Processing file: {safe_filename}") # Load audio file waveform, sample_rate = torchaudio.load(file_path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) sample_rate = 16000 # Process audio with the model input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").to(device) # Generate transcription with word timestamps with torch.no_grad(): generated_tokens = model.generate( **input_features, return_timestamps=True, task="transcribe" ) # Process outputs result = processor.decode_timestamps(generated_tokens[0].detach().cpu(), slice_start_indices=True) # Format the output full_text = result['text'] # Process chunks with timestamps chunks = [] for chunk in result['chunks']: # Only include non-empty chunks if chunk['text'].strip(): chunks.append({ "timestamp": [chunk['timestamp'][0], chunk['timestamp'][1]], "text": chunk['text'].strip() }) # Create output JSON output = { "text": full_text, "chunks": chunks } # Clean up the file immediately to save space os.remove(file_path) # Return JSON directly return output except Exception as e: logger.error(f"Error during transcription: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint for Cloud Run""" return {"status": "healthy"} if __name__ == "__main__": port = int(os.environ.get("PORT", 8080)) uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)