on1onmangoes commited on
Commit
74fb39f
·
verified ·
1 Parent(s): aafa40b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from dotenv import load_dotenv
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import StreamingResponse, HTMLResponse
10
+ from fastrtc import (
11
+ AdditionalOutputs,
12
+ ReplyOnPause,
13
+ Stream,
14
+ AlgoOptions,
15
+ SileroVadOptions,
16
+ audio_to_bytes,
17
+ )
18
+ from transformers import (
19
+ AutoModelForSpeechSeq2Seq,
20
+ AutoProcessor,
21
+ pipeline,
22
+ )
23
+ from transformers.utils import is_flash_attn_2_available
24
+
25
+ from utils.logger_config import setup_logging
26
+ from utils.device import get_device, get_torch_and_np_dtypes
27
+ from utils.turn_server import get_rtc_credentials
28
+
29
+ # Load environment variables
30
+ load_dotenv()
31
+ load_dotenv('.env.local') # Load local environment variables if they exist
32
+
33
+ # Set deployment mode
34
+ os.environ["APP_MODE"] = "deployed"
35
+ os.environ["UI_MODE"] = "fastapi"
36
+
37
+ # Verify HF_TOKEN is set
38
+ if not os.getenv("HF_TOKEN"):
39
+ raise ValueError("HF_TOKEN environment variable is not set. Please set it in .env.local or as a secret in Hugging Face Spaces.")
40
+
41
+ setup_logging()
42
+ logger = logging.getLogger(__name__)
43
+
44
+ MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
45
+
46
+ device = get_device(force_cpu=False)
47
+ torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
48
+ logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
49
+
50
+ attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
51
+ logger.info(f"Using attention: {attention}")
52
+
53
+ logger.info(f"Loading Whisper model: {MODEL_ID}")
54
+
55
+ try:
56
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
57
+ MODEL_ID,
58
+ torch_dtype=torch_dtype,
59
+ low_cpu_mem_usage=True,
60
+ use_safetensors=True,
61
+ attn_implementation=attention,
62
+ token=os.getenv("HF_TOKEN") # Use HF_TOKEN for model loading
63
+ )
64
+ model.to(device)
65
+ except Exception as e:
66
+ logger.error(f"Error loading ASR model: {e}")
67
+ logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
68
+ raise
69
+
70
+ processor = AutoProcessor.from_pretrained(
71
+ MODEL_ID,
72
+ token=os.getenv("HF_TOKEN") # Use HF_TOKEN for processor loading
73
+ )
74
+
75
+ transcribe_pipeline = pipeline(
76
+ task="automatic-speech-recognition",
77
+ model=model,
78
+ tokenizer=processor.tokenizer,
79
+ feature_extractor=processor.feature_extractor,
80
+ torch_dtype=torch_dtype,
81
+ device=device,
82
+ )
83
+
84
+ # Warm up the model with empty audio
85
+ logger.info("Warming up Whisper model with dummy input")
86
+ warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
87
+ transcribe_pipeline(warmup_audio)
88
+ logger.info("Model warmup complete")
89
+
90
+ async def transcribe(audio: tuple[int, np.ndarray]):
91
+ sample_rate, audio_array = audio
92
+ logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
93
+
94
+ outputs = transcribe_pipeline(
95
+ audio_to_bytes(audio),
96
+ chunk_length_s=3,
97
+ batch_size=1,
98
+ generate_kwargs={
99
+ 'task': 'transcribe',
100
+ 'language': 'english',
101
+ },
102
+ )
103
+ yield AdditionalOutputs(outputs["text"].strip())
104
+
105
+ logger.info("Initializing FastRTC stream")
106
+ stream = Stream(
107
+ handler=ReplyOnPause(
108
+ transcribe,
109
+ algo_options=AlgoOptions(
110
+ audio_chunk_duration=0.6,
111
+ started_talking_threshold=0.2,
112
+ speech_threshold=0.1,
113
+ ),
114
+ model_options=SileroVadOptions(
115
+ threshold=0.5,
116
+ min_speech_duration_ms=250,
117
+ max_speech_duration_s=30,
118
+ min_silence_duration_ms=2000,
119
+ window_size_samples=1024,
120
+ speech_pad_ms=400,
121
+ ),
122
+ ),
123
+ modality="audio",
124
+ mode="send",
125
+ additional_outputs=[
126
+ gr.Textbox(label="Transcript"),
127
+ ],
128
+ additional_outputs_handler=lambda current, new: current + " " + new,
129
+ rtc_configuration=get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials
130
+ )
131
+
132
+ app = FastAPI()
133
+ stream.mount(app)
134
+
135
+ @app.get("/")
136
+ async def index():
137
+ html_content = open("index.html").read()
138
+ rtc_config = get_rtc_credentials(provider="hf", token=os.getenv("HF_TOKEN")) # Pass HF_TOKEN to get_rtc_credentials
139
+ return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
140
+
141
+ @app.get("/transcript")
142
+ def _(webrtc_id: str):
143
+ logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
144
+ async def output_stream():
145
+ try:
146
+ async for output in stream.output_stream(webrtc_id):
147
+ transcript = output.args[0]
148
+ logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
149
+ yield f"event: output\ndata: {transcript}\n\n"
150
+ except Exception as e:
151
+ logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
152
+ raise
153
+
154
+ return StreamingResponse(output_stream(), media_type="text/event-stream")