Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
""" | |
Audio interface for LLaMA-Omni2 that accepts audio input and returns audio output. | |
This interface: | |
1. Transcribes audio input using Whisper | |
2. Processes the transcription with LLaMA-Omni2 model | |
3. Synthesizes the response back to audio using CosyVoice 2 | |
Enhanced with streaming generation and read-write scheduling for real-time response. | |
""" | |
import os | |
import sys | |
import argparse | |
import logging | |
import time | |
import asyncio | |
import tempfile | |
from pathlib import Path | |
from queue import Queue | |
from threading import Thread | |
import json | |
import torch | |
import torchaudio | |
import gradio as gr | |
import whisper | |
import aiohttp | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class AudioInterface: | |
def __init__( | |
self, | |
controller_url: str, | |
whisper_model_path: str, | |
vocoder_dir: str, | |
model_name: str = "LLaMA-Omni2-7B-Bilingual", | |
read_tokens: int = 3, | |
write_tokens: int = 10 | |
): | |
self.controller_url = controller_url | |
self.whisper_model_path = whisper_model_path | |
self.vocoder_dir = vocoder_dir | |
self.model_name = model_name | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Read-write scheduling parameters for streaming generation | |
self.read_tokens = read_tokens # Number of text tokens to read | |
self.write_tokens = write_tokens # Number of speech tokens to write | |
# Load Whisper model | |
try: | |
logger.info(f"Loading Whisper model from {whisper_model_path}") | |
self.whisper_model = whisper.load_model("large-v3", | |
download_root=whisper_model_path, | |
device=self.device) | |
logger.info("Whisper model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load Whisper model: {e}") | |
self.whisper_model = None | |
# Load CosyVoice vocoder | |
try: | |
sys.path.insert(0, vocoder_dir) | |
from cosy_voice_2.inference import CosyVoice | |
self.vocoder = CosyVoice( | |
device=self.device, | |
model_path=vocoder_dir | |
) | |
logger.info(f"CosyVoice vocoder loaded from {vocoder_dir}") | |
except Exception as e: | |
logger.error(f"Failed to load CosyVoice vocoder: {e}") | |
self.vocoder = None | |
logger.info(f"Using LLaMA-Omni2 model: {model_name}") | |
async def get_worker_address(self): | |
"""Get the address of the worker serving the model""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
f"{self.controller_url}/get_worker_address?model_name={self.model_name}", | |
timeout=30 | |
) as response: | |
if response.status == 200: | |
data = await response.json() | |
return data.get("address") | |
else: | |
logger.error(f"Failed to get worker address: {await response.text()}") | |
return None | |
except Exception as e: | |
logger.error(f"Error getting worker address: {e}") | |
return None | |
async def generate_text(self, prompt: str, streaming=False): | |
"""Generate text from LLaMA-Omni2 model""" | |
worker_addr = await self.get_worker_address() | |
if not worker_addr: | |
return f"Error: No worker available for model {self.model_name}" | |
try: | |
async with aiohttp.ClientSession() as session: | |
# For streaming generation | |
if streaming: | |
async with session.post( | |
f"{worker_addr}/generate_stream", | |
json={"prompt": prompt}, | |
timeout=120 | |
) as response: | |
if response.status == 200: | |
response_text = "" | |
async for line in response.content: | |
if line: | |
data = json.loads(line) | |
chunk = data.get("text", "") | |
response_text += chunk | |
yield response_text | |
return response_text | |
else: | |
error_text = await response.text() | |
logger.error(f"Failed to generate text stream: {error_text}") | |
return f"Error: {error_text}" | |
# For non-streaming generation | |
else: | |
async with session.post( | |
f"{worker_addr}/generate", | |
json={"prompt": prompt}, | |
timeout=120 | |
) as response: | |
if response.status == 200: | |
data = await response.json() | |
return data.get("response", "No response received from model") | |
else: | |
error_text = await response.text() | |
logger.error(f"Failed to generate text: {error_text}") | |
return f"Error: {error_text}" | |
except Exception as e: | |
logger.error(f"Error generating text: {e}") | |
return f"Error: {str(e)}" | |
def transcribe_audio(self, audio_path): | |
"""Transcribe audio using Whisper""" | |
if self.whisper_model is None: | |
return "Error: Whisper model not loaded" | |
try: | |
logger.info(f"Transcribing audio from {audio_path}") | |
result = self.whisper_model.transcribe(audio_path) | |
logger.info("Transcription completed") | |
return result["text"] | |
except Exception as e: | |
logger.error(f"Error transcribing audio: {e}") | |
return f"Error transcribing audio: {str(e)}" | |
def synthesize_speech(self, text): | |
"""Synthesize speech from text using CosyVoice""" | |
if self.vocoder is None: | |
return None, 16000, "Error: Vocoder not loaded" | |
try: | |
logger.info("Synthesizing speech from text response") | |
# Generate speech using CosyVoice | |
waveform = self.vocoder.inference(text) | |
sample_rate = self.vocoder.sample_rate | |
# Convert to numpy array for Gradio | |
if isinstance(waveform, torch.Tensor): | |
waveform = waveform.cpu().numpy() | |
logger.info("Speech synthesis completed") | |
return waveform, sample_rate, None | |
except Exception as e: | |
logger.error(f"Error synthesizing speech: {e}") | |
return None, 16000, f"Error synthesizing speech: {str(e)}" | |
async def synthesize_speech_chunk(self, text_chunk): | |
"""Synthesize speech for a single text chunk""" | |
if self.vocoder is None: | |
return None, 16000, "Error: Vocoder not loaded" | |
try: | |
# Generate speech using CosyVoice for this chunk | |
waveform = self.vocoder.inference(text_chunk) | |
sample_rate = self.vocoder.sample_rate | |
# Convert to numpy array | |
if isinstance(waveform, torch.Tensor): | |
waveform = waveform.cpu().numpy() | |
return waveform, sample_rate, None | |
except Exception as e: | |
logger.error(f"Error synthesizing speech chunk: {e}") | |
return None, 16000, f"Error synthesizing speech chunk: {str(e)}" | |
async def stream_text_to_speech(self, text_generator): | |
"""Stream text to speech using read-write scheduling""" | |
buffer = "" | |
audio_chunks = [] | |
try: | |
async for text in text_generator: | |
# Accumulate text until we have enough to synthesize | |
buffer += text | |
# When we have enough tokens for synthesis (approximate by characters) | |
if len(buffer.split()) >= self.read_tokens: | |
# Process the buffer | |
chunk_to_process = buffer | |
buffer = "" | |
# Synthesize this chunk | |
audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(chunk_to_process) | |
if error: | |
logger.error(f"Error in streaming synthesis: {error}") | |
continue | |
# Add to our collection of audio chunks | |
audio_chunks.append(audio_chunk) | |
# Yield the current concatenated audio | |
if audio_chunks: | |
# Concatenate audio chunks | |
full_audio = np.concatenate(audio_chunks) | |
yield full_audio, sample_rate, chunk_to_process | |
# Process any remaining text in the buffer | |
if buffer: | |
audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(buffer) | |
if not error and audio_chunk is not None: | |
audio_chunks.append(audio_chunk) | |
# Final audio output | |
if audio_chunks: | |
full_audio = np.concatenate(audio_chunks) | |
return full_audio, sample_rate, None | |
else: | |
return None, 16000, "No audio generated" | |
except Exception as e: | |
logger.error(f"Error in streaming text to speech: {e}") | |
return None, 16000, f"Error in streaming text to speech: {str(e)}" | |
async def process_audio(self, audio_data, sample_rate, streaming=False): | |
"""Process audio input and return audio output""" | |
# Save the input audio to a temporary file | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
temp_path = temp_audio.name | |
# Convert sample rate if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, new_freq=16000 | |
) | |
audio_tensor = torch.tensor(audio_data).unsqueeze(0) | |
audio_tensor = resampler(audio_tensor) | |
audio_data = audio_tensor.squeeze(0).numpy() | |
sample_rate = 16000 | |
# Save as WAV | |
torchaudio.save(temp_path, torch.tensor(audio_data).unsqueeze(0), sample_rate) | |
try: | |
# Step 1: Transcribe audio | |
transcription = self.transcribe_audio(temp_path) | |
if transcription.startswith("Error"): | |
return None, sample_rate, transcription, "Error occurred during transcription", transcription | |
# Step 2: Process with LLaMA-Omni2 | |
if streaming: | |
# For streaming mode, we use a generator | |
text_generator = self.generate_text(transcription, streaming=True) | |
audio_generator = self.stream_text_to_speech(text_generator) | |
return audio_generator, transcription | |
else: | |
# For non-streaming mode | |
response_text = await self.generate_text(transcription) | |
if response_text.startswith("Error"): | |
return None, sample_rate, transcription, response_text, response_text | |
# Step 3: Synthesize speech | |
audio_output, out_sample_rate, error = self.synthesize_speech(response_text) | |
if error: | |
return None, sample_rate, transcription, response_text, error | |
return audio_output, out_sample_rate, transcription, response_text, None | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
def build_interface(self): | |
"""Build Gradio interface""" | |
with gr.Blocks(title="LLaMA-Omni2 Audio Interface") as demo: | |
gr.Markdown("# LLaMA-Omni2 Audio Interface") | |
gr.Markdown("Speak to LLaMA-Omni2 and hear its response in real-time") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="numpy", | |
label="Input Audio" | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Process Audio", variant="primary") | |
stream_button = gr.Button("Stream Audio Response", variant="secondary") | |
with gr.Column(): | |
transcription = gr.Textbox( | |
label="Transcription", | |
interactive=False | |
) | |
response_text = gr.Textbox( | |
label="Response Text", | |
interactive=False | |
) | |
audio_output = gr.Audio( | |
label="Response Audio", | |
type="numpy", | |
interactive=False | |
) | |
error_text = gr.Textbox( | |
label="Errors (if any)", | |
interactive=False, | |
visible=False | |
) | |
async def process_wrapper(audio_data): | |
if audio_data is None: | |
return None, "No audio input detected", "Please record or upload audio", "No audio input detected" | |
audio_array, sample_rate = audio_data | |
output, out_sample_rate, trans, resp, error = await self.process_audio(audio_array, sample_rate, streaming=False) | |
if error: | |
gr.update(visible=True) | |
return None, trans, resp, error | |
return (output, out_sample_rate), trans, resp, "" | |
async def stream_wrapper(audio_data): | |
if audio_data is None: | |
return None, "No audio input detected", "Please record or upload audio", "No audio input detected" | |
audio_array, sample_rate = audio_data | |
generator, transcription = await self.process_audio(audio_array, sample_rate, streaming=True) | |
# Update transcription immediately | |
yield None, transcription, "", "" | |
# Start streaming | |
current_text = "" | |
async for audio_chunk, sr, text_chunk in generator: | |
current_text += text_chunk | |
yield (audio_chunk, sr), transcription, current_text, "" | |
submit_button.click( | |
fn=lambda audio: asyncio.create_task(process_wrapper(audio)), | |
inputs=[audio_input], | |
outputs=[audio_output, transcription, response_text, error_text] | |
) | |
stream_button.click( | |
fn=lambda audio: stream_wrapper(audio), | |
inputs=[audio_input], | |
outputs=[audio_output, transcription, response_text, error_text] | |
) | |
return demo | |
def main(): | |
parser = argparse.ArgumentParser(description="Audio interface for LLaMA-Omni2") | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--controller-url", type=str, default="http://localhost:10000") | |
parser.add_argument("--whisper-model-path", type=str, default="models/speech_encoder") | |
parser.add_argument("--vocoder-dir", type=str, default="models/cosy2_decoder") | |
parser.add_argument("--model-name", type=str, default="LLaMA-Omni2-7B-Bilingual") | |
parser.add_argument("--read-tokens", type=int, default=3, | |
help="Number of text tokens to read before generating speech") | |
parser.add_argument("--write-tokens", type=int, default=10, | |
help="Number of speech tokens to write for each read") | |
parser.add_argument("--share", action="store_true", help="Create a public link") | |
args = parser.parse_args() | |
# Create the interface | |
interface = AudioInterface( | |
controller_url=args.controller_url, | |
whisper_model_path=args.whisper_model_path, | |
vocoder_dir=args.vocoder_dir, | |
model_name=args.model_name, | |
read_tokens=args.read_tokens, | |
write_tokens=args.write_tokens | |
) | |
# Build and launch the interface | |
demo = interface.build_interface() | |
demo.queue() | |
demo.launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share | |
) | |
if __name__ == "__main__": | |
main() |