#!/usr/bin/env python3 """ A simplified implementation of the LLaMA-Omni2 model worker that doesn't rely on deep LLaMA-Omni2 imports. """ import argparse import asyncio import json import logging import os import re import threading import time import uuid from typing import Dict, List, Optional, Tuple import aiohttp import fastapi from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import gradio as gr import uvicorn from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define constants WORKER_HEART_BEAT_INTERVAL = 30 CONTROLLER_HEART_BEAT_EXPIRATION = 120 class ModelWorker: def __init__( self, controller_addr: str, worker_addr: str, worker_id: str, model_path: str, model_name: str, device: str = "cpu", limit_worker_concurrency: int = 5, ): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id self.model_path = model_path self.model_name = model_name self.device = device self.limit_worker_concurrency = limit_worker_concurrency # Track current requests self.lock = asyncio.Lock() self.messages = {} self.sem = asyncio.Semaphore(limit_worker_concurrency) # Placeholders - the real implementation would load the model logger.info(f"Loading model from {model_path}...") try: self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = None # In a real implementation, we would load the model here logger.info(f"Model initialization successful (tokenizer only, no model)") except Exception as e: logger.error(f"Failed to load model: {e}") logger.info("Using dummy model instead") self.tokenizer = None self.model = None logger.info(f"Model loaded successfully ({model_name})") async def generate_response(self, request_data): """Generate a response (simulated)""" prompt = request_data.get("prompt", "") response = f"This is a simulated response for prompt: {prompt[:30]}..." return response async def register_to_controller(self): """Register this worker with the controller""" controller_addr = self.controller_addr worker_addr = self.worker_addr worker_id = self.worker_id model_name = self.model_name data = { "worker_name": worker_id, "worker_url": worker_addr, "model_names": [model_name], "check_heart_beat": True, } logger.info(f"Register to controller at {controller_addr}") async with aiohttp.ClientSession() as session: async with session.post( f"{controller_addr}/register_worker", json=data, timeout=30, ) as response: if response.status != 200: logger.error(f"Failed to register to controller: {await response.text()}") return False else: logger.info(f"Registered to controller successfully") return True async def send_heart_beat(self): """Send a heartbeat to the controller periodically""" controller_addr = self.controller_addr worker_id = self.worker_id data = { "worker_name": worker_id, } async with aiohttp.ClientSession() as session: while True: try: async with session.post( f"{controller_addr}/heart_beat", json=data, timeout=30, ) as response: if response.status != 200: logger.error(f"Failed to send heart beat: {await response.text()}") except Exception as e: logger.error(f"Error sending heart beat: {e}") await asyncio.sleep(WORKER_HEART_BEAT_INTERVAL) # FastAPI app app = fastapi.FastAPI() @app.post("/generate") async def generate(request: Request): """Generate text based on the prompt""" global model_worker if not model_worker: return JSONResponse( {"error": "Model worker not initialized"}, status_code=500, ) data = await request.json() response = await model_worker.generate_response(data) return {"response": response} @app.get("/status") async def status(): """Get the status of the worker""" global model_worker if not model_worker: return {"status": "offline"} return { "status": "online", "model_name": model_worker.model_name, "worker_id": model_worker.worker_id, } # Global model worker instance model_worker = None def start_background_tasks(app): """Start background tasks when the app starts""" asyncio.create_task(model_worker.register_to_controller()) asyncio.create_task(model_worker.send_heart_beat()) # Run the server def main(): global model_worker parser = argparse.ArgumentParser(description="Model worker for LLaMA-Omni2") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=40000) parser.add_argument("--controller", type=str, default="http://localhost:10000") parser.add_argument("--worker", type=str, default="http://localhost:40000") parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-name", type=str, required=True) parser.add_argument("--limit-worker-concurrency", type=int, default=5) parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() logger.info(f"Initializing model worker with model {args.model_name}") # Initialize the model worker worker_id = f"worker-{str(uuid.uuid4())[:8]}" model_worker = ModelWorker( controller_addr=args.controller, worker_addr=args.worker, worker_id=worker_id, model_path=args.model_path, model_name=args.model_name, device=args.device, limit_worker_concurrency=args.limit_worker_concurrency, ) # Start the FastAPI app with background tasks app.add_event_handler("startup", lambda: start_background_tasks(app)) logger.info(f"Starting model worker server at http://{args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()