llama-omni / launch_llama_omni2.py
marcosremar2's picture
fddfddffd
f4aa7db
raw
history blame
17.1 kB
#!/usr/bin/env python3
"""
LLaMA-Omni2 Direct Launcher
---------------------------
This script extracts and directly runs the LLaMA-Omni2 components without
relying on package imports.
"""
import os
import sys
import subprocess
import time
import argparse
import shutil
import importlib.util
import tempfile
# Define paths
EXTRACTION_DIR = "/home/user/app/llama_omni2_extracted"
MODELS_DIR = "/home/user/app/models"
LLAMA_OMNI2_MODEL_NAME = "LLaMA-Omni2-0.5B"
LLAMA_OMNI2_MODEL_PATH = f"{MODELS_DIR}/{LLAMA_OMNI2_MODEL_NAME}"
COSYVOICE_PATH = f"{MODELS_DIR}/cosy2_decoder"
# Additional imports
def download_dependencies():
"""Download and install required Python packages for LLaMA-Omni2"""
print("Installing required dependencies...")
dependencies = [
"gradio>=3.50.2",
"fastapi",
"uvicorn",
"pydantic",
"transformers>=4.36.2",
"sentencepiece"
]
try:
subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade"] + dependencies, check=True)
print("Dependencies installed successfully")
return True
except subprocess.CalledProcessError as e:
print(f"Error installing dependencies: {e}")
return False
def ensure_module_structure(extraction_dir):
"""Ensure that the extracted module has the necessary structure"""
print("Ensuring proper module structure...")
# Create __init__.py files if they don't exist
module_dirs = [
os.path.join(extraction_dir, "llama_omni2"),
os.path.join(extraction_dir, "llama_omni2", "serve"),
os.path.join(extraction_dir, "llama_omni2", "model"),
os.path.join(extraction_dir, "llama_omni2", "common")
]
for dir_path in module_dirs:
os.makedirs(dir_path, exist_ok=True)
init_file = os.path.join(dir_path, "__init__.py")
if not os.path.exists(init_file):
with open(init_file, 'w') as f:
f.write("# Auto-generated __init__.py file\n")
print(f"Created {init_file}")
# Create missing module files with required constants and functions
dummy_modules = {
# Utils module
os.path.join(extraction_dir, "llama_omni2", "utils.py"): """
# Dummy utils module
def dummy_function():
pass
""",
# Constants module - required by controller.py and model_worker.py
os.path.join(extraction_dir, "llama_omni2", "constants.py"): """
# Constants required by LLaMA-Omni2 modules
# Controller constants
CONTROLLER_HEART_BEAT_EXPIRATION = 120
CONTROLLER_STATUS_POLLING_INTERVAL = 15
# Worker constants
WORKER_HEART_BEAT_INTERVAL = 30
WORKER_API_TIMEOUT = 100
# Other constants that might be needed
DEFAULT_PORT = 8000
"""
}
for file_path, content in dummy_modules.items():
if not os.path.exists(file_path):
with open(file_path, 'w') as f:
f.write(content)
print(f"Created {file_path}")
return True
def start_controller():
"""Start the LLaMA-Omni2 controller directly"""
print("=== Starting LLaMA-Omni2 Controller ===")
# First try to use our custom implementation
direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "controller.py")
if os.path.exists(direct_controller_path):
print(f"Using custom controller implementation: {direct_controller_path}")
cmd = [
sys.executable, direct_controller_path,
"--host", "0.0.0.0",
"--port", "10000"
]
env = os.environ.copy()
process = subprocess.Popen(cmd, env=env)
print(f"Controller started with PID: {process.pid}")
return process
# Fall back to a simple controller implementation
print("No controller script found. Implementing a simple controller...")
try:
from fastapi import FastAPI, HTTPException
import uvicorn
from pydantic import BaseModel
import threading
app = FastAPI()
class ModelInfo(BaseModel):
model_name: str
worker_name: str
worker_addr: str
# Simple in-memory storage
registered_models = {}
@app.get("/")
def read_root():
return {"status": "ok", "models": list(registered_models.keys())}
@app.get("/api/v1/models")
def list_models():
return {"models": list(registered_models.keys())}
@app.post("/api/v1/register_worker")
def register_worker(model_info: ModelInfo):
registered_models[model_info.model_name] = {
"worker_name": model_info.worker_name,
"worker_addr": model_info.worker_addr
}
return {"status": "ok"}
# Start a simple controller
def run_controller():
uvicorn.run(app, host="0.0.0.0", port=10000)
thread = threading.Thread(target=run_controller, daemon=True)
thread.start()
print("Simple controller started on port 10000")
# Return a dummy process for compatibility
class DummyProcess:
def __init__(self):
self.pid = 0
def terminate(self):
pass
def poll(self):
return None
def wait(self, timeout=None):
pass
return DummyProcess()
except ImportError as e:
print(f"Failed to create simple controller: {e}")
return None
def start_model_worker():
"""Start the LLaMA-Omni2 model worker directly"""
print("=== Starting LLaMA-Omni2 Model Worker ===")
# First try to use our custom implementation
direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_worker.py")
if os.path.exists(direct_worker_path):
print(f"Using custom model worker implementation: {direct_worker_path}")
cmd = [
sys.executable, direct_worker_path,
"--host", "0.0.0.0",
"--controller", "http://localhost:10000",
"--port", "40000",
"--worker", "http://localhost:40000",
"--model-path", LLAMA_OMNI2_MODEL_PATH,
"--model-name", LLAMA_OMNI2_MODEL_NAME
]
env = os.environ.copy()
process = subprocess.Popen(cmd, env=env)
print(f"Model worker started with PID: {process.pid}")
return process
# Fall back to a simple implementation
print("No model worker script found. Will try to start Gradio directly with the model.")
class DummyProcess:
def __init__(self):
self.pid = 0
def terminate(self):
pass
def poll(self):
return None
def wait(self, timeout=None):
pass
return DummyProcess()
def start_gradio_server():
"""Start the LLaMA-Omni2 Gradio web server directly"""
print("=== Starting LLaMA-Omni2 Gradio Server ===")
# First try to use our custom implementation
direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gradio_web_server.py")
if os.path.exists(direct_gradio_path):
print(f"Using custom Gradio server implementation: {direct_gradio_path}")
cmd = [
sys.executable, direct_gradio_path,
"--host", "0.0.0.0",
"--port", "7860",
"--controller-url", "http://localhost:10000",
"--vocoder-dir", COSYVOICE_PATH
]
env = os.environ.copy()
process = subprocess.Popen(cmd, env=env)
print(f"Gradio server started with PID: {process.pid}")
return process
# Fall back to a simple Gradio implementation
print("No Gradio server found. Attempting to create a simple interface...")
try:
import gradio as gr
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Simple function to launch a basic Gradio interface
def launch_simple_gradio():
try:
print(f"Loading model from {LLAMA_OMNI2_MODEL_PATH}...")
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(LLAMA_OMNI2_MODEL_PATH).to(device)
def generate_text(input_text):
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs.input_ids, max_length=100)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
with gr.Blocks() as demo:
gr.Markdown("# LLaMA-Omni2 Simple Interface")
with gr.Tab("Text Generation"):
input_text = gr.Textbox(label="Input Text")
output_text = gr.Textbox(label="Generated Text")
generate_btn = gr.Button("Generate")
generate_btn.click(generate_text, inputs=input_text, outputs=output_text)
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error in simple Gradio interface: {e}")
thread = threading.Thread(target=launch_simple_gradio, daemon=True)
thread.start()
print("Simple Gradio interface started on port 7860")
class DummyProcess:
def __init__(self):
self.pid = 0
def terminate(self):
pass
def poll(self):
return None
def wait(self, timeout=None):
pass
return DummyProcess()
except ImportError as e:
print(f"Failed to create simple Gradio interface: {e}")
return None
def patch_extracted_files(extraction_dir):
"""Patch the extracted Python files to handle missing imports"""
print("Patching extracted Python files to handle missing imports...")
# Define files to patch and their imports to check/fix
files_to_patch = {
os.path.join(extraction_dir, "llama_omni2", "serve", "controller.py"): [
"from llama_omni2.constants import",
"from llama_omni2.model import",
"from llama_omni2.common import",
],
os.path.join(extraction_dir, "llama_omni2", "serve", "model_worker.py"): [
"from llama_omni2.constants import",
"from llama_omni2.model import",
"from llama_omni2.common import",
],
os.path.join(extraction_dir, "llama_omni2", "serve", "gradio_web_server.py"): [
"from llama_omni2.constants import",
"from llama_omni2.model import",
"from llama_omni2.common import",
]
}
patched_files = []
for file_path, imports_to_check in files_to_patch.items():
if not os.path.exists(file_path):
print(f"Warning: File {file_path} not found, skipping patch")
continue
with open(file_path, 'r') as f:
content = f.read()
original_content = content
modified = False
# Add try-except blocks around problematic imports
for import_line in imports_to_check:
if import_line in content:
# Find the full line containing this import
import_lines = [line for line in content.split('\n') if import_line in line]
for full_line in import_lines:
# Extract the variable names being imported
try:
imported_vars = full_line.split('import')[1].strip().split(',')
imported_vars = [var.strip() for var in imported_vars]
# Create a try-except block with fallback definitions
replacement = f"""try:
{full_line}
except ImportError:
# Auto-generated fallback for missing import
print("Warning: Creating fallback for missing import: {full_line}")
"""
for var in imported_vars:
if var: # Skip empty strings
replacement += f" {var} = object() # Dummy placeholder\n"
# Replace the original import with the try-except block
content = content.replace(full_line, replacement)
modified = True
except Exception as e:
print(f"Error processing import line '{full_line}': {e}")
# Write the modified content back if changes were made
if modified:
with open(file_path, 'w') as f:
f.write(content)
patched_files.append(file_path)
print(f"Patched file: {file_path}")
if patched_files:
print(f"Successfully patched {len(patched_files)} files")
else:
print("No files needed patching")
return patched_files
def main():
"""Main function to extract and run LLaMA-Omni2 components"""
print("LLaMA-Omni2 Direct Launcher")
print("==========================")
# Install dependencies first
print("Checking and installing dependencies...")
download_dependencies()
# Create directories directly instead of using extraction script
print("Creating necessary directories...")
os.makedirs(EXTRACTION_DIR, exist_ok=True)
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2"), exist_ok=True)
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2", "serve"), exist_ok=True)
# Ensure the module structure is complete
ensure_module_structure(EXTRACTION_DIR)
# Skip patching files as we're not extracting anything
print("Skipping file patching as we're not running extraction")
# Add the extraction dir to Python path
if EXTRACTION_DIR not in sys.path:
sys.path.insert(0, EXTRACTION_DIR)
print(f"Added {EXTRACTION_DIR} to sys.path")
# Skip directly to model download and starting services
print("Proceeding directly to model download and starting services...")
# Make directories for models
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(LLAMA_OMNI2_MODEL_PATH, exist_ok=True)
os.makedirs(COSYVOICE_PATH, exist_ok=True)
# Start controller
controller_process = start_controller()
if not controller_process:
print("Failed to start controller. Exiting.")
return 1
# Wait for controller to initialize
print("Waiting for controller to initialize...")
time.sleep(5)
# Start model worker
model_worker_process = start_model_worker()
if not model_worker_process:
print("Failed to start model worker. Shutting down controller.")
controller_process.terminate()
return 1
# Wait for model to load - reduced from 300 seconds to 30 seconds
print("Waiting for model worker to initialize...")
time.sleep(30)
# Start Gradio server
gradio_process = start_gradio_server()
if not gradio_process:
print("Failed to start Gradio server. Shutting down other processes.")
model_worker_process.terminate()
controller_process.terminate()
return 1
print("\nAll components started successfully!")
print(f"Gradio interface should be available at http://0.0.0.0:7860")
try:
# Wait for Gradio process to finish
gradio_process.wait()
except KeyboardInterrupt:
print("\nReceived keyboard interrupt. Shutting down...")
finally:
# Cleanup
for process, name in [
(gradio_process, "Gradio server"),
(model_worker_process, "Model worker"),
(controller_process, "Controller")
]:
if process and process.poll() is None:
print(f"Shutting down {name}...")
process.terminate()
try:
process.wait(timeout=30)
except subprocess.TimeoutExpired:
print(f"{name} did not terminate gracefully. Killing...")
process.kill()
print("All processes have been shut down.")
return 0
if __name__ == "__main__":
sys.exit(main())