marcosremar2 commited on
Commit
c85077c
·
1 Parent(s): 218771e
Files changed (1) hide show
  1. audio_interface.py +411 -0
audio_interface.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Audio interface for LLaMA-Omni2 that accepts audio input and returns audio output.
4
+ This interface:
5
+ 1. Transcribes audio input using Whisper
6
+ 2. Processes the transcription with LLaMA-Omni2 model
7
+ 3. Synthesizes the response back to audio using CosyVoice 2
8
+
9
+ Enhanced with streaming generation and read-write scheduling for real-time response.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import argparse
15
+ import logging
16
+ import time
17
+ import asyncio
18
+ import tempfile
19
+ from pathlib import Path
20
+ from queue import Queue
21
+ from threading import Thread
22
+ import json
23
+
24
+ import torch
25
+ import torchaudio
26
+ import gradio as gr
27
+ import whisper
28
+ import aiohttp
29
+ import numpy as np
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ class AudioInterface:
36
+ def __init__(
37
+ self,
38
+ controller_url: str,
39
+ whisper_model_path: str,
40
+ vocoder_dir: str,
41
+ model_name: str = "LLaMA-Omni2-7B-Bilingual",
42
+ read_tokens: int = 3,
43
+ write_tokens: int = 10
44
+ ):
45
+ self.controller_url = controller_url
46
+ self.whisper_model_path = whisper_model_path
47
+ self.vocoder_dir = vocoder_dir
48
+ self.model_name = model_name
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ # Read-write scheduling parameters for streaming generation
52
+ self.read_tokens = read_tokens # Number of text tokens to read
53
+ self.write_tokens = write_tokens # Number of speech tokens to write
54
+
55
+ # Load Whisper model
56
+ try:
57
+ logger.info(f"Loading Whisper model from {whisper_model_path}")
58
+ self.whisper_model = whisper.load_model("large-v3",
59
+ download_root=whisper_model_path,
60
+ device=self.device)
61
+ logger.info("Whisper model loaded successfully")
62
+ except Exception as e:
63
+ logger.error(f"Failed to load Whisper model: {e}")
64
+ self.whisper_model = None
65
+
66
+ # Load CosyVoice vocoder
67
+ try:
68
+ sys.path.insert(0, vocoder_dir)
69
+ from cosy_voice_2.inference import CosyVoice
70
+
71
+ self.vocoder = CosyVoice(
72
+ device=self.device,
73
+ model_path=vocoder_dir
74
+ )
75
+ logger.info(f"CosyVoice vocoder loaded from {vocoder_dir}")
76
+ except Exception as e:
77
+ logger.error(f"Failed to load CosyVoice vocoder: {e}")
78
+ self.vocoder = None
79
+
80
+ logger.info(f"Using LLaMA-Omni2 model: {model_name}")
81
+
82
+ async def get_worker_address(self):
83
+ """Get the address of the worker serving the model"""
84
+ try:
85
+ async with aiohttp.ClientSession() as session:
86
+ async with session.get(
87
+ f"{self.controller_url}/get_worker_address?model_name={self.model_name}",
88
+ timeout=30
89
+ ) as response:
90
+ if response.status == 200:
91
+ data = await response.json()
92
+ return data.get("address")
93
+ else:
94
+ logger.error(f"Failed to get worker address: {await response.text()}")
95
+ return None
96
+ except Exception as e:
97
+ logger.error(f"Error getting worker address: {e}")
98
+ return None
99
+
100
+ async def generate_text(self, prompt: str, streaming=False):
101
+ """Generate text from LLaMA-Omni2 model"""
102
+ worker_addr = await self.get_worker_address()
103
+ if not worker_addr:
104
+ return f"Error: No worker available for model {self.model_name}"
105
+
106
+ try:
107
+ async with aiohttp.ClientSession() as session:
108
+ # For streaming generation
109
+ if streaming:
110
+ async with session.post(
111
+ f"{worker_addr}/generate_stream",
112
+ json={"prompt": prompt},
113
+ timeout=120
114
+ ) as response:
115
+ if response.status == 200:
116
+ response_text = ""
117
+ async for line in response.content:
118
+ if line:
119
+ data = json.loads(line)
120
+ chunk = data.get("text", "")
121
+ response_text += chunk
122
+ yield response_text
123
+ return response_text
124
+ else:
125
+ error_text = await response.text()
126
+ logger.error(f"Failed to generate text stream: {error_text}")
127
+ return f"Error: {error_text}"
128
+ # For non-streaming generation
129
+ else:
130
+ async with session.post(
131
+ f"{worker_addr}/generate",
132
+ json={"prompt": prompt},
133
+ timeout=120
134
+ ) as response:
135
+ if response.status == 200:
136
+ data = await response.json()
137
+ return data.get("response", "No response received from model")
138
+ else:
139
+ error_text = await response.text()
140
+ logger.error(f"Failed to generate text: {error_text}")
141
+ return f"Error: {error_text}"
142
+ except Exception as e:
143
+ logger.error(f"Error generating text: {e}")
144
+ return f"Error: {str(e)}"
145
+
146
+ def transcribe_audio(self, audio_path):
147
+ """Transcribe audio using Whisper"""
148
+ if self.whisper_model is None:
149
+ return "Error: Whisper model not loaded"
150
+
151
+ try:
152
+ logger.info(f"Transcribing audio from {audio_path}")
153
+ result = self.whisper_model.transcribe(audio_path)
154
+ logger.info("Transcription completed")
155
+ return result["text"]
156
+ except Exception as e:
157
+ logger.error(f"Error transcribing audio: {e}")
158
+ return f"Error transcribing audio: {str(e)}"
159
+
160
+ def synthesize_speech(self, text):
161
+ """Synthesize speech from text using CosyVoice"""
162
+ if self.vocoder is None:
163
+ return None, 16000, "Error: Vocoder not loaded"
164
+
165
+ try:
166
+ logger.info("Synthesizing speech from text response")
167
+ # Generate speech using CosyVoice
168
+ waveform = self.vocoder.inference(text)
169
+ sample_rate = self.vocoder.sample_rate
170
+
171
+ # Convert to numpy array for Gradio
172
+ if isinstance(waveform, torch.Tensor):
173
+ waveform = waveform.cpu().numpy()
174
+
175
+ logger.info("Speech synthesis completed")
176
+ return waveform, sample_rate, None
177
+ except Exception as e:
178
+ logger.error(f"Error synthesizing speech: {e}")
179
+ return None, 16000, f"Error synthesizing speech: {str(e)}"
180
+
181
+ async def synthesize_speech_chunk(self, text_chunk):
182
+ """Synthesize speech for a single text chunk"""
183
+ if self.vocoder is None:
184
+ return None, 16000, "Error: Vocoder not loaded"
185
+
186
+ try:
187
+ # Generate speech using CosyVoice for this chunk
188
+ waveform = self.vocoder.inference(text_chunk)
189
+ sample_rate = self.vocoder.sample_rate
190
+
191
+ # Convert to numpy array
192
+ if isinstance(waveform, torch.Tensor):
193
+ waveform = waveform.cpu().numpy()
194
+
195
+ return waveform, sample_rate, None
196
+ except Exception as e:
197
+ logger.error(f"Error synthesizing speech chunk: {e}")
198
+ return None, 16000, f"Error synthesizing speech chunk: {str(e)}"
199
+
200
+ async def stream_text_to_speech(self, text_generator):
201
+ """Stream text to speech using read-write scheduling"""
202
+ buffer = ""
203
+ audio_chunks = []
204
+
205
+ try:
206
+ async for text in text_generator:
207
+ # Accumulate text until we have enough to synthesize
208
+ buffer += text
209
+
210
+ # When we have enough tokens for synthesis (approximate by characters)
211
+ if len(buffer.split()) >= self.read_tokens:
212
+ # Process the buffer
213
+ chunk_to_process = buffer
214
+ buffer = ""
215
+
216
+ # Synthesize this chunk
217
+ audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(chunk_to_process)
218
+ if error:
219
+ logger.error(f"Error in streaming synthesis: {error}")
220
+ continue
221
+
222
+ # Add to our collection of audio chunks
223
+ audio_chunks.append(audio_chunk)
224
+
225
+ # Yield the current concatenated audio
226
+ if audio_chunks:
227
+ # Concatenate audio chunks
228
+ full_audio = np.concatenate(audio_chunks)
229
+ yield full_audio, sample_rate, chunk_to_process
230
+
231
+ # Process any remaining text in the buffer
232
+ if buffer:
233
+ audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(buffer)
234
+ if not error and audio_chunk is not None:
235
+ audio_chunks.append(audio_chunk)
236
+
237
+ # Final audio output
238
+ if audio_chunks:
239
+ full_audio = np.concatenate(audio_chunks)
240
+ return full_audio, sample_rate, None
241
+ else:
242
+ return None, 16000, "No audio generated"
243
+
244
+ except Exception as e:
245
+ logger.error(f"Error in streaming text to speech: {e}")
246
+ return None, 16000, f"Error in streaming text to speech: {str(e)}"
247
+
248
+ async def process_audio(self, audio_data, sample_rate, streaming=False):
249
+ """Process audio input and return audio output"""
250
+ # Save the input audio to a temporary file
251
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
252
+ temp_path = temp_audio.name
253
+ # Convert sample rate if needed
254
+ if sample_rate != 16000:
255
+ resampler = torchaudio.transforms.Resample(
256
+ orig_freq=sample_rate, new_freq=16000
257
+ )
258
+ audio_tensor = torch.tensor(audio_data).unsqueeze(0)
259
+ audio_tensor = resampler(audio_tensor)
260
+ audio_data = audio_tensor.squeeze(0).numpy()
261
+ sample_rate = 16000
262
+
263
+ # Save as WAV
264
+ torchaudio.save(temp_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
265
+
266
+ try:
267
+ # Step 1: Transcribe audio
268
+ transcription = self.transcribe_audio(temp_path)
269
+ if transcription.startswith("Error"):
270
+ return None, sample_rate, transcription, "Error occurred during transcription", transcription
271
+
272
+ # Step 2: Process with LLaMA-Omni2
273
+ if streaming:
274
+ # For streaming mode, we use a generator
275
+ text_generator = self.generate_text(transcription, streaming=True)
276
+ audio_generator = self.stream_text_to_speech(text_generator)
277
+ return audio_generator, transcription
278
+ else:
279
+ # For non-streaming mode
280
+ response_text = await self.generate_text(transcription)
281
+ if response_text.startswith("Error"):
282
+ return None, sample_rate, transcription, response_text, response_text
283
+
284
+ # Step 3: Synthesize speech
285
+ audio_output, out_sample_rate, error = self.synthesize_speech(response_text)
286
+ if error:
287
+ return None, sample_rate, transcription, response_text, error
288
+
289
+ return audio_output, out_sample_rate, transcription, response_text, None
290
+ finally:
291
+ # Clean up temporary file
292
+ if os.path.exists(temp_path):
293
+ os.unlink(temp_path)
294
+
295
+ def build_interface(self):
296
+ """Build Gradio interface"""
297
+ with gr.Blocks(title="LLaMA-Omni2 Audio Interface") as demo:
298
+ gr.Markdown("# LLaMA-Omni2 Audio Interface")
299
+ gr.Markdown("Speak to LLaMA-Omni2 and hear its response in real-time")
300
+
301
+ with gr.Row():
302
+ with gr.Column():
303
+ audio_input = gr.Audio(
304
+ sources=["microphone", "upload"],
305
+ type="numpy",
306
+ label="Input Audio"
307
+ )
308
+ with gr.Row():
309
+ submit_button = gr.Button("Process Audio", variant="primary")
310
+ stream_button = gr.Button("Stream Audio Response", variant="secondary")
311
+
312
+ with gr.Column():
313
+ transcription = gr.Textbox(
314
+ label="Transcription",
315
+ interactive=False
316
+ )
317
+ response_text = gr.Textbox(
318
+ label="Response Text",
319
+ interactive=False
320
+ )
321
+ audio_output = gr.Audio(
322
+ label="Response Audio",
323
+ type="numpy",
324
+ interactive=False
325
+ )
326
+ error_text = gr.Textbox(
327
+ label="Errors (if any)",
328
+ interactive=False,
329
+ visible=False
330
+ )
331
+
332
+ async def process_wrapper(audio_data):
333
+ if audio_data is None:
334
+ return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
335
+
336
+ audio_array, sample_rate = audio_data
337
+ output, out_sample_rate, trans, resp, error = await self.process_audio(audio_array, sample_rate, streaming=False)
338
+
339
+ if error:
340
+ gr.update(visible=True)
341
+ return None, trans, resp, error
342
+
343
+ return (output, out_sample_rate), trans, resp, ""
344
+
345
+ async def stream_wrapper(audio_data):
346
+ if audio_data is None:
347
+ return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
348
+
349
+ audio_array, sample_rate = audio_data
350
+ generator, transcription = await self.process_audio(audio_array, sample_rate, streaming=True)
351
+
352
+ # Update transcription immediately
353
+ yield None, transcription, "", ""
354
+
355
+ # Start streaming
356
+ current_text = ""
357
+ async for audio_chunk, sr, text_chunk in generator:
358
+ current_text += text_chunk
359
+ yield (audio_chunk, sr), transcription, current_text, ""
360
+
361
+ submit_button.click(
362
+ fn=lambda audio: asyncio.create_task(process_wrapper(audio)),
363
+ inputs=[audio_input],
364
+ outputs=[audio_output, transcription, response_text, error_text]
365
+ )
366
+
367
+ stream_button.click(
368
+ fn=lambda audio: stream_wrapper(audio),
369
+ inputs=[audio_input],
370
+ outputs=[audio_output, transcription, response_text, error_text]
371
+ )
372
+
373
+ return demo
374
+
375
+
376
+ def main():
377
+ parser = argparse.ArgumentParser(description="Audio interface for LLaMA-Omni2")
378
+ parser.add_argument("--host", type=str, default="0.0.0.0")
379
+ parser.add_argument("--port", type=int, default=7860)
380
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
381
+ parser.add_argument("--whisper-model-path", type=str, default="models/speech_encoder")
382
+ parser.add_argument("--vocoder-dir", type=str, default="models/cosy2_decoder")
383
+ parser.add_argument("--model-name", type=str, default="LLaMA-Omni2-7B-Bilingual")
384
+ parser.add_argument("--read-tokens", type=int, default=3,
385
+ help="Number of text tokens to read before generating speech")
386
+ parser.add_argument("--write-tokens", type=int, default=10,
387
+ help="Number of speech tokens to write for each read")
388
+ parser.add_argument("--share", action="store_true", help="Create a public link")
389
+ args = parser.parse_args()
390
+
391
+ # Create the interface
392
+ interface = AudioInterface(
393
+ controller_url=args.controller_url,
394
+ whisper_model_path=args.whisper_model_path,
395
+ vocoder_dir=args.vocoder_dir,
396
+ model_name=args.model_name,
397
+ read_tokens=args.read_tokens,
398
+ write_tokens=args.write_tokens
399
+ )
400
+
401
+ # Build and launch the interface
402
+ demo = interface.build_interface()
403
+ demo.queue()
404
+ demo.launch(
405
+ server_name=args.host,
406
+ server_port=args.port,
407
+ share=args.share
408
+ )
409
+
410
+ if __name__ == "__main__":
411
+ main()