Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
""" | |
LLaMA-Omni2 Direct Launcher | |
--------------------------- | |
This script extracts and directly runs the LLaMA-Omni2 components without | |
relying on package imports. | |
""" | |
import os | |
import sys | |
import subprocess | |
import time | |
import argparse | |
import shutil | |
import importlib.util | |
import tempfile | |
# Define paths | |
EXTRACTION_DIR = "/home/user/app/llama_omni2_extracted" | |
MODELS_DIR = "/home/user/app/models" | |
LLAMA_OMNI2_MODEL_NAME = "LLaMA-Omni2-0.5B" | |
LLAMA_OMNI2_MODEL_PATH = f"{MODELS_DIR}/{LLAMA_OMNI2_MODEL_NAME}" | |
COSYVOICE_PATH = f"{MODELS_DIR}/cosy2_decoder" | |
# Additional imports | |
def download_dependencies(): | |
"""Download and install required Python packages for LLaMA-Omni2""" | |
print("Installing required dependencies...") | |
dependencies = [ | |
"gradio>=3.50.2", | |
"fastapi", | |
"uvicorn", | |
"pydantic", | |
"transformers>=4.36.2", | |
"sentencepiece" | |
] | |
try: | |
subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade"] + dependencies, check=True) | |
print("Dependencies installed successfully") | |
return True | |
except subprocess.CalledProcessError as e: | |
print(f"Error installing dependencies: {e}") | |
return False | |
def ensure_module_structure(extraction_dir): | |
"""Ensure that the extracted module has the necessary structure""" | |
print("Ensuring proper module structure...") | |
# Create __init__.py files if they don't exist | |
module_dirs = [ | |
os.path.join(extraction_dir, "llama_omni2"), | |
os.path.join(extraction_dir, "llama_omni2", "serve"), | |
os.path.join(extraction_dir, "llama_omni2", "model"), | |
os.path.join(extraction_dir, "llama_omni2", "common") | |
] | |
for dir_path in module_dirs: | |
os.makedirs(dir_path, exist_ok=True) | |
init_file = os.path.join(dir_path, "__init__.py") | |
if not os.path.exists(init_file): | |
with open(init_file, 'w') as f: | |
f.write("# Auto-generated __init__.py file\n") | |
print(f"Created {init_file}") | |
# Create missing module files with required constants and functions | |
dummy_modules = { | |
# Utils module | |
os.path.join(extraction_dir, "llama_omni2", "utils.py"): """ | |
# Dummy utils module | |
def dummy_function(): | |
pass | |
""", | |
# Constants module - required by controller.py and model_worker.py | |
os.path.join(extraction_dir, "llama_omni2", "constants.py"): """ | |
# Constants required by LLaMA-Omni2 modules | |
# Controller constants | |
CONTROLLER_HEART_BEAT_EXPIRATION = 120 | |
CONTROLLER_STATUS_POLLING_INTERVAL = 15 | |
# Worker constants | |
WORKER_HEART_BEAT_INTERVAL = 30 | |
WORKER_API_TIMEOUT = 100 | |
# Other constants that might be needed | |
DEFAULT_PORT = 8000 | |
""" | |
} | |
for file_path, content in dummy_modules.items(): | |
if not os.path.exists(file_path): | |
with open(file_path, 'w') as f: | |
f.write(content) | |
print(f"Created {file_path}") | |
return True | |
def start_controller(): | |
"""Start the LLaMA-Omni2 controller directly""" | |
print("=== Starting LLaMA-Omni2 Controller ===") | |
# First try to use our custom implementation | |
direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "controller.py") | |
if os.path.exists(direct_controller_path): | |
print(f"Using custom controller implementation: {direct_controller_path}") | |
cmd = [ | |
sys.executable, direct_controller_path, | |
"--host", "0.0.0.0", | |
"--port", "10000" | |
] | |
env = os.environ.copy() | |
process = subprocess.Popen(cmd, env=env) | |
print(f"Controller started with PID: {process.pid}") | |
return process | |
# Fall back to a simple controller implementation | |
print("No controller script found. Implementing a simple controller...") | |
try: | |
from fastapi import FastAPI, HTTPException | |
import uvicorn | |
from pydantic import BaseModel | |
import threading | |
app = FastAPI() | |
class ModelInfo(BaseModel): | |
model_name: str | |
worker_name: str | |
worker_addr: str | |
# Simple in-memory storage | |
registered_models = {} | |
def read_root(): | |
return {"status": "ok", "models": list(registered_models.keys())} | |
def list_models(): | |
return {"models": list(registered_models.keys())} | |
def register_worker(model_info: ModelInfo): | |
registered_models[model_info.model_name] = { | |
"worker_name": model_info.worker_name, | |
"worker_addr": model_info.worker_addr | |
} | |
return {"status": "ok"} | |
# Start a simple controller | |
def run_controller(): | |
uvicorn.run(app, host="0.0.0.0", port=10000) | |
thread = threading.Thread(target=run_controller, daemon=True) | |
thread.start() | |
print("Simple controller started on port 10000") | |
# Return a dummy process for compatibility | |
class DummyProcess: | |
def __init__(self): | |
self.pid = 0 | |
def terminate(self): | |
pass | |
def poll(self): | |
return None | |
def wait(self, timeout=None): | |
pass | |
return DummyProcess() | |
except ImportError as e: | |
print(f"Failed to create simple controller: {e}") | |
return None | |
def start_model_worker(): | |
"""Start the LLaMA-Omni2 model worker directly""" | |
print("=== Starting LLaMA-Omni2 Model Worker ===") | |
# First try to use our custom implementation | |
direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_worker.py") | |
if os.path.exists(direct_worker_path): | |
print(f"Using custom model worker implementation: {direct_worker_path}") | |
cmd = [ | |
sys.executable, direct_worker_path, | |
"--host", "0.0.0.0", | |
"--controller", "http://localhost:10000", | |
"--port", "40000", | |
"--worker", "http://localhost:40000", | |
"--model-path", LLAMA_OMNI2_MODEL_PATH, | |
"--model-name", LLAMA_OMNI2_MODEL_NAME | |
] | |
env = os.environ.copy() | |
process = subprocess.Popen(cmd, env=env) | |
print(f"Model worker started with PID: {process.pid}") | |
return process | |
# Fall back to a simple implementation | |
print("No model worker script found. Will try to start Gradio directly with the model.") | |
class DummyProcess: | |
def __init__(self): | |
self.pid = 0 | |
def terminate(self): | |
pass | |
def poll(self): | |
return None | |
def wait(self, timeout=None): | |
pass | |
return DummyProcess() | |
def start_gradio_server(): | |
"""Start the LLaMA-Omni2 Gradio web server directly""" | |
print("=== Starting LLaMA-Omni2 Gradio Server ===") | |
# First try to use our custom implementation | |
direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gradio_web_server.py") | |
if os.path.exists(direct_gradio_path): | |
print(f"Using custom Gradio server implementation: {direct_gradio_path}") | |
cmd = [ | |
sys.executable, direct_gradio_path, | |
"--host", "0.0.0.0", | |
"--port", "7860", | |
"--controller-url", "http://localhost:10000", | |
"--vocoder-dir", COSYVOICE_PATH | |
] | |
env = os.environ.copy() | |
process = subprocess.Popen(cmd, env=env) | |
print(f"Gradio server started with PID: {process.pid}") | |
return process | |
# Fall back to a simple Gradio implementation | |
print("No Gradio server found. Attempting to create a simple interface...") | |
try: | |
import gradio as gr | |
import threading | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Simple function to launch a basic Gradio interface | |
def launch_simple_gradio(): | |
try: | |
print(f"Loading model from {LLAMA_OMNI2_MODEL_PATH}...") | |
# Check for CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
if device == "cuda": | |
print(f"CUDA Device: {torch.cuda.get_device_name(0)}") | |
print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_OMNI2_MODEL_PATH) | |
model = AutoModelForCausalLM.from_pretrained(LLAMA_OMNI2_MODEL_PATH).to(device) | |
def generate_text(input_text): | |
inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
outputs = model.generate(inputs.input_ids, max_length=100) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLaMA-Omni2 Simple Interface") | |
with gr.Tab("Text Generation"): | |
input_text = gr.Textbox(label="Input Text") | |
output_text = gr.Textbox(label="Generated Text") | |
generate_btn = gr.Button("Generate") | |
generate_btn.click(generate_text, inputs=input_text, outputs=output_text) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
except Exception as e: | |
print(f"Error in simple Gradio interface: {e}") | |
thread = threading.Thread(target=launch_simple_gradio, daemon=True) | |
thread.start() | |
print("Simple Gradio interface started on port 7860") | |
class DummyProcess: | |
def __init__(self): | |
self.pid = 0 | |
def terminate(self): | |
pass | |
def poll(self): | |
return None | |
def wait(self, timeout=None): | |
pass | |
return DummyProcess() | |
except ImportError as e: | |
print(f"Failed to create simple Gradio interface: {e}") | |
return None | |
def patch_extracted_files(extraction_dir): | |
"""Patch the extracted Python files to handle missing imports""" | |
print("Patching extracted Python files to handle missing imports...") | |
# Define files to patch and their imports to check/fix | |
files_to_patch = { | |
os.path.join(extraction_dir, "llama_omni2", "serve", "controller.py"): [ | |
"from llama_omni2.constants import", | |
"from llama_omni2.model import", | |
"from llama_omni2.common import", | |
], | |
os.path.join(extraction_dir, "llama_omni2", "serve", "model_worker.py"): [ | |
"from llama_omni2.constants import", | |
"from llama_omni2.model import", | |
"from llama_omni2.common import", | |
], | |
os.path.join(extraction_dir, "llama_omni2", "serve", "gradio_web_server.py"): [ | |
"from llama_omni2.constants import", | |
"from llama_omni2.model import", | |
"from llama_omni2.common import", | |
] | |
} | |
patched_files = [] | |
for file_path, imports_to_check in files_to_patch.items(): | |
if not os.path.exists(file_path): | |
print(f"Warning: File {file_path} not found, skipping patch") | |
continue | |
with open(file_path, 'r') as f: | |
content = f.read() | |
original_content = content | |
modified = False | |
# Add try-except blocks around problematic imports | |
for import_line in imports_to_check: | |
if import_line in content: | |
# Find the full line containing this import | |
import_lines = [line for line in content.split('\n') if import_line in line] | |
for full_line in import_lines: | |
# Extract the variable names being imported | |
try: | |
imported_vars = full_line.split('import')[1].strip().split(',') | |
imported_vars = [var.strip() for var in imported_vars] | |
# Create a try-except block with fallback definitions | |
replacement = f"""try: | |
{full_line} | |
except ImportError: | |
# Auto-generated fallback for missing import | |
print("Warning: Creating fallback for missing import: {full_line}") | |
""" | |
for var in imported_vars: | |
if var: # Skip empty strings | |
replacement += f" {var} = object() # Dummy placeholder\n" | |
# Replace the original import with the try-except block | |
content = content.replace(full_line, replacement) | |
modified = True | |
except Exception as e: | |
print(f"Error processing import line '{full_line}': {e}") | |
# Write the modified content back if changes were made | |
if modified: | |
with open(file_path, 'w') as f: | |
f.write(content) | |
patched_files.append(file_path) | |
print(f"Patched file: {file_path}") | |
if patched_files: | |
print(f"Successfully patched {len(patched_files)} files") | |
else: | |
print("No files needed patching") | |
return patched_files | |
def main(): | |
"""Main function to extract and run LLaMA-Omni2 components""" | |
print("LLaMA-Omni2 Direct Launcher") | |
print("==========================") | |
# Install dependencies first | |
print("Checking and installing dependencies...") | |
download_dependencies() | |
# Create directories directly instead of using extraction script | |
print("Creating necessary directories...") | |
os.makedirs(EXTRACTION_DIR, exist_ok=True) | |
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2"), exist_ok=True) | |
os.makedirs(os.path.join(EXTRACTION_DIR, "llama_omni2", "serve"), exist_ok=True) | |
# Ensure the module structure is complete | |
ensure_module_structure(EXTRACTION_DIR) | |
# Skip patching files as we're not extracting anything | |
print("Skipping file patching as we're not running extraction") | |
# Add the extraction dir to Python path | |
if EXTRACTION_DIR not in sys.path: | |
sys.path.insert(0, EXTRACTION_DIR) | |
print(f"Added {EXTRACTION_DIR} to sys.path") | |
# Skip directly to model download and starting services | |
print("Proceeding directly to model download and starting services...") | |
# Make directories for models | |
os.makedirs(MODELS_DIR, exist_ok=True) | |
os.makedirs(LLAMA_OMNI2_MODEL_PATH, exist_ok=True) | |
os.makedirs(COSYVOICE_PATH, exist_ok=True) | |
# Start controller | |
controller_process = start_controller() | |
if not controller_process: | |
print("Failed to start controller. Exiting.") | |
return 1 | |
# Wait for controller to initialize | |
print("Waiting for controller to initialize...") | |
time.sleep(5) | |
# Start model worker | |
model_worker_process = start_model_worker() | |
if not model_worker_process: | |
print("Failed to start model worker. Shutting down controller.") | |
controller_process.terminate() | |
return 1 | |
# Wait for model to load - reduced from 300 seconds to 30 seconds | |
print("Waiting for model worker to initialize...") | |
time.sleep(30) | |
# Start Gradio server | |
gradio_process = start_gradio_server() | |
if not gradio_process: | |
print("Failed to start Gradio server. Shutting down other processes.") | |
model_worker_process.terminate() | |
controller_process.terminate() | |
return 1 | |
print("\nAll components started successfully!") | |
print(f"Gradio interface should be available at http://0.0.0.0:7860") | |
try: | |
# Wait for Gradio process to finish | |
gradio_process.wait() | |
except KeyboardInterrupt: | |
print("\nReceived keyboard interrupt. Shutting down...") | |
finally: | |
# Cleanup | |
for process, name in [ | |
(gradio_process, "Gradio server"), | |
(model_worker_process, "Model worker"), | |
(controller_process, "Controller") | |
]: | |
if process and process.poll() is None: | |
print(f"Shutting down {name}...") | |
process.terminate() | |
try: | |
process.wait(timeout=30) | |
except subprocess.TimeoutExpired: | |
print(f"{name} did not terminate gracefully. Killing...") | |
process.kill() | |
print("All processes have been shut down.") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |