|
|
|
|
|
|
|
import asyncio |
|
import base64 |
|
import os |
|
import time |
|
import logging |
|
import traceback |
|
import cv2 |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from google import genai |
|
from gradio_webrtc import ( |
|
AsyncAudioVideoStreamHandler, |
|
WebRTC, |
|
async_aggregate_bytes_to_16bit, |
|
VideoEmitType, |
|
AudioEmitType, |
|
get_twilio_turn_credentials, |
|
) |
|
import requests |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
twilio_available = None |
|
gemini_connected = False |
|
|
|
|
|
|
|
def encode_audio(data: np.ndarray) -> dict: |
|
if not isinstance(data, np.ndarray): |
|
raise TypeError("encode_audio expected a numpy.ndarray") |
|
try: |
|
return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")} |
|
except Exception as e: |
|
logger.error(f"Error encoding audio: {e}") |
|
raise |
|
|
|
def encode_image(data: np.ndarray, quality: int = 85) -> dict: |
|
""" |
|
Encodes a NumPy array (image) to a JPEG, Base64-encoded UTF-8 string using OpenCV. |
|
Handles various input data types. |
|
|
|
Args: |
|
data: A NumPy array of shape (n, n, 3). |
|
quality: JPEG quality (0-100). |
|
|
|
Returns: |
|
A dictionary with keys "mime_type" and "data". |
|
|
|
Raises: |
|
TypeError: If input is not a NumPy array. |
|
ValueError: If input shape is incorrect or contains NaN/Inf. |
|
Exception: If JPEG encoding fails. |
|
""" |
|
|
|
|
|
if not isinstance(data, np.ndarray): |
|
raise TypeError("Input must be a NumPy array.") |
|
if data.ndim != 3 or data.shape[2] != 3: |
|
raise ValueError("Input array must have shape (n, n, 3).") |
|
if 0 in data.shape: |
|
raise ValueError("Input array cannot have a dimension of size 0.") |
|
|
|
|
|
if np.any(np.isnan(data)) or np.any(np.isinf(data)): |
|
raise ValueError("Input array contains NaN or Inf") |
|
|
|
|
|
if np.issubdtype(data.dtype, np.floating) or np.issubdtype(data.dtype, np.integer): |
|
scaled_data = cv2.normalize(data, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) |
|
else: |
|
raise TypeError("Input array must have a floating-point or integer data type.") |
|
|
|
|
|
try: |
|
retval, buf = cv2.imencode(".jpg", scaled_data, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) |
|
if not retval: |
|
raise Exception("cv2.imencode failed") |
|
except Exception as e: |
|
raise Exception(f"JPEG encoding failed: {e}") |
|
|
|
|
|
jpeg_bytes = np.array(buf).tobytes() |
|
base64_str = base64.b64encode(jpeg_bytes).decode('utf-8') |
|
|
|
return {"mime_type": "image/jpeg", "data": base64_str} |
|
|
|
def check_twilio_availability_sync() -> bool: |
|
"""Checks Twilio TURN server availability (synchronous version).""" |
|
global twilio_available |
|
retries = 3 |
|
delay = 2 |
|
|
|
for attempt in range(retries): |
|
try: |
|
logger.info(f"Attempting to get Twilio credentials (attempt {attempt + 1})...") |
|
credentials = get_twilio_turn_credentials() |
|
logger.info(f"Twilio credentials response: {credentials}") |
|
if credentials: |
|
twilio_available = True |
|
logger.info("Twilio TURN server available.") |
|
return True |
|
except requests.exceptions.RequestException as e: |
|
logger.warning(f"Attempt {attempt + 1}: {e}") |
|
logger.warning(traceback.format_exc()) |
|
if attempt < retries - 1: |
|
time.sleep(delay) |
|
except Exception as e: |
|
logger.exception(f"Unexpected error checking Twilio: {e}") |
|
twilio_available = False |
|
return False |
|
|
|
twilio_available = False |
|
logger.warning("Twilio TURN server unavailable.") |
|
return False |
|
|
|
|
|
|
|
|
|
class GeminiHandler(AsyncAudioVideoStreamHandler): |
|
def __init__( |
|
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 |
|
) -> None: |
|
super().__init__( |
|
expected_layout, |
|
output_sample_rate, |
|
output_frame_size, |
|
input_sample_rate=16000, |
|
) |
|
self.audio_queue = asyncio.Queue() |
|
self.video_queue = asyncio.Queue() |
|
self.quit = asyncio.Event() |
|
self.session = None |
|
self.last_frame_time = 0 |
|
|
|
def copy(self) -> "GeminiHandler": |
|
return GeminiHandler( |
|
expected_layout=self.expected_layout, |
|
output_sample_rate=self.output_sample_rate, |
|
output_frame_size=self.output_frame_size, |
|
) |
|
|
|
async def video_receive(self, frame: np.ndarray): |
|
if self.session: |
|
try: |
|
if time.time() - self.last_frame_time > 1: |
|
self.last_frame_time = time.time() |
|
await self.session.send(encode_image(frame)) |
|
if self.latest_args[2] is not None: |
|
await self.session.send(encode_image(self.latest_args[2])) |
|
except Exception as e: |
|
logger.error(f"Error sending video frame: {e}") |
|
gr.Warning("Error sending video to Gemini.") |
|
self.video_queue.put_nowait(frame) |
|
|
|
async def video_emit(self) -> VideoEmitType: |
|
try: |
|
return await self.video_queue.get() |
|
except asyncio.CancelledError: |
|
logger.info("Video emit cancelled.") |
|
return None |
|
except Exception as e: |
|
logger.exception(f"Error in video_emit: {e}") |
|
return None |
|
|
|
async def connect(self, api_key: str): |
|
global gemini_connected |
|
if self.session is None: |
|
try: |
|
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"}) |
|
config = {"response_modalities": ["AUDIO"]} |
|
async with client.aio.live.connect( |
|
model="gemini-2.0-flash-exp", config=config |
|
) as session: |
|
self.session = session |
|
gemini_connected = True |
|
asyncio.create_task(self.receive_audio()) |
|
await self.quit.wait() |
|
except Exception as e: |
|
logger.error(f"Error connecting to Gemini: {e}") |
|
gemini_connected = False |
|
self.shutdown() |
|
gr.Warning(f"Failed to connect to Gemini: {e}") |
|
finally: |
|
update_gemini_status_sync() |
|
|
|
async def generator(self): |
|
if not self.session: |
|
logger.warning("Gemini session is not initialized.") |
|
return |
|
|
|
while not self.quit.is_set(): |
|
try: |
|
await asyncio.sleep(0) |
|
if self.quit.is_set(): |
|
break |
|
turn = self.session.receive() |
|
async for response in turn: |
|
if self.quit.is_set(): |
|
break |
|
if data := response.data: |
|
yield data |
|
except Exception as e: |
|
logger.error(f"Error receiving from Gemini: {e}") |
|
self.quit.set() |
|
break |
|
|
|
async def receive_audio(self): |
|
try: |
|
async for audio_response in async_aggregate_bytes_to_16bit(self.generator()): |
|
self.audio_queue.put_nowait(audio_response) |
|
except Exception as e: |
|
logger.exception(f"Error in receive_audio: {e}") |
|
|
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None: |
|
_, array = frame |
|
array = array.squeeze() |
|
try: |
|
audio_message = encode_audio(array) |
|
if self.session: |
|
await self.session.send(audio_message) |
|
except Exception as e: |
|
logger.error(f"Error sending audio: {e}") |
|
gr.Warning("Error sending audio to Gemini.") |
|
|
|
async def emit(self) -> AudioEmitType: |
|
if not self.args_set.is_set(): |
|
await self.wait_for_args() |
|
if self.session is None: |
|
asyncio.create_task(self.connect(self.latest_args[1])) |
|
|
|
try: |
|
array = await self.audio_queue.get() |
|
return (self.output_sample_rate, array) |
|
except asyncio.CancelledError: |
|
logger.info("Audio emit cancelled.") |
|
return (self.output_sample_rate, np.array([])) |
|
except Exception as e: |
|
logger.exception(f"Error in emit: {e}") |
|
return (self.output_sample_rate, np.array([])) |
|
|
|
def shutdown(self) -> None: |
|
global gemini_connected |
|
gemini_connected = False |
|
logger.info("Shutting down GeminiHandler.") |
|
if self.session: |
|
try: |
|
|
|
pass |
|
except Exception: |
|
pass |
|
self.quit.set() |
|
self.connection = None |
|
self.args_set.clear() |
|
|
|
self.quit.clear() |
|
update_gemini_status_sync() |
|
|
|
|
|
def update_gemini_status_sync(): |
|
"""Updates the Gemini status message (synchronous version).""" |
|
status = "β
Gemini: Connected" if gemini_connected else "β Gemini: Disconnected" |
|
if 'demo' in locals() and demo.running: |
|
gr.update(value=status) |
|
|
|
|
|
|
|
|
|
css = """ |
|
#video-source {max-width: 600px !important; max-height: 600 !important;} |
|
""" |
|
|
|
|
|
if __name__ == "__main__": |
|
check_twilio_availability_sync() |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML( |
|
""" |
|
<div style='display: flex; align-items: center; justify-content: center; gap: 20px'> |
|
<div style="background-color: var(--block-background-fill); border-radius: 8px"> |
|
<img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;"> |
|
</div> |
|
<div> |
|
<h1>Gen AI SDK Voice Chat</h1> |
|
<p>Speak with Gemini using real-time audio + video streaming</p> |
|
<p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>β‘οΈ</p> |
|
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
twilio_status_message = gr.Markdown("β Twilio: Checking...") |
|
gemini_status_message = gr.Markdown("β Gemini: Checking...") |
|
|
|
with gr.Row() as api_key_row: |
|
api_key = gr.Textbox( |
|
label="API Key", |
|
type="password", |
|
placeholder="Enter your API Key", |
|
value=os.getenv("GOOGLE_API_KEY"), |
|
) |
|
with gr.Row(visible=False) as row: |
|
with gr.Column(): |
|
|
|
rtc_config = get_twilio_turn_credentials() if twilio_available else None |
|
|
|
if rtc_config: |
|
rtc_config['codecs'] = ['VP8', 'H264'] |
|
webrtc = WebRTC( |
|
label="Video Chat", |
|
modality="audio-video", |
|
mode="send-receive", |
|
elem_id="video-source", |
|
rtc_configuration=rtc_config, |
|
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", |
|
pulse_color="rgb(35, 157, 225)", |
|
icon_button_color="rgb(35, 157, 225)", |
|
) |
|
with gr.Column(): |
|
image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]) |
|
|
|
|
|
def update_twilio_status_ui(): |
|
if twilio_available: |
|
message = "β
Twilio: Available" |
|
else: |
|
message = "β Twilio: Unavailable (connection may be less reliable)" |
|
return gr.update(value=message) |
|
|
|
demo.load(update_twilio_status_ui, [], [twilio_status_message]) |
|
|
|
handler = GeminiHandler() |
|
webrtc.stream( |
|
handler, |
|
inputs=[webrtc, api_key, image_input], |
|
outputs=[webrtc], |
|
time_limit=90, |
|
concurrency_limit=None, |
|
) |
|
|
|
|
|
def check_api_key(api_key_str): |
|
if not api_key_str: |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(value="Please enter a valid API key"), |
|
gr.update(value="β Gemini: Checking..."), |
|
) |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(value=""), |
|
gr.update(value="β Gemini: Checking..."), |
|
) |
|
|
|
api_key.submit( |
|
check_api_key, |
|
[api_key], |
|
[api_key_row, row, twilio_status_message, gemini_status_message], |
|
) |
|
|
|
|
|
if os.getenv("GOOGLE_API_KEY"): |
|
demo.load( |
|
lambda: (gr.update(visible=False), gr.update(visible=True)), |
|
None, |
|
[api_key_row, row], |
|
) |
|
|
|
demo.launch() |