import os import logging import json import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import StreamingResponse, HTMLResponse from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, audio_to_bytes, ) from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, ) from transformers.utils import is_flash_attn_2_available from utils.logger_config import setup_logging from utils.device import get_device, get_torch_and_np_dtypes from utils.turn_server import get_rtc_credentials # Load environment variables load_dotenv() load_dotenv('.env.local') # Load local environment variables if they exist # Set deployment mode os.environ["APP_MODE"] = "deployed" os.environ["UI_MODE"] = "fastapi" # Verify HF_TOKEN is set if not os.getenv("HF_TOKEN"): raise ValueError("HF_TOKEN environment variable is not set. Please set it in .env.local or as a secret in Hugging Face Spaces.") setup_logging() logger = logging.getLogger(__name__) MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") device = get_device(force_cpu=False) torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" logger.info(f"Using attention: {attention}") logger.info(f"Loading Whisper model: {MODEL_ID}") try: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attention, token=os.getenv("HF_TOKEN") # Use HF_TOKEN for model loading ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") logger.error(f"Are you providing a valid model ID? {MODEL_ID}") raise processor = AutoProcessor.from_pretrained( MODEL_ID, token=os.getenv("HF_TOKEN") # Use HF_TOKEN for processor loading ) transcribe_pipeline = pipeline( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) # Warm up the model with empty audio logger.info("Warming up Whisper model with dummy input") warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence transcribe_pipeline(warmup_audio) logger.info("Model warmup complete") async def transcribe(audio: tuple[int, np.ndarray]): sample_rate, audio_array = audio logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") outputs = transcribe_pipeline( audio_to_bytes(audio), chunk_length_s=3, batch_size=1, generate_kwargs={ 'task': 'transcribe', 'language': 'english', }, ) yield AdditionalOutputs(outputs["text"].strip()) logger.info("Initializing FastRTC stream") stream = Stream( handler=ReplyOnPause( transcribe, algo_options=AlgoOptions( audio_chunk_duration=0.6, started_talking_threshold=0.2, speech_threshold=0.1, ), model_options=SileroVadOptions( threshold=0.5, min_speech_duration_ms=250, max_speech_duration_s=30, min_silence_duration_ms=2000, window_size_samples=1024, speech_pad_ms=400, ), ), modality="audio", mode="send", additional_outputs=[ gr.Textbox(label="Transcript"), ], additional_outputs_handler=lambda current, new: current + " " + new, rtc_configuration=get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials ) app = FastAPI() stream.mount(app) @app.get("/") async def index(): html_content = open("index.html").read() rtc_config = get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))) @app.get("/transcript") def _(webrtc_id: str): logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") async def output_stream(): try: async for output in stream.output_stream(webrtc_id): transcript = output.args[0] logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") yield f"event: output\ndata: {transcript}\n\n" except Exception as e: logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") raise return StreamingResponse(output_stream(), media_type="text/event-stream")