marcosremar2 commited on
Commit
ec8ce73
·
1 Parent(s): a964a55
Files changed (1) hide show
  1. launch_llama_omni2.py +142 -86
launch_llama_omni2.py CHANGED
@@ -114,7 +114,7 @@ def start_controller():
114
  print("=== Starting LLaMA-Omni2 Controller ===")
115
 
116
  # First try to use our custom implementation
117
- direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run_controller_directly.py")
118
  if os.path.exists(direct_controller_path):
119
  print(f"Using custom controller implementation: {direct_controller_path}")
120
  cmd = [
@@ -128,36 +128,72 @@ def start_controller():
128
  print(f"Controller started with PID: {process.pid}")
129
  return process
130
 
131
- # Fall back to the extracted script
132
- controller_path = os.path.join(EXTRACTION_DIR, "llama_omni2", "serve", "controller.py")
133
 
134
- if not os.path.exists(controller_path):
135
- print(f"Controller script not found at {controller_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return None
137
-
138
- cmd = [
139
- sys.executable, controller_path,
140
- "--host", "0.0.0.0",
141
- "--port", "10000"
142
- ]
143
-
144
- env = os.environ.copy()
145
- if EXTRACTION_DIR not in env.get("PYTHONPATH", ""):
146
- env["PYTHONPATH"] = f"{EXTRACTION_DIR}:{env.get('PYTHONPATH', '')}"
147
-
148
- print(f"Running: {' '.join(cmd)}")
149
- print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
150
-
151
- process = subprocess.Popen(cmd, env=env)
152
- print(f"Controller started with PID: {process.pid}")
153
- return process
154
 
155
  def start_model_worker():
156
  """Start the LLaMA-Omni2 model worker directly"""
157
  print("=== Starting LLaMA-Omni2 Model Worker ===")
158
 
159
  # First try to use our custom implementation
160
- direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run_model_worker_directly.py")
161
  if os.path.exists(direct_worker_path):
162
  print(f"Using custom model worker implementation: {direct_worker_path}")
163
  cmd = [
@@ -175,40 +211,27 @@ def start_model_worker():
175
  print(f"Model worker started with PID: {process.pid}")
176
  return process
177
 
178
- # Fall back to the extracted script
179
- model_worker_path = os.path.join(EXTRACTION_DIR, "llama_omni2", "serve", "model_worker.py")
180
-
181
- if not os.path.exists(model_worker_path):
182
- print(f"Model worker script not found at {model_worker_path}")
183
- return None
184
-
185
- cmd = [
186
- sys.executable, model_worker_path,
187
- "--host", "0.0.0.0",
188
- "--controller", "http://localhost:10000",
189
- "--port", "40000",
190
- "--worker", "http://localhost:40000",
191
- "--model-path", LLAMA_OMNI2_MODEL_PATH,
192
- "--model-name", LLAMA_OMNI2_MODEL_NAME
193
- ]
194
 
195
- env = os.environ.copy()
196
- if EXTRACTION_DIR not in env.get("PYTHONPATH", ""):
197
- env["PYTHONPATH"] = f"{EXTRACTION_DIR}:{env.get('PYTHONPATH', '')}"
 
 
 
 
 
 
198
 
199
- print(f"Running: {' '.join(cmd)}")
200
- print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
201
-
202
- process = subprocess.Popen(cmd, env=env)
203
- print(f"Model worker started with PID: {process.pid}")
204
- return process
205
 
206
  def start_gradio_server():
207
  """Start the LLaMA-Omni2 Gradio web server directly"""
208
  print("=== Starting LLaMA-Omni2 Gradio Server ===")
209
 
210
  # First try to use our custom implementation
211
- direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run_gradio_directly.py")
212
  if os.path.exists(direct_gradio_path):
213
  print(f"Using custom Gradio server implementation: {direct_gradio_path}")
214
  cmd = [
@@ -224,32 +247,59 @@ def start_gradio_server():
224
  print(f"Gradio server started with PID: {process.pid}")
225
  return process
226
 
227
- # Fall back to the extracted script
228
- gradio_server_path = os.path.join(EXTRACTION_DIR, "llama_omni2", "serve", "gradio_web_server.py")
229
 
230
- if not os.path.exists(gradio_server_path):
231
- print(f"Gradio server script not found at {gradio_server_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return None
233
-
234
- cmd = [
235
- sys.executable, gradio_server_path,
236
- "--host", "0.0.0.0",
237
- "--port", "7860",
238
- "--controller-url", "http://localhost:10000",
239
- "--model-list-mode", "reload",
240
- "--vocoder-dir", COSYVOICE_PATH
241
- ]
242
-
243
- env = os.environ.copy()
244
- if EXTRACTION_DIR not in env.get("PYTHONPATH", ""):
245
- env["PYTHONPATH"] = f"{EXTRACTION_DIR}:{env.get('PYTHONPATH', '')}"
246
-
247
- print(f"Running: {' '.join(cmd)}")
248
- print(f"With PYTHONPATH: {env.get('PYTHONPATH')}")
249
-
250
- process = subprocess.Popen(cmd, env=env)
251
- print(f"Gradio server started with PID: {process.pid}")
252
- return process
253
 
254
  def patch_extracted_files(extraction_dir):
255
  """Patch the extracted Python files to handle missing imports"""
@@ -338,25 +388,31 @@ def main():
338
  print("Checking and installing dependencies...")
339
  download_dependencies()
340
 
341
- # Run extraction script if not already extracted
342
- if not os.path.exists(os.path.join(EXTRACTION_DIR, "llama_omni2", "serve")):
343
- if not run_extraction_script():
344
- print("Failed to extract LLaMA-Omni2 scripts. Exiting.")
345
- return 1
346
- else:
347
- print("LLaMA-Omni2 scripts already extracted.")
348
 
349
  # Ensure the module structure is complete
350
  ensure_module_structure(EXTRACTION_DIR)
351
 
352
- # Patch the extracted Python files to handle missing imports
353
- patch_extracted_files(EXTRACTION_DIR)
354
 
355
  # Add the extraction dir to Python path
356
  if EXTRACTION_DIR not in sys.path:
357
  sys.path.insert(0, EXTRACTION_DIR)
358
  print(f"Added {EXTRACTION_DIR} to sys.path")
359
 
 
 
 
 
 
 
 
 
360
  # Start controller
361
  controller_process = start_controller()
362
  if not controller_process:
@@ -365,7 +421,7 @@ def main():
365
 
366
  # Wait for controller to initialize
367
  print("Waiting for controller to initialize...")
368
- time.sleep(15)
369
 
370
  # Start model worker
371
  model_worker_process = start_model_worker()
@@ -374,9 +430,9 @@ def main():
374
  controller_process.terminate()
375
  return 1
376
 
377
- # Wait for model to load
378
- print("Waiting for model to load (this may take several minutes)...")
379
- time.sleep(300)
380
 
381
  # Start Gradio server
382
  gradio_process = start_gradio_server()
 
114
  print("=== Starting LLaMA-Omni2 Controller ===")
115
 
116
  # First try to use our custom implementation
117
+ direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "controller.py")
118
  if os.path.exists(direct_controller_path):
119
  print(f"Using custom controller implementation: {direct_controller_path}")
120
  cmd = [
 
128
  print(f"Controller started with PID: {process.pid}")
129
  return process
130
 
131
+ # Fall back to a simple controller implementation
132
+ print("No controller script found. Implementing a simple controller...")
133
 
134
+ try:
135
+ from fastapi import FastAPI, HTTPException
136
+ import uvicorn
137
+ from pydantic import BaseModel
138
+ import threading
139
+
140
+ app = FastAPI()
141
+
142
+ class ModelInfo(BaseModel):
143
+ model_name: str
144
+ worker_name: str
145
+ worker_addr: str
146
+
147
+ # Simple in-memory storage
148
+ registered_models = {}
149
+
150
+ @app.get("/")
151
+ def read_root():
152
+ return {"status": "ok", "models": list(registered_models.keys())}
153
+
154
+ @app.get("/api/v1/models")
155
+ def list_models():
156
+ return {"models": list(registered_models.keys())}
157
+
158
+ @app.post("/api/v1/register_worker")
159
+ def register_worker(model_info: ModelInfo):
160
+ registered_models[model_info.model_name] = {
161
+ "worker_name": model_info.worker_name,
162
+ "worker_addr": model_info.worker_addr
163
+ }
164
+ return {"status": "ok"}
165
+
166
+ # Start a simple controller
167
+ def run_controller():
168
+ uvicorn.run(app, host="0.0.0.0", port=10000)
169
+
170
+ thread = threading.Thread(target=run_controller, daemon=True)
171
+ thread.start()
172
+
173
+ print("Simple controller started on port 10000")
174
+ # Return a dummy process for compatibility
175
+ class DummyProcess:
176
+ def __init__(self):
177
+ self.pid = 0
178
+ def terminate(self):
179
+ pass
180
+ def poll(self):
181
+ return None
182
+ def wait(self, timeout=None):
183
+ pass
184
+
185
+ return DummyProcess()
186
+
187
+ except ImportError as e:
188
+ print(f"Failed to create simple controller: {e}")
189
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  def start_model_worker():
192
  """Start the LLaMA-Omni2 model worker directly"""
193
  print("=== Starting LLaMA-Omni2 Model Worker ===")
194
 
195
  # First try to use our custom implementation
196
+ direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_worker.py")
197
  if os.path.exists(direct_worker_path):
198
  print(f"Using custom model worker implementation: {direct_worker_path}")
199
  cmd = [
 
211
  print(f"Model worker started with PID: {process.pid}")
212
  return process
213
 
214
+ # Fall back to a simple implementation
215
+ print("No model worker script found. Will try to start Gradio directly with the model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ class DummyProcess:
218
+ def __init__(self):
219
+ self.pid = 0
220
+ def terminate(self):
221
+ pass
222
+ def poll(self):
223
+ return None
224
+ def wait(self, timeout=None):
225
+ pass
226
 
227
+ return DummyProcess()
 
 
 
 
 
228
 
229
  def start_gradio_server():
230
  """Start the LLaMA-Omni2 Gradio web server directly"""
231
  print("=== Starting LLaMA-Omni2 Gradio Server ===")
232
 
233
  # First try to use our custom implementation
234
+ direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gradio_web_server.py")
235
  if os.path.exists(direct_gradio_path):
236
  print(f"Using custom Gradio server implementation: {direct_gradio_path}")
237
  cmd = [
 
247
  print(f"Gradio server started with PID: {process.pid}")
248
  return process
249
 
250
+ # Fall back to a simple Gradio implementation
251
+ print("No Gradio server found. Attempting to create a simple interface...")
252
 
253
+ try:
254
+ import gradio as gr
255
+ import threading
256
+ from transformers import AutoModelForCausalLM, AutoTokenizer
257
+
258
+ # Simple function to launch a basic Gradio interface
259
+ def launch_simple_gradio():
260
+ try:
261
+ print(f"Loading model from {LLAMA_OMNI2_MODEL_PATH}...")
262
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
263
+ model = AutoModelForCausalLM.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
264
+
265
+ def generate_text(input_text):
266
+ inputs = tokenizer(input_text, return_tensors="pt")
267
+ outputs = model.generate(inputs.input_ids, max_length=100)
268
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
269
+
270
+ with gr.Blocks() as demo:
271
+ gr.Markdown("# LLaMA-Omni2 Simple Interface")
272
+ with gr.Tab("Text Generation"):
273
+ input_text = gr.Textbox(label="Input Text")
274
+ output_text = gr.Textbox(label="Generated Text")
275
+ generate_btn = gr.Button("Generate")
276
+ generate_btn.click(generate_text, inputs=input_text, outputs=output_text)
277
+
278
+ demo.launch(server_name="0.0.0.0", server_port=7860)
279
+
280
+ except Exception as e:
281
+ print(f"Error in simple Gradio interface: {e}")
282
+
283
+ thread = threading.Thread(target=launch_simple_gradio, daemon=True)
284
+ thread.start()
285
+
286
+ print("Simple Gradio interface started on port 7860")
287
+
288
+ class DummyProcess:
289
+ def __init__(self):
290
+ self.pid = 0
291
+ def terminate(self):
292
+ pass
293
+ def poll(self):
294
+ return None
295
+ def wait(self, timeout=None):
296
+ pass
297
+
298
+ return DummyProcess()
299
+
300
+ except ImportError as e:
301
+ print(f"Failed to create simple Gradio interface: {e}")
302
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  def patch_extracted_files(extraction_dir):
305
  """Patch the extracted Python files to handle missing imports"""
 
388
  print("Checking and installing dependencies...")
389
  download_dependencies()
390
 
391
+ # Create directories directly instead of using extraction script
392
+ print("Creating necessary directories...")
393
+ os.makedirs(EXTRACTION_DIR, exist_ok=True)
394
+ os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2"), exist_ok=True)
395
+ os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2", "serve"), exist_ok=True)
 
 
396
 
397
  # Ensure the module structure is complete
398
  ensure_module_structure(EXTRACTION_DIR)
399
 
400
+ # Skip patching files as we're not extracting anything
401
+ print("Skipping file patching as we're not running extraction")
402
 
403
  # Add the extraction dir to Python path
404
  if EXTRACTION_DIR not in sys.path:
405
  sys.path.insert(0, EXTRACTION_DIR)
406
  print(f"Added {EXTRACTION_DIR} to sys.path")
407
 
408
+ # Skip directly to model download and starting services
409
+ print("Proceeding directly to model download and starting services...")
410
+
411
+ # Make directories for models
412
+ os.makedirs(MODELS_DIR, exist_ok=True)
413
+ os.makedirs(LLAMA_OMNI2_MODEL_PATH, exist_ok=True)
414
+ os.makedirs(COSYVOICE_PATH, exist_ok=True)
415
+
416
  # Start controller
417
  controller_process = start_controller()
418
  if not controller_process:
 
421
 
422
  # Wait for controller to initialize
423
  print("Waiting for controller to initialize...")
424
+ time.sleep(5)
425
 
426
  # Start model worker
427
  model_worker_process = start_model_worker()
 
430
  controller_process.terminate()
431
  return 1
432
 
433
+ # Wait for model to load - reduced from 300 seconds to 30 seconds
434
+ print("Waiting for model worker to initialize...")
435
+ time.sleep(30)
436
 
437
  # Start Gradio server
438
  gradio_process = start_gradio_server()