#!/usr/bin/env python3 """ Audio interface for LLaMA-Omni2 that accepts audio input and returns audio output. This interface: 1. Transcribes audio input using Whisper 2. Processes the transcription with LLaMA-Omni2 model 3. Synthesizes the response back to audio using CosyVoice 2 Enhanced with streaming generation and read-write scheduling for real-time response. """ import os import sys import argparse import logging import time import asyncio import tempfile from pathlib import Path from queue import Queue from threading import Thread import json import torch import torchaudio import gradio as gr import whisper import aiohttp import numpy as np # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class AudioInterface: def __init__( self, controller_url: str, whisper_model_path: str, vocoder_dir: str, model_name: str = "LLaMA-Omni2-7B-Bilingual", read_tokens: int = 3, write_tokens: int = 10 ): self.controller_url = controller_url self.whisper_model_path = whisper_model_path self.vocoder_dir = vocoder_dir self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" # Read-write scheduling parameters for streaming generation self.read_tokens = read_tokens # Number of text tokens to read self.write_tokens = write_tokens # Number of speech tokens to write # Load Whisper model try: logger.info(f"Loading Whisper model from {whisper_model_path}") self.whisper_model = whisper.load_model("large-v3", download_root=whisper_model_path, device=self.device) logger.info("Whisper model loaded successfully") except Exception as e: logger.error(f"Failed to load Whisper model: {e}") self.whisper_model = None # Load CosyVoice vocoder try: sys.path.insert(0, vocoder_dir) from cosy_voice_2.inference import CosyVoice self.vocoder = CosyVoice( device=self.device, model_path=vocoder_dir ) logger.info(f"CosyVoice vocoder loaded from {vocoder_dir}") except Exception as e: logger.error(f"Failed to load CosyVoice vocoder: {e}") self.vocoder = None logger.info(f"Using LLaMA-Omni2 model: {model_name}") async def get_worker_address(self): """Get the address of the worker serving the model""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.controller_url}/get_worker_address?model_name={self.model_name}", timeout=30 ) as response: if response.status == 200: data = await response.json() return data.get("address") else: logger.error(f"Failed to get worker address: {await response.text()}") return None except Exception as e: logger.error(f"Error getting worker address: {e}") return None async def generate_text(self, prompt: str, streaming=False): """Generate text from LLaMA-Omni2 model""" worker_addr = await self.get_worker_address() if not worker_addr: return f"Error: No worker available for model {self.model_name}" try: async with aiohttp.ClientSession() as session: # For streaming generation if streaming: async with session.post( f"{worker_addr}/generate_stream", json={"prompt": prompt}, timeout=120 ) as response: if response.status == 200: response_text = "" async for line in response.content: if line: data = json.loads(line) chunk = data.get("text", "") response_text += chunk yield response_text return response_text else: error_text = await response.text() logger.error(f"Failed to generate text stream: {error_text}") return f"Error: {error_text}" # For non-streaming generation else: async with session.post( f"{worker_addr}/generate", json={"prompt": prompt}, timeout=120 ) as response: if response.status == 200: data = await response.json() return data.get("response", "No response received from model") else: error_text = await response.text() logger.error(f"Failed to generate text: {error_text}") return f"Error: {error_text}" except Exception as e: logger.error(f"Error generating text: {e}") return f"Error: {str(e)}" def transcribe_audio(self, audio_path): """Transcribe audio using Whisper""" if self.whisper_model is None: return "Error: Whisper model not loaded" try: logger.info(f"Transcribing audio from {audio_path}") result = self.whisper_model.transcribe(audio_path) logger.info("Transcription completed") return result["text"] except Exception as e: logger.error(f"Error transcribing audio: {e}") return f"Error transcribing audio: {str(e)}" def synthesize_speech(self, text): """Synthesize speech from text using CosyVoice""" if self.vocoder is None: return None, 16000, "Error: Vocoder not loaded" try: logger.info("Synthesizing speech from text response") # Generate speech using CosyVoice waveform = self.vocoder.inference(text) sample_rate = self.vocoder.sample_rate # Convert to numpy array for Gradio if isinstance(waveform, torch.Tensor): waveform = waveform.cpu().numpy() logger.info("Speech synthesis completed") return waveform, sample_rate, None except Exception as e: logger.error(f"Error synthesizing speech: {e}") return None, 16000, f"Error synthesizing speech: {str(e)}" async def synthesize_speech_chunk(self, text_chunk): """Synthesize speech for a single text chunk""" if self.vocoder is None: return None, 16000, "Error: Vocoder not loaded" try: # Generate speech using CosyVoice for this chunk waveform = self.vocoder.inference(text_chunk) sample_rate = self.vocoder.sample_rate # Convert to numpy array if isinstance(waveform, torch.Tensor): waveform = waveform.cpu().numpy() return waveform, sample_rate, None except Exception as e: logger.error(f"Error synthesizing speech chunk: {e}") return None, 16000, f"Error synthesizing speech chunk: {str(e)}" async def stream_text_to_speech(self, text_generator): """Stream text to speech using read-write scheduling""" buffer = "" audio_chunks = [] try: async for text in text_generator: # Accumulate text until we have enough to synthesize buffer += text # When we have enough tokens for synthesis (approximate by characters) if len(buffer.split()) >= self.read_tokens: # Process the buffer chunk_to_process = buffer buffer = "" # Synthesize this chunk audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(chunk_to_process) if error: logger.error(f"Error in streaming synthesis: {error}") continue # Add to our collection of audio chunks audio_chunks.append(audio_chunk) # Yield the current concatenated audio if audio_chunks: # Concatenate audio chunks full_audio = np.concatenate(audio_chunks) yield full_audio, sample_rate, chunk_to_process # Process any remaining text in the buffer if buffer: audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(buffer) if not error and audio_chunk is not None: audio_chunks.append(audio_chunk) # Final audio output if audio_chunks: full_audio = np.concatenate(audio_chunks) return full_audio, sample_rate, None else: return None, 16000, "No audio generated" except Exception as e: logger.error(f"Error in streaming text to speech: {e}") return None, 16000, f"Error in streaming text to speech: {str(e)}" async def process_audio(self, audio_data, sample_rate, streaming=False): """Process audio input and return audio output""" # Save the input audio to a temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: temp_path = temp_audio.name # Convert sample rate if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 ) audio_tensor = torch.tensor(audio_data).unsqueeze(0) audio_tensor = resampler(audio_tensor) audio_data = audio_tensor.squeeze(0).numpy() sample_rate = 16000 # Save as WAV torchaudio.save(temp_path, torch.tensor(audio_data).unsqueeze(0), sample_rate) try: # Step 1: Transcribe audio transcription = self.transcribe_audio(temp_path) if transcription.startswith("Error"): return None, sample_rate, transcription, "Error occurred during transcription", transcription # Step 2: Process with LLaMA-Omni2 if streaming: # For streaming mode, we use a generator text_generator = self.generate_text(transcription, streaming=True) audio_generator = self.stream_text_to_speech(text_generator) return audio_generator, transcription else: # For non-streaming mode response_text = await self.generate_text(transcription) if response_text.startswith("Error"): return None, sample_rate, transcription, response_text, response_text # Step 3: Synthesize speech audio_output, out_sample_rate, error = self.synthesize_speech(response_text) if error: return None, sample_rate, transcription, response_text, error return audio_output, out_sample_rate, transcription, response_text, None finally: # Clean up temporary file if os.path.exists(temp_path): os.unlink(temp_path) def build_interface(self): """Build Gradio interface""" with gr.Blocks(title="LLaMA-Omni2 Audio Interface") as demo: gr.Markdown("# LLaMA-Omni2 Audio Interface") gr.Markdown("Speak to LLaMA-Omni2 and hear its response in real-time") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["microphone", "upload"], type="numpy", label="Input Audio" ) with gr.Row(): submit_button = gr.Button("Process Audio", variant="primary") stream_button = gr.Button("Stream Audio Response", variant="secondary") with gr.Column(): transcription = gr.Textbox( label="Transcription", interactive=False ) response_text = gr.Textbox( label="Response Text", interactive=False ) audio_output = gr.Audio( label="Response Audio", type="numpy", interactive=False ) error_text = gr.Textbox( label="Errors (if any)", interactive=False, visible=False ) async def process_wrapper(audio_data): if audio_data is None: return None, "No audio input detected", "Please record or upload audio", "No audio input detected" audio_array, sample_rate = audio_data output, out_sample_rate, trans, resp, error = await self.process_audio(audio_array, sample_rate, streaming=False) if error: gr.update(visible=True) return None, trans, resp, error return (output, out_sample_rate), trans, resp, "" async def stream_wrapper(audio_data): if audio_data is None: return None, "No audio input detected", "Please record or upload audio", "No audio input detected" audio_array, sample_rate = audio_data generator, transcription = await self.process_audio(audio_array, sample_rate, streaming=True) # Update transcription immediately yield None, transcription, "", "" # Start streaming current_text = "" async for audio_chunk, sr, text_chunk in generator: current_text += text_chunk yield (audio_chunk, sr), transcription, current_text, "" submit_button.click( fn=lambda audio: asyncio.create_task(process_wrapper(audio)), inputs=[audio_input], outputs=[audio_output, transcription, response_text, error_text] ) stream_button.click( fn=lambda audio: stream_wrapper(audio), inputs=[audio_input], outputs=[audio_output, transcription, response_text, error_text] ) return demo def main(): parser = argparse.ArgumentParser(description="Audio interface for LLaMA-Omni2") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--controller-url", type=str, default="http://localhost:10000") parser.add_argument("--whisper-model-path", type=str, default="models/speech_encoder") parser.add_argument("--vocoder-dir", type=str, default="models/cosy2_decoder") parser.add_argument("--model-name", type=str, default="LLaMA-Omni2-7B-Bilingual") parser.add_argument("--read-tokens", type=int, default=3, help="Number of text tokens to read before generating speech") parser.add_argument("--write-tokens", type=int, default=10, help="Number of speech tokens to write for each read") parser.add_argument("--share", action="store_true", help="Create a public link") args = parser.parse_args() # Create the interface interface = AudioInterface( controller_url=args.controller_url, whisper_model_path=args.whisper_model_path, vocoder_dir=args.vocoder_dir, model_name=args.model_name, read_tokens=args.read_tokens, write_tokens=args.write_tokens ) # Build and launch the interface demo = interface.build_interface() demo.queue() demo.launch( server_name=args.host, server_port=args.port, share=args.share ) if __name__ == "__main__": main()