Spaces:
Build error
Build error
Commit
·
ec8ce73
1
Parent(s):
a964a55
reerer
Browse files- launch_llama_omni2.py +142 -86
launch_llama_omni2.py
CHANGED
@@ -114,7 +114,7 @@ def start_controller():
|
|
114 |
print("=== Starting LLaMA-Omni2 Controller ===")
|
115 |
|
116 |
# First try to use our custom implementation
|
117 |
-
direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "
|
118 |
if os.path.exists(direct_controller_path):
|
119 |
print(f"Using custom controller implementation: {direct_controller_path}")
|
120 |
cmd = [
|
@@ -128,36 +128,72 @@ def start_controller():
|
|
128 |
print(f"Controller started with PID: {process.pid}")
|
129 |
return process
|
130 |
|
131 |
-
# Fall back to
|
132 |
-
|
133 |
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
return None
|
137 |
-
|
138 |
-
cmd = [
|
139 |
-
sys.executable, controller_path,
|
140 |
-
"--host", "0.0.0.0",
|
141 |
-
"--port", "10000"
|
142 |
-
]
|
143 |
-
|
144 |
-
env = os.environ.copy()
|
145 |
-
if EXTRACTION_DIR not in env.get("PYTHONPATH", ""):
|
146 |
-
env["PYTHONPATH"] = f"{EXTRACTION_DIR}:{env.get('PYTHONPATH', '')}"
|
147 |
-
|
148 |
-
print(f"Running: {' '.join(cmd)}")
|
149 |
-
print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
|
150 |
-
|
151 |
-
process = subprocess.Popen(cmd, env=env)
|
152 |
-
print(f"Controller started with PID: {process.pid}")
|
153 |
-
return process
|
154 |
|
155 |
def start_model_worker():
|
156 |
"""Start the LLaMA-Omni2 model worker directly"""
|
157 |
print("=== Starting LLaMA-Omni2 Model Worker ===")
|
158 |
|
159 |
# First try to use our custom implementation
|
160 |
-
direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "
|
161 |
if os.path.exists(direct_worker_path):
|
162 |
print(f"Using custom model worker implementation: {direct_worker_path}")
|
163 |
cmd = [
|
@@ -175,40 +211,27 @@ def start_model_worker():
|
|
175 |
print(f"Model worker started with PID: {process.pid}")
|
176 |
return process
|
177 |
|
178 |
-
# Fall back to
|
179 |
-
|
180 |
-
|
181 |
-
if not os.path.exists(model_worker_path):
|
182 |
-
print(f"Model worker script not found at {model_worker_path}")
|
183 |
-
return None
|
184 |
-
|
185 |
-
cmd = [
|
186 |
-
sys.executable, model_worker_path,
|
187 |
-
"--host", "0.0.0.0",
|
188 |
-
"--controller", "http://localhost:10000",
|
189 |
-
"--port", "40000",
|
190 |
-
"--worker", "http://localhost:40000",
|
191 |
-
"--model-path", LLAMA_OMNI2_MODEL_PATH,
|
192 |
-
"--model-name", LLAMA_OMNI2_MODEL_NAME
|
193 |
-
]
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
|
201 |
-
|
202 |
-
process = subprocess.Popen(cmd, env=env)
|
203 |
-
print(f"Model worker started with PID: {process.pid}")
|
204 |
-
return process
|
205 |
|
206 |
def start_gradio_server():
|
207 |
"""Start the LLaMA-Omni2 Gradio web server directly"""
|
208 |
print("=== Starting LLaMA-Omni2 Gradio Server ===")
|
209 |
|
210 |
# First try to use our custom implementation
|
211 |
-
direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "
|
212 |
if os.path.exists(direct_gradio_path):
|
213 |
print(f"Using custom Gradio server implementation: {direct_gradio_path}")
|
214 |
cmd = [
|
@@ -224,32 +247,59 @@ def start_gradio_server():
|
|
224 |
print(f"Gradio server started with PID: {process.pid}")
|
225 |
return process
|
226 |
|
227 |
-
# Fall back to
|
228 |
-
|
229 |
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
return None
|
233 |
-
|
234 |
-
cmd = [
|
235 |
-
sys.executable, gradio_server_path,
|
236 |
-
"--host", "0.0.0.0",
|
237 |
-
"--port", "7860",
|
238 |
-
"--controller-url", "http://localhost:10000",
|
239 |
-
"--model-list-mode", "reload",
|
240 |
-
"--vocoder-dir", COSYVOICE_PATH
|
241 |
-
]
|
242 |
-
|
243 |
-
env = os.environ.copy()
|
244 |
-
if EXTRACTION_DIR not in env.get("PYTHONPATH", ""):
|
245 |
-
env["PYTHONPATH"] = f"{EXTRACTION_DIR}:{env.get('PYTHONPATH', '')}"
|
246 |
-
|
247 |
-
print(f"Running: {' '.join(cmd)}")
|
248 |
-
print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
|
249 |
-
|
250 |
-
process = subprocess.Popen(cmd, env=env)
|
251 |
-
print(f"Gradio server started with PID: {process.pid}")
|
252 |
-
return process
|
253 |
|
254 |
def patch_extracted_files(extraction_dir):
|
255 |
"""Patch the extracted Python files to handle missing imports"""
|
@@ -338,25 +388,31 @@ def main():
|
|
338 |
print("Checking and installing dependencies...")
|
339 |
download_dependencies()
|
340 |
|
341 |
-
#
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
else:
|
347 |
-
print("LLaMA-Omni2 scripts already extracted.")
|
348 |
|
349 |
# Ensure the module structure is complete
|
350 |
ensure_module_structure(EXTRACTION_DIR)
|
351 |
|
352 |
-
#
|
353 |
-
|
354 |
|
355 |
# Add the extraction dir to Python path
|
356 |
if EXTRACTION_DIR not in sys.path:
|
357 |
sys.path.insert(0, EXTRACTION_DIR)
|
358 |
print(f"Added {EXTRACTION_DIR} to sys.path")
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
# Start controller
|
361 |
controller_process = start_controller()
|
362 |
if not controller_process:
|
@@ -365,7 +421,7 @@ def main():
|
|
365 |
|
366 |
# Wait for controller to initialize
|
367 |
print("Waiting for controller to initialize...")
|
368 |
-
time.sleep(
|
369 |
|
370 |
# Start model worker
|
371 |
model_worker_process = start_model_worker()
|
@@ -374,9 +430,9 @@ def main():
|
|
374 |
controller_process.terminate()
|
375 |
return 1
|
376 |
|
377 |
-
# Wait for model to load
|
378 |
-
print("Waiting for model to
|
379 |
-
time.sleep(
|
380 |
|
381 |
# Start Gradio server
|
382 |
gradio_process = start_gradio_server()
|
|
|
114 |
print("=== Starting LLaMA-Omni2 Controller ===")
|
115 |
|
116 |
# First try to use our custom implementation
|
117 |
+
direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "controller.py")
|
118 |
if os.path.exists(direct_controller_path):
|
119 |
print(f"Using custom controller implementation: {direct_controller_path}")
|
120 |
cmd = [
|
|
|
128 |
print(f"Controller started with PID: {process.pid}")
|
129 |
return process
|
130 |
|
131 |
+
# Fall back to a simple controller implementation
|
132 |
+
print("No controller script found. Implementing a simple controller...")
|
133 |
|
134 |
+
try:
|
135 |
+
from fastapi import FastAPI, HTTPException
|
136 |
+
import uvicorn
|
137 |
+
from pydantic import BaseModel
|
138 |
+
import threading
|
139 |
+
|
140 |
+
app = FastAPI()
|
141 |
+
|
142 |
+
class ModelInfo(BaseModel):
|
143 |
+
model_name: str
|
144 |
+
worker_name: str
|
145 |
+
worker_addr: str
|
146 |
+
|
147 |
+
# Simple in-memory storage
|
148 |
+
registered_models = {}
|
149 |
+
|
150 |
+
@app.get("/")
|
151 |
+
def read_root():
|
152 |
+
return {"status": "ok", "models": list(registered_models.keys())}
|
153 |
+
|
154 |
+
@app.get("/api/v1/models")
|
155 |
+
def list_models():
|
156 |
+
return {"models": list(registered_models.keys())}
|
157 |
+
|
158 |
+
@app.post("/api/v1/register_worker")
|
159 |
+
def register_worker(model_info: ModelInfo):
|
160 |
+
registered_models[model_info.model_name] = {
|
161 |
+
"worker_name": model_info.worker_name,
|
162 |
+
"worker_addr": model_info.worker_addr
|
163 |
+
}
|
164 |
+
return {"status": "ok"}
|
165 |
+
|
166 |
+
# Start a simple controller
|
167 |
+
def run_controller():
|
168 |
+
uvicorn.run(app, host="0.0.0.0", port=10000)
|
169 |
+
|
170 |
+
thread = threading.Thread(target=run_controller, daemon=True)
|
171 |
+
thread.start()
|
172 |
+
|
173 |
+
print("Simple controller started on port 10000")
|
174 |
+
# Return a dummy process for compatibility
|
175 |
+
class DummyProcess:
|
176 |
+
def __init__(self):
|
177 |
+
self.pid = 0
|
178 |
+
def terminate(self):
|
179 |
+
pass
|
180 |
+
def poll(self):
|
181 |
+
return None
|
182 |
+
def wait(self, timeout=None):
|
183 |
+
pass
|
184 |
+
|
185 |
+
return DummyProcess()
|
186 |
+
|
187 |
+
except ImportError as e:
|
188 |
+
print(f"Failed to create simple controller: {e}")
|
189 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def start_model_worker():
|
192 |
"""Start the LLaMA-Omni2 model worker directly"""
|
193 |
print("=== Starting LLaMA-Omni2 Model Worker ===")
|
194 |
|
195 |
# First try to use our custom implementation
|
196 |
+
direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_worker.py")
|
197 |
if os.path.exists(direct_worker_path):
|
198 |
print(f"Using custom model worker implementation: {direct_worker_path}")
|
199 |
cmd = [
|
|
|
211 |
print(f"Model worker started with PID: {process.pid}")
|
212 |
return process
|
213 |
|
214 |
+
# Fall back to a simple implementation
|
215 |
+
print("No model worker script found. Will try to start Gradio directly with the model.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
+
class DummyProcess:
|
218 |
+
def __init__(self):
|
219 |
+
self.pid = 0
|
220 |
+
def terminate(self):
|
221 |
+
pass
|
222 |
+
def poll(self):
|
223 |
+
return None
|
224 |
+
def wait(self, timeout=None):
|
225 |
+
pass
|
226 |
|
227 |
+
return DummyProcess()
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
def start_gradio_server():
|
230 |
"""Start the LLaMA-Omni2 Gradio web server directly"""
|
231 |
print("=== Starting LLaMA-Omni2 Gradio Server ===")
|
232 |
|
233 |
# First try to use our custom implementation
|
234 |
+
direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gradio_web_server.py")
|
235 |
if os.path.exists(direct_gradio_path):
|
236 |
print(f"Using custom Gradio server implementation: {direct_gradio_path}")
|
237 |
cmd = [
|
|
|
247 |
print(f"Gradio server started with PID: {process.pid}")
|
248 |
return process
|
249 |
|
250 |
+
# Fall back to a simple Gradio implementation
|
251 |
+
print("No Gradio server found. Attempting to create a simple interface...")
|
252 |
|
253 |
+
try:
|
254 |
+
import gradio as gr
|
255 |
+
import threading
|
256 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
257 |
+
|
258 |
+
# Simple function to launch a basic Gradio interface
|
259 |
+
def launch_simple_gradio():
|
260 |
+
try:
|
261 |
+
print(f"Loading model from {LLAMA_OMNI2_MODEL_PATH}...")
|
262 |
+
tokenizer = AutoTokenizer.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
|
263 |
+
model = AutoModelForCausalLM.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
|
264 |
+
|
265 |
+
def generate_text(input_text):
|
266 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
267 |
+
outputs = model.generate(inputs.input_ids, max_length=100)
|
268 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
269 |
+
|
270 |
+
with gr.Blocks() as demo:
|
271 |
+
gr.Markdown("# LLaMA-Omni2 Simple Interface")
|
272 |
+
with gr.Tab("Text Generation"):
|
273 |
+
input_text = gr.Textbox(label="Input Text")
|
274 |
+
output_text = gr.Textbox(label="Generated Text")
|
275 |
+
generate_btn = gr.Button("Generate")
|
276 |
+
generate_btn.click(generate_text, inputs=input_text, outputs=output_text)
|
277 |
+
|
278 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
279 |
+
|
280 |
+
except Exception as e:
|
281 |
+
print(f"Error in simple Gradio interface: {e}")
|
282 |
+
|
283 |
+
thread = threading.Thread(target=launch_simple_gradio, daemon=True)
|
284 |
+
thread.start()
|
285 |
+
|
286 |
+
print("Simple Gradio interface started on port 7860")
|
287 |
+
|
288 |
+
class DummyProcess:
|
289 |
+
def __init__(self):
|
290 |
+
self.pid = 0
|
291 |
+
def terminate(self):
|
292 |
+
pass
|
293 |
+
def poll(self):
|
294 |
+
return None
|
295 |
+
def wait(self, timeout=None):
|
296 |
+
pass
|
297 |
+
|
298 |
+
return DummyProcess()
|
299 |
+
|
300 |
+
except ImportError as e:
|
301 |
+
print(f"Failed to create simple Gradio interface: {e}")
|
302 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
def patch_extracted_files(extraction_dir):
|
305 |
"""Patch the extracted Python files to handle missing imports"""
|
|
|
388 |
print("Checking and installing dependencies...")
|
389 |
download_dependencies()
|
390 |
|
391 |
+
# Create directories directly instead of using extraction script
|
392 |
+
print("Creating necessary directories...")
|
393 |
+
os.makedirs(EXTRACTION_DIR, exist_ok=True)
|
394 |
+
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2"), exist_ok=True)
|
395 |
+
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2", "serve"), exist_ok=True)
|
|
|
|
|
396 |
|
397 |
# Ensure the module structure is complete
|
398 |
ensure_module_structure(EXTRACTION_DIR)
|
399 |
|
400 |
+
# Skip patching files as we're not extracting anything
|
401 |
+
print("Skipping file patching as we're not running extraction")
|
402 |
|
403 |
# Add the extraction dir to Python path
|
404 |
if EXTRACTION_DIR not in sys.path:
|
405 |
sys.path.insert(0, EXTRACTION_DIR)
|
406 |
print(f"Added {EXTRACTION_DIR} to sys.path")
|
407 |
|
408 |
+
# Skip directly to model download and starting services
|
409 |
+
print("Proceeding directly to model download and starting services...")
|
410 |
+
|
411 |
+
# Make directories for models
|
412 |
+
os.makedirs(MODELS_DIR, exist_ok=True)
|
413 |
+
os.makedirs(LLAMA_OMNI2_MODEL_PATH, exist_ok=True)
|
414 |
+
os.makedirs(COSYVOICE_PATH, exist_ok=True)
|
415 |
+
|
416 |
# Start controller
|
417 |
controller_process = start_controller()
|
418 |
if not controller_process:
|
|
|
421 |
|
422 |
# Wait for controller to initialize
|
423 |
print("Waiting for controller to initialize...")
|
424 |
+
time.sleep(5)
|
425 |
|
426 |
# Start model worker
|
427 |
model_worker_process = start_model_worker()
|
|
|
430 |
controller_process.terminate()
|
431 |
return 1
|
432 |
|
433 |
+
# Wait for model to load - reduced from 300 seconds to 30 seconds
|
434 |
+
print("Waiting for model worker to initialize...")
|
435 |
+
time.sleep(30)
|
436 |
|
437 |
# Start Gradio server
|
438 |
gradio_process = start_gradio_server()
|