on1onmangoes's picture
Upload app.py with huggingface_hub
74fb39f verified
raw
history blame
4.88 kB
import os
import logging
import json
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.logger_config import setup_logging
from utils.device import get_device, get_torch_and_np_dtypes
from utils.turn_server import get_rtc_credentials
# Load environment variables
load_dotenv()
load_dotenv('.env.local') # Load local environment variables if they exist
# Set deployment mode
os.environ["APP_MODE"] = "deployed"
os.environ["UI_MODE"] = "fastapi"
# Verify HF_TOKEN is set
if not os.getenv("HF_TOKEN"):
raise ValueError("HF_TOKEN environment variable is not set. Please set it in .env.local or as a secret in Hugging Face Spaces.")
setup_logging()
logger = logging.getLogger(__name__)
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention,
token=os.getenv("HF_TOKEN") # Use HF_TOKEN for model loading
)
model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
raise
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=os.getenv("HF_TOKEN") # Use HF_TOKEN for processor loading
)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
audio_to_bytes(audio),
chunk_length_s=3,
batch_size=1,
generate_kwargs={
'task': 'transcribe',
'language': 'english',
},
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
audio_chunk_duration=0.6,
started_talking_threshold=0.2,
speech_threshold=0.1,
),
model_options=SileroVadOptions(
threshold=0.5,
min_speech_duration_ms=250,
max_speech_duration_s=30,
min_silence_duration_ms=2000,
window_size_samples=1024,
speech_pad_ms=400,
),
),
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials
)
app = FastAPI()
stream.mount(app)
@app.get("/")
async def index():
html_content = open("index.html").read()
rtc_config = get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream")