llama-omni / audio_interface.py
marcosremar2's picture
dfdfd
c85077c
raw
history blame
17.3 kB
#!/usr/bin/env python3
"""
Audio interface for LLaMA-Omni2 that accepts audio input and returns audio output.
This interface:
1. Transcribes audio input using Whisper
2. Processes the transcription with LLaMA-Omni2 model
3. Synthesizes the response back to audio using CosyVoice 2
Enhanced with streaming generation and read-write scheduling for real-time response.
"""
import os
import sys
import argparse
import logging
import time
import asyncio
import tempfile
from pathlib import Path
from queue import Queue
from threading import Thread
import json
import torch
import torchaudio
import gradio as gr
import whisper
import aiohttp
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioInterface:
def __init__(
self,
controller_url: str,
whisper_model_path: str,
vocoder_dir: str,
model_name: str = "LLaMA-Omni2-7B-Bilingual",
read_tokens: int = 3,
write_tokens: int = 10
):
self.controller_url = controller_url
self.whisper_model_path = whisper_model_path
self.vocoder_dir = vocoder_dir
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Read-write scheduling parameters for streaming generation
self.read_tokens = read_tokens # Number of text tokens to read
self.write_tokens = write_tokens # Number of speech tokens to write
# Load Whisper model
try:
logger.info(f"Loading Whisper model from {whisper_model_path}")
self.whisper_model = whisper.load_model("large-v3",
download_root=whisper_model_path,
device=self.device)
logger.info("Whisper model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Whisper model: {e}")
self.whisper_model = None
# Load CosyVoice vocoder
try:
sys.path.insert(0, vocoder_dir)
from cosy_voice_2.inference import CosyVoice
self.vocoder = CosyVoice(
device=self.device,
model_path=vocoder_dir
)
logger.info(f"CosyVoice vocoder loaded from {vocoder_dir}")
except Exception as e:
logger.error(f"Failed to load CosyVoice vocoder: {e}")
self.vocoder = None
logger.info(f"Using LLaMA-Omni2 model: {model_name}")
async def get_worker_address(self):
"""Get the address of the worker serving the model"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.controller_url}/get_worker_address?model_name={self.model_name}",
timeout=30
) as response:
if response.status == 200:
data = await response.json()
return data.get("address")
else:
logger.error(f"Failed to get worker address: {await response.text()}")
return None
except Exception as e:
logger.error(f"Error getting worker address: {e}")
return None
async def generate_text(self, prompt: str, streaming=False):
"""Generate text from LLaMA-Omni2 model"""
worker_addr = await self.get_worker_address()
if not worker_addr:
return f"Error: No worker available for model {self.model_name}"
try:
async with aiohttp.ClientSession() as session:
# For streaming generation
if streaming:
async with session.post(
f"{worker_addr}/generate_stream",
json={"prompt": prompt},
timeout=120
) as response:
if response.status == 200:
response_text = ""
async for line in response.content:
if line:
data = json.loads(line)
chunk = data.get("text", "")
response_text += chunk
yield response_text
return response_text
else:
error_text = await response.text()
logger.error(f"Failed to generate text stream: {error_text}")
return f"Error: {error_text}"
# For non-streaming generation
else:
async with session.post(
f"{worker_addr}/generate",
json={"prompt": prompt},
timeout=120
) as response:
if response.status == 200:
data = await response.json()
return data.get("response", "No response received from model")
else:
error_text = await response.text()
logger.error(f"Failed to generate text: {error_text}")
return f"Error: {error_text}"
except Exception as e:
logger.error(f"Error generating text: {e}")
return f"Error: {str(e)}"
def transcribe_audio(self, audio_path):
"""Transcribe audio using Whisper"""
if self.whisper_model is None:
return "Error: Whisper model not loaded"
try:
logger.info(f"Transcribing audio from {audio_path}")
result = self.whisper_model.transcribe(audio_path)
logger.info("Transcription completed")
return result["text"]
except Exception as e:
logger.error(f"Error transcribing audio: {e}")
return f"Error transcribing audio: {str(e)}"
def synthesize_speech(self, text):
"""Synthesize speech from text using CosyVoice"""
if self.vocoder is None:
return None, 16000, "Error: Vocoder not loaded"
try:
logger.info("Synthesizing speech from text response")
# Generate speech using CosyVoice
waveform = self.vocoder.inference(text)
sample_rate = self.vocoder.sample_rate
# Convert to numpy array for Gradio
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
logger.info("Speech synthesis completed")
return waveform, sample_rate, None
except Exception as e:
logger.error(f"Error synthesizing speech: {e}")
return None, 16000, f"Error synthesizing speech: {str(e)}"
async def synthesize_speech_chunk(self, text_chunk):
"""Synthesize speech for a single text chunk"""
if self.vocoder is None:
return None, 16000, "Error: Vocoder not loaded"
try:
# Generate speech using CosyVoice for this chunk
waveform = self.vocoder.inference(text_chunk)
sample_rate = self.vocoder.sample_rate
# Convert to numpy array
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
return waveform, sample_rate, None
except Exception as e:
logger.error(f"Error synthesizing speech chunk: {e}")
return None, 16000, f"Error synthesizing speech chunk: {str(e)}"
async def stream_text_to_speech(self, text_generator):
"""Stream text to speech using read-write scheduling"""
buffer = ""
audio_chunks = []
try:
async for text in text_generator:
# Accumulate text until we have enough to synthesize
buffer += text
# When we have enough tokens for synthesis (approximate by characters)
if len(buffer.split()) >= self.read_tokens:
# Process the buffer
chunk_to_process = buffer
buffer = ""
# Synthesize this chunk
audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(chunk_to_process)
if error:
logger.error(f"Error in streaming synthesis: {error}")
continue
# Add to our collection of audio chunks
audio_chunks.append(audio_chunk)
# Yield the current concatenated audio
if audio_chunks:
# Concatenate audio chunks
full_audio = np.concatenate(audio_chunks)
yield full_audio, sample_rate, chunk_to_process
# Process any remaining text in the buffer
if buffer:
audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(buffer)
if not error and audio_chunk is not None:
audio_chunks.append(audio_chunk)
# Final audio output
if audio_chunks:
full_audio = np.concatenate(audio_chunks)
return full_audio, sample_rate, None
else:
return None, 16000, "No audio generated"
except Exception as e:
logger.error(f"Error in streaming text to speech: {e}")
return None, 16000, f"Error in streaming text to speech: {str(e)}"
async def process_audio(self, audio_data, sample_rate, streaming=False):
"""Process audio input and return audio output"""
# Save the input audio to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
temp_path = temp_audio.name
# Convert sample rate if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)
audio_tensor = torch.tensor(audio_data).unsqueeze(0)
audio_tensor = resampler(audio_tensor)
audio_data = audio_tensor.squeeze(0).numpy()
sample_rate = 16000
# Save as WAV
torchaudio.save(temp_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
try:
# Step 1: Transcribe audio
transcription = self.transcribe_audio(temp_path)
if transcription.startswith("Error"):
return None, sample_rate, transcription, "Error occurred during transcription", transcription
# Step 2: Process with LLaMA-Omni2
if streaming:
# For streaming mode, we use a generator
text_generator = self.generate_text(transcription, streaming=True)
audio_generator = self.stream_text_to_speech(text_generator)
return audio_generator, transcription
else:
# For non-streaming mode
response_text = await self.generate_text(transcription)
if response_text.startswith("Error"):
return None, sample_rate, transcription, response_text, response_text
# Step 3: Synthesize speech
audio_output, out_sample_rate, error = self.synthesize_speech(response_text)
if error:
return None, sample_rate, transcription, response_text, error
return audio_output, out_sample_rate, transcription, response_text, None
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
def build_interface(self):
"""Build Gradio interface"""
with gr.Blocks(title="LLaMA-Omni2 Audio Interface") as demo:
gr.Markdown("# LLaMA-Omni2 Audio Interface")
gr.Markdown("Speak to LLaMA-Omni2 and hear its response in real-time")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Input Audio"
)
with gr.Row():
submit_button = gr.Button("Process Audio", variant="primary")
stream_button = gr.Button("Stream Audio Response", variant="secondary")
with gr.Column():
transcription = gr.Textbox(
label="Transcription",
interactive=False
)
response_text = gr.Textbox(
label="Response Text",
interactive=False
)
audio_output = gr.Audio(
label="Response Audio",
type="numpy",
interactive=False
)
error_text = gr.Textbox(
label="Errors (if any)",
interactive=False,
visible=False
)
async def process_wrapper(audio_data):
if audio_data is None:
return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
audio_array, sample_rate = audio_data
output, out_sample_rate, trans, resp, error = await self.process_audio(audio_array, sample_rate, streaming=False)
if error:
gr.update(visible=True)
return None, trans, resp, error
return (output, out_sample_rate), trans, resp, ""
async def stream_wrapper(audio_data):
if audio_data is None:
return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
audio_array, sample_rate = audio_data
generator, transcription = await self.process_audio(audio_array, sample_rate, streaming=True)
# Update transcription immediately
yield None, transcription, "", ""
# Start streaming
current_text = ""
async for audio_chunk, sr, text_chunk in generator:
current_text += text_chunk
yield (audio_chunk, sr), transcription, current_text, ""
submit_button.click(
fn=lambda audio: asyncio.create_task(process_wrapper(audio)),
inputs=[audio_input],
outputs=[audio_output, transcription, response_text, error_text]
)
stream_button.click(
fn=lambda audio: stream_wrapper(audio),
inputs=[audio_input],
outputs=[audio_output, transcription, response_text, error_text]
)
return demo
def main():
parser = argparse.ArgumentParser(description="Audio interface for LLaMA-Omni2")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
parser.add_argument("--whisper-model-path", type=str, default="models/speech_encoder")
parser.add_argument("--vocoder-dir", type=str, default="models/cosy2_decoder")
parser.add_argument("--model-name", type=str, default="LLaMA-Omni2-7B-Bilingual")
parser.add_argument("--read-tokens", type=int, default=3,
help="Number of text tokens to read before generating speech")
parser.add_argument("--write-tokens", type=int, default=10,
help="Number of speech tokens to write for each read")
parser.add_argument("--share", action="store_true", help="Create a public link")
args = parser.parse_args()
# Create the interface
interface = AudioInterface(
controller_url=args.controller_url,
whisper_model_path=args.whisper_model_path,
vocoder_dir=args.vocoder_dir,
model_name=args.model_name,
read_tokens=args.read_tokens,
write_tokens=args.write_tokens
)
# Build and launch the interface
demo = interface.build_interface()
demo.queue()
demo.launch(
server_name=args.host,
server_port=args.port,
share=args.share
)
if __name__ == "__main__":
main()