import argparse import asyncio import json import time from fastapi import FastAPI, WebSocket, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn from typing import Dict, List, Optional, Union app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Store worker information worker_info = {} @app.get("/status") async def get_status(): """Get the status of the controller.""" return {"status": "ok", "worker_count": len(worker_info)} @app.get("/worker_info") async def get_worker_info(): """Get information about all registered workers.""" return {"worker_info": worker_info} @app.post("/register_worker") async def register_worker(worker_info_data: Dict): """Register a new worker.""" worker_name = worker_info_data.get("name") worker_url = worker_info_data.get("url") if not worker_name or not worker_url: raise HTTPException(status_code=400, detail="Missing name or URL for worker") models = worker_info_data.get("models", []) worker_info[worker_name] = { "url": worker_url, "models": models, "status": "alive", "last_heartbeat": time.time() } return {"status": "registered", "worker_name": worker_name} @app.post("/unregister_worker") async def unregister_worker(worker_name: str): """Unregister a worker.""" if worker_name in worker_info: del worker_info[worker_name] return {"status": "unregistered", "worker_name": worker_name} else: raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found") @app.post("/heartbeat") async def heartbeat(worker_data: Dict): """Process worker heartbeat.""" worker_name = worker_data.get("name") if worker_name in worker_info: worker_info[worker_name]["last_heartbeat"] = time.time() worker_info[worker_name]["status"] = "alive" return {"status": "received"} else: raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found") @app.get("/get_worker_address") async def get_worker_address(model_name: str): """Get the address of a worker that hosts the requested model.""" for name, info in worker_info.items(): if model_name in info["models"] and info["status"] == "alive": return {"worker_address": info["url"]} raise HTTPException(status_code=404, detail=f"No available worker found for model {model_name}") @app.get("/list_models") async def list_models(): """List all available models across workers.""" available_models = [] for name, info in worker_info.items(): if info["status"] == "alive": available_models.extend(info["models"]) return {"models": list(set(available_models))} def main(): """Run the controller server.""" parser = argparse.ArgumentParser(description="LLaMA-Omni controller for managing worker nodes") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server") parser.add_argument("--port", type=int, default=10000, help="Port to bind the server") args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port, log_level="info") if __name__ == "__main__": main()