Spaces:
Build error
Build error
import os | |
import logging | |
import json | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from fastrtc import ( | |
AdditionalOutputs, | |
ReplyOnPause, | |
Stream, | |
AlgoOptions, | |
SileroVadOptions, | |
audio_to_bytes, | |
) | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
pipeline, | |
) | |
from transformers.utils import is_flash_attn_2_available | |
from utils.logger_config import setup_logging | |
from utils.device import get_device, get_torch_and_np_dtypes | |
from utils.turn_server import get_rtc_credentials | |
# Load environment variables | |
load_dotenv() | |
load_dotenv('.env.local') # Load local environment variables if they exist | |
# Set deployment mode | |
os.environ["APP_MODE"] = "deployed" | |
os.environ["UI_MODE"] = "fastapi" | |
# Verify HF_TOKEN is set | |
if not os.getenv("HF_TOKEN"): | |
raise ValueError("HF_TOKEN environment variable is not set. Please set it in .env.local or as a secret in Hugging Face Spaces.") | |
setup_logging() | |
logger = logging.getLogger(__name__) | |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") | |
device = get_device(force_cpu=False) | |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) | |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") | |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" | |
logger.info(f"Using attention: {attention}") | |
logger.info(f"Loading Whisper model: {MODEL_ID}") | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation=attention, | |
token=os.getenv("HF_TOKEN") # Use HF_TOKEN for model loading | |
) | |
model.to(device) | |
except Exception as e: | |
logger.error(f"Error loading ASR model: {e}") | |
logger.error(f"Are you providing a valid model ID? {MODEL_ID}") | |
raise | |
processor = AutoProcessor.from_pretrained( | |
MODEL_ID, | |
token=os.getenv("HF_TOKEN") # Use HF_TOKEN for processor loading | |
) | |
transcribe_pipeline = pipeline( | |
task="automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Warm up the model with empty audio | |
logger.info("Warming up Whisper model with dummy input") | |
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence | |
transcribe_pipeline(warmup_audio) | |
logger.info("Model warmup complete") | |
async def transcribe(audio: tuple[int, np.ndarray]): | |
sample_rate, audio_array = audio | |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") | |
outputs = transcribe_pipeline( | |
audio_to_bytes(audio), | |
chunk_length_s=3, | |
batch_size=1, | |
generate_kwargs={ | |
'task': 'transcribe', | |
'language': 'english', | |
}, | |
) | |
yield AdditionalOutputs(outputs["text"].strip()) | |
logger.info("Initializing FastRTC stream") | |
stream = Stream( | |
handler=ReplyOnPause( | |
transcribe, | |
algo_options=AlgoOptions( | |
audio_chunk_duration=0.6, | |
started_talking_threshold=0.2, | |
speech_threshold=0.1, | |
), | |
model_options=SileroVadOptions( | |
threshold=0.5, | |
min_speech_duration_ms=250, | |
max_speech_duration_s=30, | |
min_silence_duration_ms=2000, | |
window_size_samples=1024, | |
speech_pad_ms=400, | |
), | |
), | |
modality="audio", | |
mode="send", | |
additional_outputs=[ | |
gr.Textbox(label="Transcript"), | |
], | |
additional_outputs_handler=lambda current, new: current + " " + new, | |
rtc_configuration=get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials | |
) | |
app = FastAPI() | |
stream.mount(app) | |
async def index(): | |
html_content = open("index.html").read() | |
rtc_config = get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials | |
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))) | |
def _(webrtc_id: str): | |
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") | |
async def output_stream(): | |
try: | |
async for output in stream.output_stream(webrtc_id): | |
transcript = output.args[0] | |
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") | |
yield f"event: output\ndata: {transcript}\n\n" | |
except Exception as e: | |
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") | |
raise | |
return StreamingResponse(output_stream(), media_type="text/event-stream") |