ahundt commited on
Commit
b7a0a78
·
1 Parent(s): e45ab03

try more robust code (not working yet), notifications about status

Browse files
Files changed (3) hide show
  1. app.py +255 -85
  2. pyproject.toml +11 -0
  3. uv.lock +0 -0
app.py CHANGED
@@ -3,6 +3,8 @@ import base64
3
  import os
4
  import time
5
  from io import BytesIO
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -17,21 +19,82 @@ from gradio_webrtc import (
17
  )
18
  from PIL import Image
19
 
 
 
 
 
 
 
 
 
20
 
21
- def encode_audio(data: np.ndarray) -> dict:
22
- """Encode Audio data to send to the server"""
23
- return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
24
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def encode_image(data: np.ndarray) -> dict:
27
- with BytesIO() as output_bytes:
28
- pil_image = Image.fromarray(data)
29
- pil_image.save(output_bytes, "JPEG")
30
- bytes_data = output_bytes.getvalue()
31
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
32
- return {"mime_type": "image/jpeg", "data": base64_str}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
 
 
35
  class GeminiHandler(AsyncAudioVideoStreamHandler):
36
  def __init__(
37
  self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
@@ -54,117 +117,224 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
54
  output_sample_rate=self.output_sample_rate,
55
  output_frame_size=self.output_frame_size,
56
  )
57
-
58
  async def video_receive(self, frame: np.ndarray):
59
  if self.session:
60
- # send image every 1 second
61
- if time.time() - self.last_frame_time > 1:
62
- self.last_frame_time = time.time()
63
- await self.session.send(encode_image(frame))
64
- if self.latest_args[2] is not None:
65
- await self.session.send(encode_image(self.latest_args[2]))
66
- self.video_queue.put_nowait(frame)
67
-
 
 
 
 
68
  async def video_emit(self) -> VideoEmitType:
69
- return await self.video_queue.get()
 
 
 
 
 
 
 
70
 
71
  async def connect(self, api_key: str):
 
72
  if self.session is None:
73
- client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
74
- config = {"response_modalities": ["AUDIO"]}
75
- async with client.aio.live.connect(
76
- model="gemini-2.0-flash-exp", config=config
77
- ) as session:
78
- self.session = session
79
- asyncio.create_task(self.receive_audio())
80
- await self.quit.wait()
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  async def generator(self):
 
 
 
 
83
  while not self.quit.is_set():
84
- turn = self.session.receive()
85
- async for response in turn:
86
- if data := response.data:
87
- yield data
88
-
 
 
 
 
 
89
  async def receive_audio(self):
90
- async for audio_response in async_aggregate_bytes_to_16bit(
91
- self.generator()
92
- ):
93
- self.audio_queue.put_nowait(audio_response)
 
 
 
 
 
 
94
 
95
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
96
  _, array = frame
97
  array = array.squeeze()
98
- audio_message = encode_audio(array)
99
- if self.session:
100
- await self.session.send(audio_message)
 
 
 
 
 
101
 
102
  async def emit(self) -> AudioEmitType:
103
  if not self.args_set.is_set():
104
  await self.wait_for_args()
105
  if self.session is None:
 
106
  asyncio.create_task(self.connect(self.latest_args[1]))
107
- array = await self.audio_queue.get()
108
- return (self.output_sample_rate, array)
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def shutdown(self) -> None:
 
 
 
111
  self.quit.set()
112
  self.connection = None
113
  self.args_set.clear()
 
 
 
 
114
  self.quit.clear()
 
115
 
116
 
117
-
118
  css = """
119
  #video-source {max-width: 600px !important; max-height: 600 !important;}
120
  """
121
 
122
- with gr.Blocks(css=css) as demo:
123
- gr.HTML(
124
- """
125
- <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
126
- <div style="background-color: var(--block-background-fill); border-radius: 8px">
127
- <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
128
- </div>
129
- <div>
130
- <h1>Gen AI SDK Voice Chat</h1>
131
- <p>Speak with Gemini using real-time audio + video streaming</p>
132
- <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
133
- <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
 
 
 
 
134
  </div>
135
- </div>
136
- """
137
- )
138
- with gr.Row() as api_key_row:
139
- api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API Key", value=os.getenv("GOOGLE_API_KEY"))
140
- with gr.Row(visible=False) as row:
141
- with gr.Column():
142
- webrtc = WebRTC(
143
- label="Video Chat",
144
- modality="audio-video",
145
- mode="send-receive",
146
- elem_id="video-source",
147
- rtc_configuration=get_twilio_turn_credentials(),
148
- icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
149
- pulse_color="rgb(35, 157, 225)",
150
- icon_button_color="rgb(35, 157, 225)",
151
- )
152
- with gr.Column():
153
- image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
154
-
155
- webrtc.stream(
156
- GeminiHandler(),
157
- inputs=[webrtc, api_key, image_input],
158
- outputs=[webrtc],
159
- time_limit=90,
160
- concurrency_limit=2,
161
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  api_key.submit(
163
- lambda: (gr.update(visible=False), gr.update(visible=True)),
164
- None,
165
- [api_key_row, row],
166
- )
167
 
168
 
169
- if __name__ == "__main__":
170
  demo.launch()
 
 
 
 
3
  import os
4
  import time
5
  from io import BytesIO
6
+ import logging
7
+ import traceback # Import traceback
8
 
9
  import gradio as gr
10
  import numpy as np
 
19
  )
20
  from PIL import Image
21
 
22
+ # --- Setup Logging ---
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # --- Global State ---
27
+ twilio_available = None # None = not checked, True = available, False = unavailable
28
+ gemini_connected = False # Track Gemini connection status
29
+ load_complete = asyncio.Event() # Event to signal demo.load completion
30
 
 
 
 
31
 
32
+ # --- Helper Functions ---
33
+ def encode_audio(data: np.ndarray) -> dict:
34
+ """Encode Audio data to send to the server."""
35
+ if not isinstance(data, np.ndarray):
36
+ raise TypeError("encode_audio expected a numpy.ndarray")
37
+ try:
38
+ return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
39
+ except Exception as e:
40
+ logger.error(f"Error encoding audio: {e}")
41
+ raise # Re-raise the exception after logging
42
 
43
  def encode_image(data: np.ndarray) -> dict:
44
+ """Encode Image data to send to the server."""
45
+ if not isinstance(data, np.ndarray):
46
+ raise TypeError("encode_image expected a numpy.ndarray")
47
+ try:
48
+ with BytesIO() as output_bytes:
49
+ pil_image = Image.fromarray(data)
50
+ pil_image.save(output_bytes, "JPEG")
51
+ bytes_data = output_bytes.getvalue()
52
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
53
+ return {"mime_type": "image/jpeg", "data": base64_str}
54
+ except Exception as e:
55
+ logger.error(f"Error encoding image: {e}")
56
+ raise
57
+
58
+ async def check_twilio_availability() -> bool:
59
+ """Checks Twilio TURN server availability with retries and timeout."""
60
+ global twilio_available
61
+ timeout = 10
62
+ retries = 3
63
+ delay = 2
64
+
65
+ try:
66
+ async with asyncio.timeout(timeout):
67
+ for attempt in range(retries):
68
+ try:
69
+ # VERY DETAILED LOGGING HERE
70
+ logger.info(f"Attempting to get Twilio credentials (attempt {attempt + 1})...")
71
+ credentials = get_twilio_turn_credentials()
72
+ logger.info(f"Twilio credentials response: {credentials}") # Log the response
73
+ if credentials:
74
+ twilio_available = True
75
+ logger.info("Twilio TURN server available.")
76
+ return True
77
+ except Exception as e:
78
+ logger.warning(f"Attempt {attempt + 1} to get Twilio credentials failed: {e}")
79
+ # Print the full traceback
80
+ logger.warning(traceback.format_exc())
81
+ if attempt < retries - 1:
82
+ await asyncio.sleep(delay)
83
+ twilio_available = False
84
+ logger.warning("Twilio TURN server unavailable after multiple attempts.")
85
+ return False
86
+ except asyncio.TimeoutError:
87
+ twilio_available = False
88
+ logger.error(f"Twilio TURN server check timed out after {timeout} seconds.")
89
+ return False
90
+ except Exception as e:
91
+ twilio_available = False
92
+ logger.exception(f"Unexpected error checking Twilio availability: {e}")
93
+ return False
94
 
95
 
96
+
97
+ # --- Gemini Handler Class ---
98
  class GeminiHandler(AsyncAudioVideoStreamHandler):
99
  def __init__(
100
  self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
 
117
  output_sample_rate=self.output_sample_rate,
118
  output_frame_size=self.output_frame_size,
119
  )
120
+
121
  async def video_receive(self, frame: np.ndarray):
122
  if self.session:
123
+ try:
124
+ # send image every 1 second
125
+ if time.time() - self.last_frame_time > 1:
126
+ self.last_frame_time = time.time()
127
+ await self.session.send(encode_image(frame))
128
+ if self.latest_args[2] is not None:
129
+ await self.session.send(encode_image(self.latest_args[2]))
130
+ except Exception as e:
131
+ logger.error(f"Error sending video frame: {e}")
132
+ gr.Warning("Error sending video to Gemini. Check your connection and API key.")
133
+ self.video_queue.put_nowait(frame) # Always put the frame in the queue
134
+
135
  async def video_emit(self) -> VideoEmitType:
136
+ try:
137
+ return await self.video_queue.get()
138
+ except asyncio.CancelledError:
139
+ logger.info("Video emit cancelled.")
140
+ return None # Or some other default value
141
+ except Exception as e:
142
+ logger.exception(f"Error in video_emit: {e}")
143
+ return None
144
 
145
  async def connect(self, api_key: str):
146
+ global gemini_connected
147
  if self.session is None:
148
+ try:
149
+ client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
150
+ config = {"response_modalities": ["AUDIO"]}
151
+ async with client.aio.live.connect(
152
+ model="gemini-2.0-flash-exp", config=config
153
+ ) as session:
154
+ self.session = session
155
+ gemini_connected = True
156
+ asyncio.create_task(self.receive_audio())
157
+ await self.quit.wait()
158
+ except Exception as e:
159
+ logger.error(f"Error connecting to Gemini: {e}")
160
+ gemini_connected = False # Set connection status to False
161
+ self.shutdown()
162
+ # Display error in the UI
163
+ gr.Warning(f"Failed to connect to Gemini: {e}")
164
+ finally: # Update UI *after* connection attempt (both success and failure)
165
+ gr.Info(f"Gemini connection status: {'Connected' if gemini_connected else 'Disconnected'}")
166
+
167
 
168
  async def generator(self):
169
+ if not self.session: # Check if session exists
170
+ logger.warning("Gemini session is not initialized.")
171
+ return # Or raise an exception, depending on desired behavior
172
+
173
  while not self.quit.is_set():
174
+ try:
175
+ turn = await self.session.receive()
176
+ async for response in turn:
177
+ if data := response.data:
178
+ yield data
179
+ except Exception as e:
180
+ logger.error(f"Error receiving from Gemini: {e}")
181
+ gr.Warning("Error communicating with Gemini. Check network and API key.")
182
+ break # Exit the loop on error
183
+
184
  async def receive_audio(self):
185
+ try:
186
+ async for audio_response in async_aggregate_bytes_to_16bit(
187
+ self.generator()
188
+ ):
189
+ self.audio_queue.put_nowait(audio_response)
190
+ except asyncio.CancelledError:
191
+ logger.info("Audio receive cancelled.")
192
+ except Exception as e:
193
+ logger.exception(f"Error in receive_audio: {e}")
194
+
195
 
196
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
197
  _, array = frame
198
  array = array.squeeze()
199
+ try:
200
+ audio_message = encode_audio(array)
201
+ if self.session:
202
+ await self.session.send(audio_message)
203
+ except Exception as e:
204
+ logger.error(f"Error sending audio: {e}")
205
+ gr.Warning("Error sending audio to Gemini. Check your connection and API key.")
206
+
207
 
208
  async def emit(self) -> AudioEmitType:
209
  if not self.args_set.is_set():
210
  await self.wait_for_args()
211
  if self.session is None:
212
+ try:
213
  asyncio.create_task(self.connect(self.latest_args[1]))
214
+ except Exception as e:
215
+ logger.error(f"emit error connecting: {e}")
216
+
217
+ try:
218
+ array = await self.audio_queue.get()
219
+ return (self.output_sample_rate, array)
220
+ except asyncio.CancelledError:
221
+ logger.info("Audio emit cancelled.")
222
+ return (self.output_sample_rate, np.array([]))
223
+ except Exception as e:
224
+ logger.exception(f"Error in emit: {e}")
225
+ return (self.output_sample_rate, np.array([])) # Return empty array on error
226
+
227
 
228
  def shutdown(self) -> None:
229
+ global gemini_connected
230
+ gemini_connected = False # Reset on shutdown
231
+ logger.info("Shutting down GeminiHandler.")
232
  self.quit.set()
233
  self.connection = None
234
  self.args_set.clear()
235
+ if self.session:
236
+ # No good async close method, this can get stuck.
237
+ # asyncio.create_task(self.session.close())
238
+ pass
239
  self.quit.clear()
240
+ gr.Info("Gemini connection closed.")
241
 
242
 
243
+ # --- Gradio UI ---
244
  css = """
245
  #video-source {max-width: 600px !important; max-height: 600 !important;}
246
  """
247
 
248
+ async def main():
249
+ global twilio_available, gemini_connected
250
+
251
+ with gr.Blocks(css=css) as demo:
252
+ gr.HTML(
253
+ """
254
+ <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
255
+ <div style="background-color: var(--block-background-fill); border-radius: 8px">
256
+ <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
257
+ </div>
258
+ <div>
259
+ <h1>Gen AI SDK Voice Chat</h1>
260
+ <p>Speak with Gemini using real-time audio + video streaming</p>
261
+ <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
262
+ <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
263
+ </div>
264
  </div>
265
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  )
267
+ twilio_status_message = gr.Markdown("") # For displaying Twilio status
268
+ gemini_status_message = gr.Markdown("") # For Gemini status
269
+
270
+ with gr.Row() as api_key_row:
271
+ api_key = gr.Textbox(
272
+ label="API Key",
273
+ type="password",
274
+ placeholder="Enter your API Key",
275
+ value=os.getenv("GOOGLE_API_KEY"),
276
+ )
277
+ with gr.Row(visible=False) as row:
278
+ with gr.Column():
279
+ webrtc = WebRTC(
280
+ label="Video Chat",
281
+ modality="audio-video",
282
+ mode="send-receive",
283
+ elem_id="video-source",
284
+ rtc_configuration={"iceServers": []}, # DUMMY CONFIGURATION
285
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
286
+ pulse_color="rgb(35, 157, 225)",
287
+ icon_button_color="rgb(35, 157, 225)",
288
+ )
289
+ with gr.Column():
290
+ image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
291
+
292
+
293
+ async def update_twilio_status_and_ui():
294
+ """Updates Twilio status and UI elements."""
295
+ await check_twilio_availability() # Check Twilio availability
296
+
297
+ if twilio_available:
298
+ rtc_config = get_twilio_turn_credentials()
299
+ message = "Twilio TURN server available. Connection should be reliable."
300
+ else:
301
+ rtc_config = None
302
+ message = "**Warning:** Twilio TURN server unavailable. Connection might be less reliable or fail if you are behind a symmetric NAT."
303
+ load_complete.set() # Signal that load is complete - *before* returning
304
+ return gr.update(rtc_configuration=rtc_config), gr.update(value=message)
305
+
306
+ # Check Twilio availability and update UI on startup.
307
+ demo.load(update_twilio_status_and_ui, [], [webrtc, twilio_status_message])
308
+
309
+ async def start_streaming():
310
+ """Starts the WebRTC streaming after load_complete is set."""
311
+ await load_complete.wait() # *Wait* for load to complete
312
+ await asyncio.sleep(0.1) # Small delay (optional, but can help)
313
+ webrtc.stream(
314
+ GeminiHandler(),
315
+ inputs=[webrtc, api_key, image_input],
316
+ outputs=[webrtc],
317
+ time_limit=90,
318
+ concurrency_limit=None, # Removed concurrency limit
319
+ )
320
+
321
+ # Use .then() to chain start_streaming *after* demo.load
322
+ demo.load(None, [], []).then(start_streaming, [], [])
323
+
324
+
325
+ def check_api_key(api_key_str):
326
+ if not api_key_str:
327
+ return gr.update(visible=True), gr.update(visible=False), gr.update(value="Please enter a valid API key")
328
+ return gr.update(visible=False), gr.update(visible=True), gr.update(value="")
329
+
330
  api_key.submit(
331
+ check_api_key,
332
+ [api_key],
333
+ [api_key_row, row, twilio_status_message],
334
+ )
335
 
336
 
 
337
  demo.launch()
338
+
339
+ if __name__ == "__main__":
340
+ asyncio.run(main())
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gemini-audio-video-chat"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "gradio_webrtc==0.0.28",
9
+ "google-genai==0.3.0",
10
+ "twilio"
11
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff