import argparse import json import os import time import uuid import requests import threading import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from typing import Dict, List, Optional, Union import traceback from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables model = None tokenizer = None model_name = None model_path = None device = "cuda" if torch.cuda.is_available() else "cpu" controller_url = None worker_url = None worker_id = str(uuid.uuid4())[:8] support_s2s = False def load_model(model_path_arg, s2s=False): """Load LLaMA-Omni model and tokenizer.""" global model, tokenizer, model_name, model_path, support_s2s model_name = os.path.basename(model_path_arg) model_path = model_path_arg support_s2s = s2s logger.info(f"Loading model {model_name} from {model_path}...") # This is a placeholder for downloading the model # In a real implementation, it would download from HuggingFace or another source logger.info(f"Model would be downloaded from huggingface.co/ictnlp/Llama-3.1-8B-Omni") try: # Use placeholder values since we're not actually loading the model in this setup tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") model = "PLACEHOLDER - Model would be loaded during actual deployment" logger.info(f"Model {model_name} loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") logger.error(traceback.format_exc()) return False def register_worker(): """Register with the controller.""" global worker_id, controller_url, worker_url, model_name logger.info(f"Registering worker {worker_id} with controller at {controller_url}") while True: try: response = requests.post( f"{controller_url}/register_worker", json={ "name": worker_id, "url": worker_url, "models": [model_name] if model_name else [] } ) if response.status_code == 200: logger.info(f"Worker {worker_id} registered successfully") break else: logger.error(f"Failed to register worker: {response.text}") except Exception as e: logger.error(f"Error registering worker: {str(e)}") # Retry after a short delay time.sleep(5) def heartbeat_sender(): """Send heartbeats to the controller.""" global worker_id, controller_url while True: try: response = requests.post( f"{controller_url}/heartbeat", json={"name": worker_id} ) if response.status_code == 200: logger.debug(f"Heartbeat sent successfully") else: logger.warning(f"Failed to send heartbeat: {response.text}") except Exception as e: logger.error(f"Error sending heartbeat: {str(e)}") # Send heartbeat every 15 seconds time.sleep(15) @app.get("/status") async def get_status(): """Get the status of the worker.""" return { "status": "ok", "model": model_name, "supports_speech": support_s2s } @app.post("/generate_speech") async def generate_speech(request_data: Dict): """Generate speech response from a prompt.""" prompt = request_data.get("prompt") if not prompt: raise HTTPException(status_code=400, detail="Prompt is required") try: # This is a placeholder since we're not actually generating speech # In a real implementation, it would process the prompt and return speech logger.info(f"Received prompt: {prompt[:50]}...") # Simulated response response = { "text": f"This is a response to: {prompt[:20]}...", "speech_url": None, # In a real implementation, this would be the URL to the generated speech "success": True } return response except Exception as e: logger.error(f"Error generating speech: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") @app.post("/generate_text") async def generate_text(request_data: Dict): """Generate text response from a prompt.""" prompt = request_data.get("prompt") if not prompt: raise HTTPException(status_code=400, detail="Prompt is required") try: # This is a placeholder since we're not actually generating text # In a real implementation, it would process the prompt and return text logger.info(f"Received prompt: {prompt[:50]}...") # Simulated response response = { "text": f"This is a response to: {prompt[:20]}...", "success": True } return response except Exception as e: logger.error(f"Error generating text: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") def main(): """Run the model worker.""" global controller_url, worker_url parser = argparse.ArgumentParser(description="LLaMA-Omni model worker") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server") parser.add_argument("--port", type=int, default=40000, help="Port to bind the server") parser.add_argument("--controller", type=str, required=True, help="Controller URL") parser.add_argument("--worker", type=str, required=True, help="Worker URL") parser.add_argument("--model-path", type=str, required=True, help="Path or name of the model to load") parser.add_argument("--model-name", type=str, required=True, help="Name to register the model as") parser.add_argument("--s2s", action="store_true", help="Enable speech-to-speech support") args = parser.parse_args() controller_url = args.controller worker_url = args.worker # Load the model if not load_model(args.model_path, args.s2s): logger.error("Failed to load model. Exiting.") return # Register with the controller register_worker() # Start heartbeat thread heartbeat_thread = threading.Thread(target=heartbeat_sender, daemon=True) heartbeat_thread.start() # Start the server uvicorn.run(app, host=args.host, port=args.port, log_level="info") if __name__ == "__main__": main()