File size: 4,370 Bytes
d478b16
 
34b8b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbce578
34b8b49
 
 
fbce578
34b8b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ead3363
34b8b49
 
 
 
 
 
 
 
 
8ce9406
34b8b49
 
 
 
b9d0632
34b8b49
 
 
 
dc2c952
34b8b49
 
d478b16
34b8b49
 
dc2c952
34b8b49
 
 
 
 
dc2c952
34b8b49
 
b9d0632
34b8b49
 
 
d478b16
34b8b49
 
 
 
 
 
 
 
b9d0632
 
34b8b49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import subprocess
import threading
import time
import gradio as gr
import whisper
import requests

# Configuration
MODEL_NAME = "Llama-3.1-8B-Omni"
CONTROLLER_PORT = 10000
WEB_SERVER_PORT = 8000
MODEL_WORKER_PORT = 40000

# Paths
VOCODER_PATH = "vocoder/g_00500000"
VOCODER_CFG = "vocoder/config.json"

def download_models():
    """Ensure that required models are available"""
    os.makedirs("models/speech_encoder", exist_ok=True)
    
    # Download Whisper model if needed (this will happen during deployment)
    print("Setting up Whisper model...")
    whisper.load_model("large-v3", download_root="models/speech_encoder/")
    
    # Download vocoder if needed
    if not os.path.exists(VOCODER_PATH):
        print("Downloading vocoder...")
        os.makedirs("vocoder", exist_ok=True)
        subprocess.run([
            "wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000",
            "-P", "vocoder/"
        ])
        subprocess.run([
            "wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json",
            "-P", "vocoder/"
        ])

def start_controller():
    """Start the controller process"""
    print("Starting controller...")
    controller_process = subprocess.Popen([
        "python", "-m", "omni_speech.serve.controller",
        "--host", "0.0.0.0",
        "--port", str(CONTROLLER_PORT)
    ])
    time.sleep(5)  # Wait for controller to start
    return controller_process

def start_model_worker():
    """Start the model worker process"""
    print("Starting model worker...")
    worker_process = subprocess.Popen([
        "python", "-m", "omni_speech.serve.model_worker",
        "--host", "0.0.0.0",
        "--controller", f"http://localhost:{CONTROLLER_PORT}",
        "--port", str(MODEL_WORKER_PORT),
        "--worker", f"http://localhost:{MODEL_WORKER_PORT}",
        "--model-path", MODEL_NAME,
        "--model-name", MODEL_NAME,
        "--s2s"
    ])
    time.sleep(10)  # Wait for model worker to start
    return worker_process

def start_web_server():
    """Start the web server process"""
    print("Starting web server...")
    web_process = subprocess.Popen([
        "python", "-m", "omni_speech.serve.gradio_web_server",
        "--controller", f"http://localhost:{CONTROLLER_PORT}",
        "--port", str(WEB_SERVER_PORT),
        "--model-list-mode", "reload",
        "--vocoder", VOCODER_PATH,
        "--vocoder-cfg", VOCODER_CFG
    ])
    return web_process

def check_services():
    """Check if all services are running"""
    try:
        controller_resp = requests.get(f"http://localhost:{CONTROLLER_PORT}/status").json()
        web_server_resp = requests.get(f"http://localhost:{WEB_SERVER_PORT}/").status_code
        return controller_resp["status"] == "ok" and web_server_resp == 200
    except Exception:
        return False

def main():
    # Download required models
    download_models()
    
    # Start all services
    controller = start_controller()
    worker = start_model_worker()
    web_server = start_web_server()
    
    # Create a simple redirection interface
    with gr.Blocks() as demo:
        gr.Markdown("# πŸ¦™πŸŽ§ LLaMA-Omni")
        gr.Markdown("## Starting LLaMA-Omni services...")
        
        with gr.Row():
            status = gr.Textbox(value="Initializing...", label="Status")
        
        with gr.Row():
            redirect_btn = gr.Button("Go to LLaMA-Omni Interface")
        
        def update_status():
            if check_services():
                return "All services running! Click the button below to access the interface."
            else:
                return "Still starting services... Please wait."
        
        def redirect():
            return gr.Redirect(f"http://localhost:{WEB_SERVER_PORT}")
        
        # Update status every 5 seconds
        demo.load(update_status, outputs=status, every=5)
        redirect_btn.click(redirect)
    
    # Launch the Gradio interface
    try:
        demo.launch(server_name="0.0.0.0")
    finally:
        # Clean up processes when Gradio is closed
        controller.terminate()
        worker.terminate()
        web_server.terminate()

if __name__ == "__main__":
    main()