#!/usr/bin/env python3 """ A minimal Gradio web interface for LLaMA-Omni2 that doesn't rely on importing from the LLaMA-Omni2 package. """ import argparse import asyncio import json import logging import os import time from typing import Dict, List, Optional import aiohttp import gradio as gr # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class LLaMA_Omni2_UI: def __init__( self, controller_url: str, vocoder_dir: str ): self.controller_url = controller_url self.vocoder_dir = vocoder_dir self.model_list = [] self.model_names = [] # Verify vocoder directory exists if not os.path.exists(vocoder_dir): logger.warning(f"Vocoder directory not found at {vocoder_dir}") logger.warning("Voice synthesis will not be available") else: logger.info(f"Using vocoder at {vocoder_dir}") async def fetch_model_list(self): """Fetch the list of models from the controller""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.controller_url}/list_models", timeout=30 ) as response: if response.status == 200: data = await response.json() self.model_list = data.get("models", []) self.model_names = [model.get("name") for model in self.model_list] return self.model_names else: logger.error(f"Failed to fetch model list: {await response.text()}") return [] except Exception as e: logger.error(f"Error fetching model list: {e}") return [] async def get_worker_address(self, model_name: str): """Get the address of a worker serving the specified model""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.controller_url}/get_worker_address?model_name={model_name}", timeout=30 ) as response: if response.status == 200: data = await response.json() return data.get("address") else: logger.error(f"Failed to get worker address: {await response.text()}") return None except Exception as e: logger.error(f"Error getting worker address: {e}") return None async def generate_text(self, prompt: str, model_name: str): """Generate text using the specified model""" worker_addr = await self.get_worker_address(model_name) if not worker_addr: return f"Error: No worker available for model {model_name}" try: async with aiohttp.ClientSession() as session: async with session.post( f"{worker_addr}/generate", json={"prompt": prompt}, timeout=120 ) as response: if response.status == 200: data = await response.json() return data.get("response", "No response received from model") else: error_text = await response.text() logger.error(f"Failed to generate text: {error_text}") return f"Error: {error_text}" except Exception as e: logger.error(f"Error generating text: {e}") return f"Error: {str(e)}" def build_demo(self): """Build the Gradio interface""" with gr.Blocks(title="LLaMA-Omni2 Web UI") as demo: gr.Markdown("# LLaMA-Omni2 Web UI") with gr.Row(): with gr.Column(scale=1): model_dropdown = gr.Dropdown( choices=self.model_names or ["No models available"], label="Model", value=self.model_names[0] if self.model_names else None ) refresh_button = gr.Button("Refresh Models") with gr.Row(): with gr.Column(scale=3): text_input = gr.Textbox( lines=5, placeholder="Enter text here...", label="Input Text" ) with gr.Row(): with gr.Column(scale=1): submit_button = gr.Button("Generate", variant="primary") clear_button = gr.Button("Clear") with gr.Row(): with gr.Column(scale=3): text_output = gr.Textbox( lines=10, label="Generated Text", interactive=False ) async def refresh_models(): model_names = await self.fetch_model_list() return gr.Dropdown.update(choices=model_names or ["No models available"]) async def generate(text, model): if not text.strip(): return "Please enter some text" if not model or model == "No models available": return "Please select a model" return await self.generate_text(text, model) def clear(): return "", "" refresh_button.click(fn=lambda: asyncio.create_task(refresh_models()), outputs=[model_dropdown]) submit_button.click(fn=lambda text, model: asyncio.create_task(generate(text, model)), inputs=[text_input, model_dropdown], outputs=[text_output]) clear_button.click(fn=clear, outputs=[text_input, text_output]) return demo def main(): parser = argparse.ArgumentParser(description="Gradio web server for LLaMA-Omni2") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--controller-url", type=str, default="http://localhost:10000") parser.add_argument("--vocoder-dir", type=str, required=True) parser.add_argument("--share", action="store_true", help="Create a public link") args = parser.parse_args() logger.info(f"Using controller at {args.controller_url}") # Create the UI ui = LLaMA_Omni2_UI( controller_url=args.controller_url, vocoder_dir=args.vocoder_dir ) # Start by fetching the model list asyncio.run(ui.fetch_model_list()) # Build and launch the demo demo = ui.build_demo() demo.queue() demo.launch( server_name=args.host, server_port=args.port, share=args.share ) if __name__ == "__main__": main()