llama-omni / omni_speech /serve /controller.py
marcosremar2's picture
dfdfdf
34b8b49
import argparse
import asyncio
import json
import time
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from typing import Dict, List, Optional, Union
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store worker information
worker_info = {}
@app.get("/status")
async def get_status():
"""Get the status of the controller."""
return {"status": "ok", "worker_count": len(worker_info)}
@app.get("/worker_info")
async def get_worker_info():
"""Get information about all registered workers."""
return {"worker_info": worker_info}
@app.post("/register_worker")
async def register_worker(worker_info_data: Dict):
"""Register a new worker."""
worker_name = worker_info_data.get("name")
worker_url = worker_info_data.get("url")
if not worker_name or not worker_url:
raise HTTPException(status_code=400, detail="Missing name or URL for worker")
models = worker_info_data.get("models", [])
worker_info[worker_name] = {
"url": worker_url,
"models": models,
"status": "alive",
"last_heartbeat": time.time()
}
return {"status": "registered", "worker_name": worker_name}
@app.post("/unregister_worker")
async def unregister_worker(worker_name: str):
"""Unregister a worker."""
if worker_name in worker_info:
del worker_info[worker_name]
return {"status": "unregistered", "worker_name": worker_name}
else:
raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found")
@app.post("/heartbeat")
async def heartbeat(worker_data: Dict):
"""Process worker heartbeat."""
worker_name = worker_data.get("name")
if worker_name in worker_info:
worker_info[worker_name]["last_heartbeat"] = time.time()
worker_info[worker_name]["status"] = "alive"
return {"status": "received"}
else:
raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found")
@app.get("/get_worker_address")
async def get_worker_address(model_name: str):
"""Get the address of a worker that hosts the requested model."""
for name, info in worker_info.items():
if model_name in info["models"] and info["status"] == "alive":
return {"worker_address": info["url"]}
raise HTTPException(status_code=404, detail=f"No available worker found for model {model_name}")
@app.get("/list_models")
async def list_models():
"""List all available models across workers."""
available_models = []
for name, info in worker_info.items():
if info["status"] == "alive":
available_models.extend(info["models"])
return {"models": list(set(available_models))}
def main():
"""Run the controller server."""
parser = argparse.ArgumentParser(description="LLaMA-Omni controller for managing worker nodes")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server")
parser.add_argument("--port", type=int, default=10000, help="Port to bind the server")
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()