akhaliq HF staff commited on
Commit
d1eb1fb
1 Parent(s): 38c74e2

update gemini voice to use gemini-gradio

Browse files
Files changed (2) hide show
  1. app_gemini_voice.py +12 -188
  2. requirements.txt +1 -1
app_gemini_voice.py CHANGED
@@ -1,194 +1,18 @@
1
- import base64
2
- import json
3
  import os
4
 
5
- import gradio as gr
6
- import numpy as np
7
- import websockets.sync.client
8
- from dotenv import load_dotenv
9
- from gradio_webrtc import StreamHandler, WebRTC, get_twilio_turn_credentials
10
 
 
11
 
12
- class GeminiConfig:
13
- def __init__(self):
14
- load_dotenv()
15
- self.api_key = self._get_api_key()
16
- self.host = "generativelanguage.googleapis.com"
17
- self.model = "models/gemini-2.0-flash-exp"
18
- self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
19
-
20
- def _get_api_key(self):
21
- api_key = os.getenv("GOOGLE_API_KEY")
22
- if not api_key:
23
- raise ValueError("GOOGLE_API_KEY not found in environment variables. Please set it in your .env file.")
24
- return api_key
25
-
26
-
27
- class AudioProcessor:
28
- @staticmethod
29
- def encode_audio(data, sample_rate):
30
- encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
31
- return {
32
- "realtimeInput": {
33
- "mediaChunks": [
34
- {
35
- "mimeType": f"audio/pcm;rate={sample_rate}",
36
- "data": encoded,
37
- }
38
- ],
39
- },
40
- }
41
-
42
- @staticmethod
43
- def process_audio_response(data):
44
- audio_data = base64.b64decode(data)
45
- return np.frombuffer(audio_data, dtype=np.int16)
46
-
47
-
48
- class GeminiHandler(StreamHandler):
49
- def __init__(self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480) -> None:
50
- super().__init__(expected_layout, output_sample_rate, output_frame_size, input_sample_rate=24000)
51
- self.config = GeminiConfig()
52
- self.ws = None
53
- self.all_output_data = None
54
- self.audio_processor = AudioProcessor()
55
-
56
- def copy(self):
57
- return GeminiHandler(
58
- expected_layout=self.expected_layout,
59
- output_sample_rate=self.output_sample_rate,
60
- output_frame_size=self.output_frame_size,
61
- )
62
-
63
- def _initialize_websocket(self):
64
- try:
65
- self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30)
66
- initial_request = {
67
- "setup": {
68
- "model": self.config.model,
69
- }
70
- }
71
- self.ws.send(json.dumps(initial_request))
72
- setup_response = json.loads(self.ws.recv())
73
- print(f"Setup response: {setup_response}")
74
- except websockets.exceptions.WebSocketException as e:
75
- print(f"WebSocket connection failed: {str(e)}")
76
- self.ws = None
77
- except Exception as e:
78
- print(f"Setup failed: {str(e)}")
79
- self.ws = None
80
-
81
- def receive(self, frame: tuple[int, np.ndarray]) -> None:
82
- try:
83
- if not self.ws:
84
- self._initialize_websocket()
85
-
86
- _, array = frame
87
- array = array.squeeze()
88
- audio_message = self.audio_processor.encode_audio(array, self.output_sample_rate)
89
- self.ws.send(json.dumps(audio_message)) # type: ignore
90
- except Exception as e:
91
- print(f"Error in receive: {str(e)}")
92
- if self.ws:
93
- self.ws.close()
94
- self.ws = None
95
-
96
- def _process_server_content(self, content):
97
- for part in content.get("parts", []):
98
- data = part.get("inlineData", {}).get("data", "")
99
- if data:
100
- audio_array = self.audio_processor.process_audio_response(data)
101
- if self.all_output_data is None:
102
- self.all_output_data = audio_array
103
- else:
104
- self.all_output_data = np.concatenate((self.all_output_data, audio_array))
105
-
106
- while self.all_output_data.shape[-1] >= self.output_frame_size:
107
- yield (self.output_sample_rate, self.all_output_data[: self.output_frame_size].reshape(1, -1))
108
- self.all_output_data = self.all_output_data[self.output_frame_size :]
109
-
110
- def generator(self):
111
- while True:
112
- if not self.ws:
113
- print("WebSocket not connected")
114
- yield None
115
- continue
116
-
117
- try:
118
- message = self.ws.recv(timeout=5)
119
- msg = json.loads(message)
120
-
121
- if "serverContent" in msg:
122
- content = msg["serverContent"].get("modelTurn", {})
123
- yield from self._process_server_content(content)
124
- except TimeoutError:
125
- print("Timeout waiting for server response")
126
- yield None
127
- except Exception as e:
128
- print(f"Error in generator: {str(e)}")
129
- yield None
130
-
131
- def emit(self) -> tuple[int, np.ndarray] | None:
132
- if not self.ws:
133
- return None
134
- if not hasattr(self, "_generator"):
135
- self._generator = self.generator()
136
- try:
137
- return next(self._generator)
138
- except StopIteration:
139
- self.reset()
140
- return None
141
-
142
- def reset(self) -> None:
143
- if hasattr(self, "_generator"):
144
- delattr(self, "_generator")
145
- self.all_output_data = None
146
-
147
- def shutdown(self) -> None:
148
- if self.ws:
149
- self.ws.close()
150
-
151
- def check_connection(self):
152
- try:
153
- if not self.ws or self.ws.closed:
154
- self._initialize_websocket()
155
- return True
156
- except Exception as e:
157
- print(f"Connection check failed: {str(e)}")
158
- return False
159
-
160
-
161
- class GeminiVoiceChat:
162
- def __init__(self):
163
- load_dotenv()
164
- self.demo = self._create_interface()
165
-
166
- def _create_interface(self):
167
- with gr.Blocks() as demo:
168
- gr.HTML(
169
- """
170
- <div style='text-align: center'>
171
- <h1>Gemini 2.0 Voice Chat</h1>
172
- <p>Speak with Gemini using real-time audio streaming</p>
173
- </div>
174
- """
175
- )
176
-
177
- webrtc = WebRTC(
178
- label="Conversation",
179
- modality="audio",
180
- mode="send-receive",
181
- rtc_configuration=get_twilio_turn_credentials(),
182
- )
183
-
184
- webrtc.stream(GeminiHandler(), inputs=[webrtc], outputs=[webrtc], time_limit=90, concurrency_limit=10)
185
- return demo
186
-
187
- def launch(self):
188
- self.demo.launch()
189
-
190
-
191
- demo = GeminiVoiceChat().demo
192
 
193
  if __name__ == "__main__":
194
- demo.launch(server_name="0.0.0.0")
 
 
 
1
  import os
2
 
3
+ import gemini_gradio
 
 
 
 
4
 
5
+ from utils import get_app
6
 
7
+ demo = get_app(
8
+ models=[
9
+ "gemini-2.0-flash-exp",
10
+ ],
11
+ default_model="gemini-2.0-flash-exp",
12
+ src=gemini_gradio.registry,
13
+ accept_token=not os.getenv("GEMINI_API_KEY"),
14
+ enable_voice=True,
15
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  if __name__ == "__main__":
18
+ demo.launch()
requirements.txt CHANGED
@@ -113,7 +113,7 @@ fsspec==2024.10.0
113
  # via
114
  # gradio-client
115
  # huggingface-hub
116
- gemini-gradio==0.0.2
117
  # via anychat (pyproject.toml)
118
  google-ai-generativelanguage==0.6.10
119
  # via google-generativeai
 
113
  # via
114
  # gradio-client
115
  # huggingface-hub
116
+ gemini-gradio==0.0.3
117
  # via anychat (pyproject.toml)
118
  google-ai-generativelanguage==0.6.10
119
  # via google-generativeai