#!/usr/bin/env python3 """ A standalone implementation of the LLaMA-Omni2 controller that doesn't rely on any LLaMA-Omni2 imports. """ import argparse import asyncio import dataclasses import json import logging import time from typing import Dict, List, Optional, Set, Tuple, Union import fastapi from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn # Define constants CONTROLLER_HEART_BEAT_EXPIRATION = 120 MODEL_WORKER_API_TIMEOUT = 100 # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define data models using dataclasses instead of pydantic @dataclasses.dataclass class ModelInfo: id: str name: str worker_names: List[str] time: float = dataclasses.field(default_factory=time.time) @dataclasses.dataclass class WorkerInfo: worker_name: str model_names: List[str] check_heart_beat: bool last_heart_beat: float = dataclasses.field(default_factory=time.time) # Global state worker_info: Dict[str, WorkerInfo] = {} model_info: Dict[str, ModelInfo] = {} worker_addr: Dict[str, str] = {} # FastAPI app app = fastapi.FastAPI() @app.post("/register_worker") async def register_worker(request: Request): data = await request.json() worker_name = data.get("worker_name") worker_url = data.get("worker_url") model_names = data.get("model_names", []) check_heart_beat = data.get("check_heart_beat", True) logger.info(f"Registering worker {worker_name} at {worker_url}") worker_info[worker_name] = WorkerInfo( worker_name=worker_name, model_names=model_names, check_heart_beat=check_heart_beat, last_heart_beat=time.time() ) worker_addr[worker_name] = worker_url # Register models for model_name in model_names: if model_name in model_info: model_info[model_name].worker_names.append(worker_name) else: model_id = f"model-{len(model_info)}" model_info[model_name] = ModelInfo( id=model_id, name=model_name, worker_names=[worker_name] ) return {"result": "success"} @app.post("/unregister_worker") async def unregister_worker(request: Request): data = await request.json() worker_name = data.get("worker_name") logger.info(f"Unregistering worker {worker_name}") if worker_name in worker_info: for model_name in worker_info[worker_name].model_names: if model_name in model_info: if worker_name in model_info[model_name].worker_names: model_info[model_name].worker_names.remove(worker_name) if len(model_info[model_name].worker_names) == 0: del model_info[model_name] del worker_info[worker_name] if worker_name in worker_addr: del worker_addr[worker_name] return {"result": "success"} @app.post("/heart_beat") async def heart_beat(request: Request): data = await request.json() worker_name = data.get("worker_name") if worker_name not in worker_info or worker_name not in worker_addr: return {"result": "failure", "error": f"Worker {worker_name} not found"} worker_info[worker_name].last_heart_beat = time.time() return {"result": "success"} @app.get("/list_models") async def list_models(): models = [] for name, info in model_info.items(): models.append({ "id": info.id, "name": name }) return {"models": models} @app.get("/get_worker_address") async def get_worker_address(model_name: str): if model_name not in model_info or not model_info[model_name].worker_names: return JSONResponse( {"error": f"No available workers for model {model_name}"}, status_code=400 ) # Simple round-robin selection among available workers workers = model_info[model_name].worker_names selected_worker = workers[int(time.time()) % len(workers)] return {"address": worker_addr.get(selected_worker)} @app.get("/worker_status") async def worker_status(): return {"worker_info": [ { "name": name, "address": worker_addr.get(name), "models": info.model_names, "last_heart_beat": info.last_heart_beat, "status": "alive" if not info.check_heart_beat or (time.time() - info.last_heart_beat) < CONTROLLER_HEART_BEAT_EXPIRATION else "dead" } for name, info in worker_info.items() ]} @app.get("/status") async def status(): return { "model_info": [ { "name": name, "id": info.id, "workers": info.worker_names } for name, info in model_info.items() ], "worker_info": [ { "name": name, "address": worker_addr.get(name), "models": info.model_names, "last_heart_beat": info.last_heart_beat, "status": "alive" if not info.check_heart_beat or (time.time() - info.last_heart_beat) < CONTROLLER_HEART_BEAT_EXPIRATION else "dead" } for name, info in worker_info.items() ] } # Run the server def main(): parser = argparse.ArgumentParser(description="Controller for LLaMA-Omni2") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=10000) args = parser.parse_args() logger.info(f"Starting controller server at http://{args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()