skip_key_show_status

#1
by ahundt - opened
Files changed (3) hide show
  1. app.py +265 -56
  2. pyproject.toml +12 -0
  3. uv.lock +0 -0
app.py CHANGED
@@ -1,8 +1,13 @@
 
 
 
1
  import asyncio
2
  import base64
3
  import os
4
  import time
5
- from io import BytesIO
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -15,23 +20,109 @@ from gradio_webrtc import (
15
  AudioEmitType,
16
  get_twilio_turn_credentials,
17
  )
18
- from PIL import Image
 
 
 
 
 
 
 
 
19
 
20
 
 
21
  def encode_audio(data: np.ndarray) -> dict:
22
- """Encode Audio data to send to the server"""
23
- return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def encode_image(data: np.ndarray) -> dict:
27
- with BytesIO() as output_bytes:
28
- pil_image = Image.fromarray(data)
29
- pil_image.save(output_bytes, "JPEG")
30
- bytes_data = output_bytes.getvalue()
31
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
32
  return {"mime_type": "image/jpeg", "data": base64_str}
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
35
  class GeminiHandler(AsyncAudioVideoStreamHandler):
36
  def __init__(
37
  self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
@@ -54,71 +145,142 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
54
  output_sample_rate=self.output_sample_rate,
55
  output_frame_size=self.output_frame_size,
56
  )
57
-
58
  async def video_receive(self, frame: np.ndarray):
59
  if self.session:
60
- # send image every 1 second
61
- if time.time() - self.last_frame_time > 1:
62
- self.last_frame_time = time.time()
63
- await self.session.send(encode_image(frame))
64
- if self.latest_args[2] is not None:
65
- await self.session.send(encode_image(self.latest_args[2]))
 
 
 
66
  self.video_queue.put_nowait(frame)
67
-
68
  async def video_emit(self) -> VideoEmitType:
69
- return await self.video_queue.get()
 
 
 
 
 
 
 
70
 
71
  async def connect(self, api_key: str):
 
72
  if self.session is None:
73
- client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
74
- config = {"response_modalities": ["AUDIO"]}
75
- async with client.aio.live.connect(
76
- model="gemini-2.0-flash-exp", config=config
77
- ) as session:
78
- self.session = session
79
- asyncio.create_task(self.receive_audio())
80
- await self.quit.wait()
 
 
 
 
 
 
 
 
 
81
 
82
  async def generator(self):
 
 
 
 
83
  while not self.quit.is_set():
84
- turn = self.session.receive()
85
- async for response in turn:
86
- if data := response.data:
87
- yield data
88
-
 
 
 
 
 
 
 
 
 
 
89
  async def receive_audio(self):
90
- async for audio_response in async_aggregate_bytes_to_16bit(
91
- self.generator()
92
- ):
93
- self.audio_queue.put_nowait(audio_response)
 
94
 
95
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
96
  _, array = frame
97
  array = array.squeeze()
98
- audio_message = encode_audio(array)
99
- if self.session:
100
- await self.session.send(audio_message)
 
 
 
 
101
 
102
  async def emit(self) -> AudioEmitType:
103
  if not self.args_set.is_set():
104
  await self.wait_for_args()
105
  if self.session is None:
106
  asyncio.create_task(self.connect(self.latest_args[1]))
107
- array = await self.audio_queue.get()
108
- return (self.output_sample_rate, array)
 
 
 
 
 
 
 
 
109
 
110
  def shutdown(self) -> None:
111
- self.quit.set()
 
 
 
 
 
 
 
 
 
112
  self.connection = None
113
  self.args_set.clear()
 
114
  self.quit.clear()
 
 
 
 
 
 
 
 
115
 
116
 
117
 
 
118
  css = """
119
  #video-source {max-width: 600px !important; max-height: 600 !important;}
120
  """
121
 
 
 
 
 
 
122
  with gr.Blocks(css=css) as demo:
123
  gr.HTML(
124
  """
@@ -135,16 +297,29 @@ with gr.Blocks(css=css) as demo:
135
  </div>
136
  """
137
  )
 
 
 
138
  with gr.Row() as api_key_row:
139
- api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API Key", value=os.getenv("GOOGLE_API_KEY"))
 
 
 
 
 
140
  with gr.Row(visible=False) as row:
141
  with gr.Column():
 
 
 
 
 
142
  webrtc = WebRTC(
143
  label="Video Chat",
144
  modality="audio-video",
145
  mode="send-receive",
146
  elem_id="video-source",
147
- rtc_configuration=get_twilio_turn_credentials(),
148
  icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
149
  pulse_color="rgb(35, 157, 225)",
150
  icon_button_color="rgb(35, 157, 225)",
@@ -152,19 +327,53 @@ with gr.Blocks(css=css) as demo:
152
  with gr.Column():
153
  image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
154
 
155
- webrtc.stream(
156
- GeminiHandler(),
157
- inputs=[webrtc, api_key, image_input],
158
- outputs=[webrtc],
159
- time_limit=90,
160
- concurrency_limit=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  )
162
- api_key.submit(
163
- lambda: (gr.update(visible=False), gr.update(visible=True)),
164
- None,
165
- [api_key_row, row],
 
166
  )
167
 
 
 
 
 
 
 
 
168
 
169
- if __name__ == "__main__":
170
- demo.launch()
 
1
+ # https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat
2
+ # related demos: https://github.com/freddyaboulton/gradio-webrtc
3
+
4
  import asyncio
5
  import base64
6
  import os
7
  import time
8
+ import logging
9
+ import traceback
10
+ import cv2
11
 
12
  import gradio as gr
13
  import numpy as np
 
20
  AudioEmitType,
21
  get_twilio_turn_credentials,
22
  )
23
+ import requests # Use requests for synchronous Twilio check
24
+
25
+ # --- Setup Logging ---
26
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # --- Global State ---
30
+ twilio_available = None # Will be set *before* Gradio initialization
31
+ gemini_connected = False
32
 
33
 
34
+ # --- Helper Functions ---
35
  def encode_audio(data: np.ndarray) -> dict:
36
+ if not isinstance(data, np.ndarray):
37
+ raise TypeError("encode_audio expected a numpy.ndarray")
38
+ try:
39
+ return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
40
+ except Exception as e:
41
+ logger.error(f"Error encoding audio: {e}")
42
+ raise
43
+
44
+ def encode_image(data: np.ndarray, quality: int = 85) -> dict:
45
+ """
46
+ Encodes a NumPy array (image) to a JPEG, Base64-encoded UTF-8 string using OpenCV.
47
+ Handles various input data types.
48
+
49
+ Args:
50
+ data: A NumPy array of shape (n, n, 3).
51
+ quality: JPEG quality (0-100).
52
+
53
+ Returns:
54
+ A dictionary with keys "mime_type" and "data".
55
+
56
+ Raises:
57
+ TypeError: If input is not a NumPy array.
58
+ ValueError: If input shape is incorrect or contains NaN/Inf.
59
+ Exception: If JPEG encoding fails.
60
+ """
61
+
62
+ # Input Validation (shape and dimensions)
63
+ if not isinstance(data, np.ndarray):
64
+ raise TypeError("Input must be a NumPy array.")
65
+ if data.ndim != 3 or data.shape[2] != 3:
66
+ raise ValueError("Input array must have shape (n, n, 3).")
67
+ if 0 in data.shape:
68
+ raise ValueError("Input array cannot have a dimension of size 0.")
69
+
70
+ # Handle NaN/Inf (regardless of data type)
71
+ if np.any(np.isnan(data)) or np.any(np.isinf(data)):
72
+ raise ValueError("Input array contains NaN or Inf")
73
 
74
+ # Normalize and convert to uint8
75
+ if np.issubdtype(data.dtype, np.floating) or np.issubdtype(data.dtype, np.integer):
76
+ scaled_data = cv2.normalize(data, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
77
+ else:
78
+ raise TypeError("Input array must have a floating-point or integer data type.")
79
+
80
+ # JPEG Encoding (with quality control and error handling)
81
+ try:
82
+ retval, buf = cv2.imencode(".jpg", scaled_data, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
83
+ if not retval:
84
+ raise Exception("cv2.imencode failed")
85
+ except Exception as e:
86
+ raise Exception(f"JPEG encoding failed: {e}")
87
+
88
+ # Base64 Encoding
89
+ jpeg_bytes = np.array(buf).tobytes()
90
+ base64_str = base64.b64encode(jpeg_bytes).decode('utf-8')
91
 
 
 
 
 
 
 
92
  return {"mime_type": "image/jpeg", "data": base64_str}
93
 
94
+ def check_twilio_availability_sync() -> bool:
95
+ """Checks Twilio TURN server availability (synchronous version)."""
96
+ global twilio_available
97
+ retries = 3
98
+ delay = 2
99
+
100
+ for attempt in range(retries):
101
+ try:
102
+ logger.info(f"Attempting to get Twilio credentials (attempt {attempt + 1})...")
103
+ credentials = get_twilio_turn_credentials()
104
+ logger.info(f"Twilio credentials response: {credentials}")
105
+ if credentials:
106
+ twilio_available = True
107
+ logger.info("Twilio TURN server available.")
108
+ return True
109
+ except requests.exceptions.RequestException as e:
110
+ logger.warning(f"Attempt {attempt + 1}: {e}")
111
+ logger.warning(traceback.format_exc())
112
+ if attempt < retries - 1:
113
+ time.sleep(delay)
114
+ except Exception as e:
115
+ logger.exception(f"Unexpected error checking Twilio: {e}")
116
+ twilio_available = False
117
+ return False
118
+
119
+ twilio_available = False
120
+ logger.warning("Twilio TURN server unavailable.")
121
+ return False
122
+
123
 
124
+
125
+ # --- Gemini Handler Class ---
126
  class GeminiHandler(AsyncAudioVideoStreamHandler):
127
  def __init__(
128
  self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
 
145
  output_sample_rate=self.output_sample_rate,
146
  output_frame_size=self.output_frame_size,
147
  )
148
+
149
  async def video_receive(self, frame: np.ndarray):
150
  if self.session:
151
+ try:
152
+ if time.time() - self.last_frame_time > 1:
153
+ self.last_frame_time = time.time()
154
+ await self.session.send(encode_image(frame))
155
+ if self.latest_args[2] is not None:
156
+ await self.session.send(encode_image(self.latest_args[2]))
157
+ except Exception as e:
158
+ logger.error(f"Error sending video frame: {e}")
159
+ gr.Warning("Error sending video to Gemini.")
160
  self.video_queue.put_nowait(frame)
161
+
162
  async def video_emit(self) -> VideoEmitType:
163
+ try:
164
+ return await self.video_queue.get()
165
+ except asyncio.CancelledError:
166
+ logger.info("Video emit cancelled.")
167
+ return None
168
+ except Exception as e:
169
+ logger.exception(f"Error in video_emit: {e}")
170
+ return None
171
 
172
  async def connect(self, api_key: str):
173
+ global gemini_connected
174
  if self.session is None:
175
+ try:
176
+ client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
177
+ config = {"response_modalities": ["AUDIO"]}
178
+ async with client.aio.live.connect(
179
+ model="gemini-2.0-flash-exp", config=config
180
+ ) as session:
181
+ self.session = session
182
+ gemini_connected = True
183
+ asyncio.create_task(self.receive_audio())
184
+ await self.quit.wait()
185
+ except Exception as e:
186
+ logger.error(f"Error connecting to Gemini: {e}")
187
+ gemini_connected = False
188
+ self.shutdown()
189
+ gr.Warning(f"Failed to connect to Gemini: {e}")
190
+ finally:
191
+ update_gemini_status_sync()
192
 
193
  async def generator(self):
194
+ if not self.session:
195
+ logger.warning("Gemini session is not initialized.")
196
+ return
197
+
198
  while not self.quit.is_set():
199
+ try:
200
+ await asyncio.sleep(0) # Yield to the event loop
201
+ if self.quit.is_set():
202
+ break
203
+ turn = self.session.receive()
204
+ async for response in turn:
205
+ if self.quit.is_set():
206
+ break # Exit inner loop if quit is set.
207
+ if data := response.data:
208
+ yield data
209
+ except Exception as e:
210
+ logger.error(f"Error receiving from Gemini: {e}")
211
+ self.quit.set() # set quit if we error.
212
+ break
213
+
214
  async def receive_audio(self):
215
+ try:
216
+ async for audio_response in async_aggregate_bytes_to_16bit(self.generator()):
217
+ self.audio_queue.put_nowait(audio_response)
218
+ except Exception as e:
219
+ logger.exception(f"Error in receive_audio: {e}")
220
 
221
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
222
  _, array = frame
223
  array = array.squeeze()
224
+ try:
225
+ audio_message = encode_audio(array)
226
+ if self.session:
227
+ await self.session.send(audio_message)
228
+ except Exception as e:
229
+ logger.error(f"Error sending audio: {e}")
230
+ gr.Warning("Error sending audio to Gemini.")
231
 
232
  async def emit(self) -> AudioEmitType:
233
  if not self.args_set.is_set():
234
  await self.wait_for_args()
235
  if self.session is None:
236
  asyncio.create_task(self.connect(self.latest_args[1]))
237
+
238
+ try:
239
+ array = await self.audio_queue.get()
240
+ return (self.output_sample_rate, array)
241
+ except asyncio.CancelledError:
242
+ logger.info("Audio emit cancelled.")
243
+ return (self.output_sample_rate, np.array([]))
244
+ except Exception as e:
245
+ logger.exception(f"Error in emit: {e}")
246
+ return (self.output_sample_rate, np.array([]))
247
 
248
  def shutdown(self) -> None:
249
+ global gemini_connected
250
+ gemini_connected = False
251
+ logger.info("Shutting down GeminiHandler.")
252
+ if self.session:
253
+ try:
254
+ # await self.session.close() # There is no async close
255
+ pass
256
+ except Exception:
257
+ pass
258
+ self.quit.set() # Set quit *after* attempting to close the session
259
  self.connection = None
260
  self.args_set.clear()
261
+
262
  self.quit.clear()
263
+ update_gemini_status_sync()
264
+
265
+
266
+ def update_gemini_status_sync():
267
+ """Updates the Gemini status message (synchronous version)."""
268
+ status = "βœ… Gemini: Connected" if gemini_connected else "❌ Gemini: Disconnected"
269
+ if 'demo' in locals() and demo.running:
270
+ gr.update(value=status)
271
 
272
 
273
 
274
+ # --- Gradio UI ---
275
  css = """
276
  #video-source {max-width: 600px !important; max-height: 600 !important;}
277
  """
278
 
279
+ # Perform Twilio check *before* Gradio UI definition (synchronously)
280
+ if __name__ == "__main__":
281
+ check_twilio_availability_sync()
282
+
283
+
284
  with gr.Blocks(css=css) as demo:
285
  gr.HTML(
286
  """
 
297
  </div>
298
  """
299
  )
300
+ twilio_status_message = gr.Markdown("❓ Twilio: Checking...")
301
+ gemini_status_message = gr.Markdown("❓ Gemini: Checking...")
302
+
303
  with gr.Row() as api_key_row:
304
+ api_key = gr.Textbox(
305
+ label="API Key",
306
+ type="password",
307
+ placeholder="Enter your API Key",
308
+ value=os.getenv("GOOGLE_API_KEY"),
309
+ )
310
  with gr.Row(visible=False) as row:
311
  with gr.Column():
312
+ # Set rtc_configuration based on the *pre-checked* twilio_available
313
+ rtc_config = get_twilio_turn_credentials() if twilio_available else None
314
+ # Explicitly specify codecs (example - you might need to adjust)
315
+ if rtc_config:
316
+ rtc_config['codecs'] = ['VP8', 'H264'] # Prefer VP8, then H.264
317
  webrtc = WebRTC(
318
  label="Video Chat",
319
  modality="audio-video",
320
  mode="send-receive",
321
  elem_id="video-source",
322
+ rtc_configuration=rtc_config,
323
  icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
324
  pulse_color="rgb(35, 157, 225)",
325
  icon_button_color="rgb(35, 157, 225)",
 
327
  with gr.Column():
328
  image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
329
 
330
+
331
+ def update_twilio_status_ui():
332
+ if twilio_available:
333
+ message = "βœ… Twilio: Available"
334
+ else:
335
+ message = "❌ Twilio: Unavailable (connection may be less reliable)"
336
+ return gr.update(value=message)
337
+
338
+ demo.load(update_twilio_status_ui, [], [twilio_status_message])
339
+
340
+ handler = GeminiHandler()
341
+ webrtc.stream(
342
+ handler,
343
+ inputs=[webrtc, api_key, image_input],
344
+ outputs=[webrtc],
345
+ time_limit=90,
346
+ concurrency_limit=None,
347
+ )
348
+
349
+
350
+ def check_api_key(api_key_str):
351
+ if not api_key_str:
352
+ return (
353
+ gr.update(visible=True),
354
+ gr.update(visible=False),
355
+ gr.update(value="Please enter a valid API key"),
356
+ gr.update(value="❓ Gemini: Checking..."),
357
+ )
358
+ return (
359
+ gr.update(visible=False),
360
+ gr.update(visible=True),
361
+ gr.update(value=""),
362
+ gr.update(value="❓ Gemini: Checking..."),
363
  )
364
+
365
+ api_key.submit(
366
+ check_api_key,
367
+ [api_key],
368
+ [api_key_row, row, twilio_status_message, gemini_status_message],
369
  )
370
 
371
+ # If API key is already set via environment variables, hide the API key row and show content
372
+ if os.getenv("GOOGLE_API_KEY"):
373
+ demo.load(
374
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
375
+ None,
376
+ [api_key_row, row],
377
+ )
378
 
379
+ demo.launch()
 
pyproject.toml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gemini-audio-video-chat"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "gradio_webrtc==0.0.28",
9
+ "google-genai==0.3.0",
10
+ "twilio",
11
+ "opencv-python"
12
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff