llama-omni / app.py
marcosremar2's picture
dfdfdf
34b8b49
raw
history blame
4.37 kB
import os
import subprocess
import threading
import time
import gradio as gr
import whisper
import requests
# Configuration
MODEL_NAME = "Llama-3.1-8B-Omni"
CONTROLLER_PORT = 10000
WEB_SERVER_PORT = 8000
MODEL_WORKER_PORT = 40000
# Paths
VOCODER_PATH = "vocoder/g_00500000"
VOCODER_CFG = "vocoder/config.json"
def download_models():
"""Ensure that required models are available"""
os.makedirs("models/speech_encoder", exist_ok=True)
# Download Whisper model if needed (this will happen during deployment)
print("Setting up Whisper model...")
whisper.load_model("large-v3", download_root="models/speech_encoder/")
# Download vocoder if needed
if not os.path.exists(VOCODER_PATH):
print("Downloading vocoder...")
os.makedirs("vocoder", exist_ok=True)
subprocess.run([
"wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000",
"-P", "vocoder/"
])
subprocess.run([
"wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json",
"-P", "vocoder/"
])
def start_controller():
"""Start the controller process"""
print("Starting controller...")
controller_process = subprocess.Popen([
"python", "-m", "omni_speech.serve.controller",
"--host", "0.0.0.0",
"--port", str(CONTROLLER_PORT)
])
time.sleep(5) # Wait for controller to start
return controller_process
def start_model_worker():
"""Start the model worker process"""
print("Starting model worker...")
worker_process = subprocess.Popen([
"python", "-m", "omni_speech.serve.model_worker",
"--host", "0.0.0.0",
"--controller", f"http://localhost:{CONTROLLER_PORT}",
"--port", str(MODEL_WORKER_PORT),
"--worker", f"http://localhost:{MODEL_WORKER_PORT}",
"--model-path", MODEL_NAME,
"--model-name", MODEL_NAME,
"--s2s"
])
time.sleep(10) # Wait for model worker to start
return worker_process
def start_web_server():
"""Start the web server process"""
print("Starting web server...")
web_process = subprocess.Popen([
"python", "-m", "omni_speech.serve.gradio_web_server",
"--controller", f"http://localhost:{CONTROLLER_PORT}",
"--port", str(WEB_SERVER_PORT),
"--model-list-mode", "reload",
"--vocoder", VOCODER_PATH,
"--vocoder-cfg", VOCODER_CFG
])
return web_process
def check_services():
"""Check if all services are running"""
try:
controller_resp = requests.get(f"http://localhost:{CONTROLLER_PORT}/status").json()
web_server_resp = requests.get(f"http://localhost:{WEB_SERVER_PORT}/").status_code
return controller_resp["status"] == "ok" and web_server_resp == 200
except Exception:
return False
def main():
# Download required models
download_models()
# Start all services
controller = start_controller()
worker = start_model_worker()
web_server = start_web_server()
# Create a simple redirection interface
with gr.Blocks() as demo:
gr.Markdown("# πŸ¦™πŸŽ§ LLaMA-Omni")
gr.Markdown("## Starting LLaMA-Omni services...")
with gr.Row():
status = gr.Textbox(value="Initializing...", label="Status")
with gr.Row():
redirect_btn = gr.Button("Go to LLaMA-Omni Interface")
def update_status():
if check_services():
return "All services running! Click the button below to access the interface."
else:
return "Still starting services... Please wait."
def redirect():
return gr.Redirect(f"http://localhost:{WEB_SERVER_PORT}")
# Update status every 5 seconds
demo.load(update_status, outputs=status, every=5)
redirect_btn.click(redirect)
# Launch the Gradio interface
try:
demo.launch(server_name="0.0.0.0")
finally:
# Clean up processes when Gradio is closed
controller.terminate()
worker.terminate()
web_server.terminate()
if __name__ == "__main__":
main()