Spaces:
Build error
Build error
import argparse | |
import asyncio | |
import json | |
import time | |
from fastapi import FastAPI, WebSocket, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
from typing import Dict, List, Optional, Union | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Store worker information | |
worker_info = {} | |
async def get_status(): | |
"""Get the status of the controller.""" | |
return {"status": "ok", "worker_count": len(worker_info)} | |
async def get_worker_info(): | |
"""Get information about all registered workers.""" | |
return {"worker_info": worker_info} | |
async def register_worker(worker_info_data: Dict): | |
"""Register a new worker.""" | |
worker_name = worker_info_data.get("name") | |
worker_url = worker_info_data.get("url") | |
if not worker_name or not worker_url: | |
raise HTTPException(status_code=400, detail="Missing name or URL for worker") | |
models = worker_info_data.get("models", []) | |
worker_info[worker_name] = { | |
"url": worker_url, | |
"models": models, | |
"status": "alive", | |
"last_heartbeat": time.time() | |
} | |
return {"status": "registered", "worker_name": worker_name} | |
async def unregister_worker(worker_name: str): | |
"""Unregister a worker.""" | |
if worker_name in worker_info: | |
del worker_info[worker_name] | |
return {"status": "unregistered", "worker_name": worker_name} | |
else: | |
raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found") | |
async def heartbeat(worker_data: Dict): | |
"""Process worker heartbeat.""" | |
worker_name = worker_data.get("name") | |
if worker_name in worker_info: | |
worker_info[worker_name]["last_heartbeat"] = time.time() | |
worker_info[worker_name]["status"] = "alive" | |
return {"status": "received"} | |
else: | |
raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found") | |
async def get_worker_address(model_name: str): | |
"""Get the address of a worker that hosts the requested model.""" | |
for name, info in worker_info.items(): | |
if model_name in info["models"] and info["status"] == "alive": | |
return {"worker_address": info["url"]} | |
raise HTTPException(status_code=404, detail=f"No available worker found for model {model_name}") | |
async def list_models(): | |
"""List all available models across workers.""" | |
available_models = [] | |
for name, info in worker_info.items(): | |
if info["status"] == "alive": | |
available_models.extend(info["models"]) | |
return {"models": list(set(available_models))} | |
def main(): | |
"""Run the controller server.""" | |
parser = argparse.ArgumentParser(description="LLaMA-Omni controller for managing worker nodes") | |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server") | |
parser.add_argument("--port", type=int, default=10000, help="Port to bind the server") | |
args = parser.parse_args() | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |
if __name__ == "__main__": | |
main() |