import argparse import json import os import time import requests import gradio as gr import uuid import logging from typing import Dict, List, Optional, Tuple, Union import tempfile logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Global variables controller_url = None vocoder_path = None vocoder_cfg = None model_list_mode = "once" # "once" or "reload" avatars = {} message_history = {} def list_models(): """Get list of available models from the controller.""" try: response = requests.get(f"{controller_url}/list_models") if response.status_code == 200: models = response.json().get("models", []) return models else: logger.error(f"Failed to list models: {response.text}") return [] except Exception as e: logger.error(f"Error listing models: {str(e)}") return [] def get_worker_address(model_name): """Get address of a worker that serves the requested model.""" try: response = requests.get(f"{controller_url}/get_worker_address", params={"model_name": model_name}) if response.status_code == 200: return response.json().get("worker_address") else: logger.error(f"Failed to get worker address: {response.text}") return None except Exception as e: logger.error(f"Error getting worker address: {str(e)}") return None def transcribe_audio(audio_path): """Placeholder for audio transcription.""" # In a real implementation, this would use the Whisper model logger.info(f"Transcribing audio from {audio_path}...") # Simulated transcription return f"This is a placeholder transcription for audio file {os.path.basename(audio_path)}" def process_speech_to_speech(audio_path, model_name): """Process speech to speech generation.""" if not audio_path: return "Error: No audio provided", None try: # Transcribe the audio transcription = transcribe_audio(audio_path) # Get worker address worker_address = get_worker_address(model_name) if not worker_address: return f"Error: No worker available for model {model_name}", None # Send request to worker response = requests.post( f"{worker_address}/generate_speech", json={"prompt": transcription} ) if response.status_code == 200: result = response.json() text_response = result.get("text", "No text response generated") speech_url = result.get("speech_url") # In a real implementation, we would handle the audio file # For now, we'll just return the text response return text_response, speech_url else: return f"Error: {response.text}", None except Exception as e: logger.error(f"Error in speech-to-speech processing: {str(e)}") return f"Error: {str(e)}", None def process_text_to_speech(text, model_name): """Process text to speech generation.""" if not text: return "Error: No text provided", None try: # Get worker address worker_address = get_worker_address(model_name) if not worker_address: return f"Error: No worker available for model {model_name}", None # Send request to worker response = requests.post( f"{worker_address}/generate_speech", json={"prompt": text} ) if response.status_code == 200: result = response.json() text_response = result.get("text", "No text response generated") speech_url = result.get("speech_url") # In a real implementation, we would handle the audio file # For now, we'll just return the text response return text_response, speech_url else: return f"Error: {response.text}", None except Exception as e: logger.error(f"Error in text-to-speech processing: {str(e)}") return f"Error: {str(e)}", None def create_chat_ui(): """Create the Gradio chat UI.""" available_models = list_models() logger.info(f"Available models: {available_models}") with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown("# 🦙🎧 LLaMA-Omni Speech Interaction Demo") with gr.Row(): with gr.Column(scale=3): # Input area with gr.Tab("Speech Input"): audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload audio") transcription_output = gr.Textbox(label="Transcription", interactive=False) with gr.Tab("Text Input"): text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...") # Common controls with gr.Row(): model_selector = gr.Dropdown(choices=available_models, label="Model", value=available_models[0] if available_models else None) submit_btn = gr.Button("Submit") if model_list_mode == "reload": refresh_btn = gr.Button("Refresh Models") with gr.Column(scale=4): # Output area chatbot = gr.Chatbot(label="Conversation", height=500) with gr.Row(): audio_output = gr.Audio(label="Generated Speech", interactive=False) # Event handlers def on_audio_input(audio): if audio: transcription = transcribe_audio(audio) return transcription return "" def on_speech_submit(audio, model_name, chat_history): if not audio: return chat_history, None transcription = transcribe_audio(audio) text_response, speech_url = process_speech_to_speech(audio, model_name) # Update chat history new_history = chat_history.copy() new_history.append((transcription, text_response)) # In a real implementation, we would handle the audio file # For now, we'll just return None for audio output return new_history, None def on_text_submit(text, model_name, chat_history): if not text: return chat_history, None text_response, speech_url = process_text_to_speech(text, model_name) # Update chat history new_history = chat_history.copy() new_history.append((text, text_response)) # In a real implementation, we would handle the audio file # For now, we'll just return None for audio output return new_history, None def on_refresh_models(): return gr.Dropdown.update(choices=list_models()) # Connect events audio_input.change(on_audio_input, [audio_input], [transcription_output]) submit_btn.click( fn=lambda audio, text, model, chat: on_speech_submit(audio, model, chat) if audio else on_text_submit(text, model, chat), inputs=[audio_input, text_input, model_selector, chatbot], outputs=[chatbot, audio_output] ) if model_list_mode == "reload": refresh_btn.click(on_refresh_models, [], [model_selector]) return demo def main(): """Run the Gradio web server.""" global controller_url, vocoder_path, vocoder_cfg, model_list_mode parser = argparse.ArgumentParser(description="LLaMA-Omni Gradio web server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server") parser.add_argument("--port", type=int, default=8000, help="Port to bind the server") parser.add_argument("--controller", type=str, required=True, help="Controller URL") parser.add_argument("--vocoder", type=str, required=True, help="Path to vocoder model") parser.add_argument("--vocoder-cfg", type=str, required=True, help="Path to vocoder config") parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"], help="Model listing mode") args = parser.parse_args() controller_url = args.controller vocoder_path = args.vocoder vocoder_cfg = args.vocoder_cfg model_list_mode = args.model_list_mode # Create the demo demo = create_chat_ui() # Launch the server demo.launch(server_name=args.host, server_port=args.port, share=False) if __name__ == "__main__": main()