ndc8
Update Dockerfile and application entry point for GGUF backend; optimize memory usage in model parameters and requirements
358e717
| #!/usr/bin/env python3 | |
| """ | |
| Working Gemma 3n GGUF Backend Service | |
| Minimal FastAPI backend using only llama-cpp-python for GGUF models | |
| """ | |
| import os | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import List, Dict, Any, Optional | |
| import uuid | |
| import sys | |
| import subprocess | |
| import threading | |
| from pathlib import Path | |
| import signal # Use signal.SIGTERM for process termination | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| # Import llama-cpp-python for GGUF model support | |
| try: | |
| from llama_cpp import Llama | |
| llama_cpp_available = True | |
| except ImportError: | |
| llama_cpp_available = False | |
| import uvicorn | |
| import sqlite3 | |
| import json # For persisting job metadata | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pydantic models for OpenAI-compatible API | |
| class ChatMessage(BaseModel): | |
| role: str = Field(..., description="The role of the message author") | |
| content: str = Field(..., description="The content of the message") | |
| def validate_role(cls, v: str) -> str: | |
| if v not in ["system", "user", "assistant"]: | |
| raise ValueError("Role must be one of: system, user, assistant") | |
| return v | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default="gemma-3n-e4b-it", description="The model to use for completion") | |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") | |
| max_tokens: Optional[int] = Field(default=256, ge=1, le=1024, description="Maximum tokens to generate (reduced for memory efficiency)") | |
| temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") | |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") | |
| top_k: Optional[int] = Field(default=64, ge=1, le=100, description="Top-k sampling") | |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: str | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model: str | |
| version: str | |
| backend: str | |
| from pathlib import Path | |
| # Global variables for model management | |
| current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF") | |
| llm = None | |
| def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str: | |
| """Convert OpenAI messages format to Gemma 3n chat format.""" | |
| # Gemma 3n uses specific format with <start_of_turn> and <end_of_turn> | |
| prompt_parts = ["<bos>"] | |
| for message in messages: | |
| role = message.role | |
| content = message.content | |
| if role == "system": | |
| prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") | |
| elif role == "user": | |
| prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
| elif role == "assistant": | |
| prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
| # Add the start for model response | |
| prompt_parts.append("<start_of_turn>model\n") | |
| return "\n".join(prompt_parts) | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for startup and shutdown events""" | |
| global llm | |
| logger.info("🚀 Starting Gemma 3n GGUF Backend Service...") | |
| if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"): | |
| logger.info("🧪 DEMO_MODE enabled: skipping model load") | |
| llm = None | |
| yield | |
| logger.info("🔄 Shutting down Gemma 3n Backend Service (demo mode)...") | |
| return | |
| if not llama_cpp_available: | |
| logger.error("❌ llama-cpp-python is not available. Please install with: pip install llama-cpp-python") | |
| raise RuntimeError("llama-cpp-python not available") | |
| try: | |
| logger.info(f"📥 Loading Gemma 3n GGUF model from {current_model}...") | |
| # Configure model parameters optimized for HF Spaces memory constraints | |
| llm = Llama.from_pretrained( | |
| repo_id=current_model, | |
| filename="*Q4_0.gguf", # Use Q4_0 instead of Q4_K_M for lower memory usage | |
| verbose=True, | |
| # Memory-optimized settings for HF Spaces | |
| n_ctx=2048, # Reduced context length to save memory (was 4096) | |
| n_threads=2, # Fewer threads for lower memory usage (was 4) | |
| n_gpu_layers=0, # Force CPU-only to avoid GPU memory issues | |
| # Additional memory optimizations | |
| n_batch=512, # Smaller batch size to reduce memory peaks | |
| use_mmap=True, # Use memory mapping to reduce RAM usage | |
| use_mlock=False, # Don't lock memory pages | |
| low_vram=True, # Enable low VRAM mode for additional memory savings | |
| # Chat template for Gemma 3n format | |
| chat_format="gemma", # Try built-in gemma format first | |
| ) | |
| logger.info("✅ Successfully loaded Gemma 3n GGUF model with memory optimizations") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize Gemma 3n model: {e}") | |
| logger.warning("⚠️ Please download the GGUF model file locally and update the path") | |
| logger.warning("⚠️ You can download from: https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF") | |
| # For demo purposes, we'll continue without the model | |
| logger.info("🔄 Starting service in demo mode (responses will be mocked)") | |
| yield | |
| logger.info("🔄 Shutting down Gemma 3n Backend Service...") | |
| if llm: | |
| # Clean up model resources | |
| llm = None | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Gemma 3n GGUF Backend Service", | |
| description="OpenAI-compatible chat completion API powered by Gemma-3n-E4B-it-GGUF", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def ensure_model_ready(): | |
| """Check if model is loaded and ready""" | |
| # For demo mode, we'll allow the service to run even without a model | |
| pass | |
| def generate_response_gguf(messages: List[ChatMessage], max_tokens: int = 256, temperature: float = 1.0, top_p: float = 0.95, top_k: int = 64) -> str: | |
| """Generate response using GGUF model via llama-cpp-python (memory-optimized).""" | |
| if llm is None: | |
| # Demo mode response | |
| return "🤖 Demo mode: Gemma 3n model not loaded. This would be a real response from the Gemma 3n model. Please download the GGUF model from https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF" | |
| # Limit max_tokens for memory efficiency on HF Spaces | |
| max_tokens = min(max_tokens, 512) # Cap at 512 tokens max | |
| try: | |
| # Use the chat completion method if available | |
| if hasattr(llm, 'create_chat_completion'): | |
| # Convert to dict format for llama-cpp-python | |
| messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] | |
| response = llm.create_chat_completion( | |
| messages=messages_dict, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| stop=["<end_of_turn>", "<eos>", "</s>"] # Gemma 3n stop tokens | |
| ) | |
| return response['choices'][0]['message']['content'].strip() | |
| else: | |
| # Fallback to direct prompt completion | |
| prompt = convert_messages_to_gemma_prompt(messages) | |
| response = llm( | |
| prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| stop=["<end_of_turn>", "<eos>", "</s>"], | |
| echo=False | |
| ) | |
| return response['choices'][0]['text'].strip() | |
| except Exception as e: | |
| logger.error(f"GGUF generation failed: {e}") | |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
| async def root() -> Dict[str, Any]: | |
| """Root endpoint with service information""" | |
| return { | |
| "message": "Gemma 3n GGUF Backend Service is running!", | |
| "model": current_model, | |
| "version": "1.0.0", | |
| "backend": "llama-cpp-python", | |
| "model_loaded": llm is not None, | |
| "endpoints": { | |
| "health": "/health", | |
| "chat_completions": "/v1/chat/completions" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy" if (llm is not None) else "demo_mode", | |
| model=current_model, | |
| version="1.0.0", | |
| backend="llama-cpp-python" | |
| ) | |
| async def create_chat_completion( | |
| request: ChatCompletionRequest | |
| ) -> ChatCompletionResponse: | |
| """Create a chat completion (OpenAI-compatible) using Gemma 3n GGUF""" | |
| try: | |
| ensure_model_ready() | |
| if not request.messages: | |
| raise HTTPException(status_code=400, detail="Messages cannot be empty") | |
| logger.info(f"Generating Gemma 3n response for {len(request.messages)} messages") | |
| response_text = generate_response_gguf( | |
| request.messages, | |
| request.max_tokens or 512, | |
| request.temperature or 1.0, | |
| request.top_p or 0.95, | |
| request.top_k or 64 | |
| ) | |
| response_text = response_text.strip() if response_text else "No response generated." | |
| return ChatCompletionResponse( | |
| id=f"chatcmpl-{int(time.time())}", | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| )] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat completion: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| # ----------------------------- | |
| # Training Job Management (Unsloth) | |
| # ----------------------------- | |
| # Persistent job store: in-memory dict backed by SQLite | |
| TRAIN_JOBS: Dict[str, Dict[str, Any]] = {} | |
| # Initialize SQLite DB for job persistence | |
| DB_PATH = Path(os.environ.get("JOB_DB_PATH", "./jobs.db")) | |
| conn = sqlite3.connect(str(DB_PATH), check_same_thread=False) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS jobs ( | |
| job_id TEXT PRIMARY KEY, | |
| data TEXT NOT NULL | |
| ) | |
| """ | |
| ) | |
| conn.commit() | |
| def load_jobs() -> None: | |
| cursor.execute("SELECT job_id, data FROM jobs") | |
| for job_id, data in cursor.fetchall(): | |
| TRAIN_JOBS[job_id] = json.loads(data) | |
| def save_job(job_id: str) -> None: | |
| cursor.execute( | |
| "INSERT OR REPLACE INTO jobs (job_id, data) VALUES (?, ?)", | |
| (job_id, json.dumps(TRAIN_JOBS[job_id])) | |
| ) | |
| conn.commit() | |
| # Load existing jobs on startup | |
| load_jobs() | |
| TRAIN_DIR = Path(os.environ.get("TRAIN_DIR", "./training_runs")).resolve() | |
| TRAIN_DIR.mkdir(parents=True, exist_ok=True) | |
| # Maximum concurrent training jobs | |
| MAX_CONCURRENT_JOBS = int(os.environ.get("MAX_CONCURRENT_JOBS", "5")) | |
| def _start_training_subprocess(job_id: str, args: Dict[str, Any]) -> subprocess.Popen[Any]: | |
| """Spawn a subprocess to run the Unsloth fine-tuning script.""" | |
| logs_dir = TRAIN_DIR / job_id | |
| logs_dir.mkdir(parents=True, exist_ok=True) | |
| log_file = open(logs_dir / "train.log", "w", encoding="utf-8") | |
| # Store log file handle to close later | |
| TRAIN_JOBS.setdefault(job_id, {})["log_file"] = log_file | |
| save_job(job_id) | |
| # Build absolute script path to avoid module/package resolution issues | |
| script_path = (Path(__file__).parent / "training" / "train_gemma_unsloth.py").resolve() | |
| # Verify training script exists | |
| if not script_path.exists(): | |
| logger.error(f"Training script not found at {script_path}") | |
| raise HTTPException(status_code=500, detail=f"Training script not found at {script_path}") | |
| python_exec = sys.executable | |
| cmd = [ | |
| python_exec, | |
| str(script_path), | |
| "--job-id", job_id, | |
| "--output-dir", str(logs_dir), | |
| ] | |
| # Optional user-specified args | |
| def _extend(k: str, v: Any): | |
| if v is None: | |
| return | |
| if isinstance(v, bool): | |
| cmd.extend([f"--{k}"] if v else []) | |
| else: | |
| cmd.extend([f"--{k}", str(v)]) | |
| _extend("dataset", args.get("dataset")) | |
| _extend("text-field", args.get("text_field")) | |
| _extend("prompt-field", args.get("prompt_field")) | |
| _extend("response-field", args.get("response_field")) | |
| _extend("max-steps", args.get("max_steps")) | |
| _extend("epochs", args.get("epochs")) | |
| _extend("lr", args.get("lr")) | |
| _extend("batch-size", args.get("batch_size")) | |
| _extend("gradient-accumulation", args.get("gradient_accumulation")) | |
| _extend("lora-r", args.get("lora_r")) | |
| _extend("lora-alpha", args.get("lora_alpha")) | |
| _extend("cutoff-len", args.get("cutoff_len")) | |
| _extend("model-id", args.get("model_id")) | |
| _extend("use-bf16", args.get("use_bf16")) | |
| _extend("use-fp16", args.get("use_fp16")) | |
| _extend("seed", args.get("seed")) | |
| _extend("dry-run", args.get("dry_run")) | |
| logger.info(f"🧵 Starting training subprocess for job {job_id}: {' '.join(cmd)}") | |
| logger.info(f"🐍 Using interpreter: {python_exec}") | |
| proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, cwd=str(Path(__file__).parent)) | |
| return proc | |
| def _watch_process(job_id: str, proc: subprocess.Popen[Any]): | |
| """Monitor a training process and update job state on exit.""" | |
| return_code = proc.wait() | |
| status = "completed" if return_code == 0 else "failed" | |
| TRAIN_JOBS[job_id]["status"] = status | |
| TRAIN_JOBS[job_id]["return_code"] = return_code | |
| TRAIN_JOBS[job_id]["ended_at"] = int(time.time()) | |
| # Persist updated job status | |
| save_job(job_id) | |
| # Close the log file handle to prevent resource leaks | |
| log_file = TRAIN_JOBS[job_id].get("log_file") | |
| if log_file: | |
| try: | |
| log_file.close() | |
| except Exception as close_err: | |
| logger.warning(f"Failed to close log file for job {job_id}: {close_err}") | |
| logger.info(f"🏁 Training job {job_id} finished with status={status}, code={return_code}") | |
| class StartTrainingRequest(BaseModel): | |
| dataset: str = Field(..., description="HF dataset name or path to local JSONL/JSON file") | |
| model_id: Optional[str] = Field(default="unsloth/gemma-3n-E4B-it", description="Base model for training (HF Transformers format)") | |
| text_field: Optional[str] = Field(default=None, description="Single text field name (SFT)") | |
| prompt_field: Optional[str] = Field(default=None, description="Prompt/instruction field (chat data)") | |
| response_field: Optional[str] = Field(default=None, description="Response/output field (chat data)") | |
| max_steps: Optional[int] = Field(default=None) | |
| epochs: Optional[int] = Field(default=1) | |
| lr: Optional[float] = Field(default=2e-4) | |
| batch_size: Optional[int] = Field(default=1) | |
| gradient_accumulation: Optional[int] = Field(default=8) | |
| lora_r: Optional[int] = Field(default=16) | |
| lora_alpha: Optional[int] = Field(default=32) | |
| cutoff_len: Optional[int] = Field(default=4096) | |
| use_bf16: Optional[bool] = Field(default=True) | |
| use_fp16: Optional[bool] = Field(default=False) | |
| seed: Optional[int] = Field(default=42) | |
| dry_run: Optional[bool] = Field(default=False, description="Write DONE and exit without running (for CI/macOS)") | |
| class StartTrainingResponse(BaseModel): | |
| job_id: str | |
| status: str | |
| output_dir: str | |
| class TrainStatusResponse(BaseModel): | |
| job_id: str | |
| status: str | |
| created_at: int | |
| started_at: Optional[int] = None | |
| ended_at: Optional[int] = None | |
| output_dir: Optional[str] = None | |
| return_code: Optional[int] = None | |
| def start_training(req: StartTrainingRequest): | |
| """Start a background Unsloth fine-tuning job. Returns a job_id to poll.""" | |
| # Enforce maximum concurrent training jobs | |
| running_jobs = sum(1 for job in TRAIN_JOBS.values() if job.get("status") == "running") | |
| if running_jobs >= MAX_CONCURRENT_JOBS: | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Maximum concurrent training jobs reached ({MAX_CONCURRENT_JOBS}). Try again later." | |
| ) | |
| job_id = uuid.uuid4().hex[:12] | |
| now = int(time.time()) | |
| output_dir = str((TRAIN_DIR / job_id).resolve()) | |
| TRAIN_JOBS[job_id] = { | |
| "status": "starting", | |
| "created_at": now, | |
| "started_at": now, | |
| "args": req.model_dump(), | |
| "output_dir": output_dir, | |
| } | |
| save_job(job_id) | |
| try: | |
| proc = _start_training_subprocess(job_id, req.model_dump()) | |
| TRAIN_JOBS[job_id]["status"] = "running" | |
| TRAIN_JOBS[job_id]["pid"] = proc.pid | |
| save_job(job_id) | |
| watcher = threading.Thread(target=_watch_process, args=(job_id, proc), daemon=True) | |
| watcher.start() | |
| return StartTrainingResponse(job_id=job_id, status="running", output_dir=output_dir) | |
| except Exception as e: | |
| logger.exception("Failed to start training job") | |
| TRAIN_JOBS[job_id]["status"] = "failed_to_start" | |
| save_job(job_id) | |
| raise HTTPException(status_code=500, detail=f"Failed to start training: {e}") | |
| def train_status(job_id: str): | |
| job = TRAIN_JOBS.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return TrainStatusResponse( | |
| job_id=job_id, | |
| status=job.get("status", "unknown"), | |
| created_at=job.get("created_at", 0), | |
| started_at=job.get("started_at"), | |
| ended_at=job.get("ended_at"), | |
| output_dir=job.get("output_dir"), | |
| return_code=job.get("return_code"), | |
| ) | |
| def train_logs( | |
| job_id: str, | |
| tail: int = Query(200, ge=0, le=1000, description="Number of lines to tail, between 0 and 1000"), | |
| ): | |
| job = TRAIN_JOBS.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| log_path = Path(job["output_dir"]) / "train.log" | |
| if not log_path.exists(): | |
| return {"job_id": job_id, "logs": "(no logs yet)"} | |
| try: | |
| with open(log_path, "r", encoding="utf-8", errors="ignore") as f: | |
| lines = f.readlines()[-tail:] | |
| return {"job_id": job_id, "logs": "".join(lines)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to read logs: {e}") | |
| def train_stop(job_id: str): | |
| job = TRAIN_JOBS.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| pid = job.get("pid") | |
| if not pid: | |
| raise HTTPException(status_code=400, detail="Job does not have an active PID") | |
| try: | |
| os.kill(pid, signal.SIGTERM) | |
| except ProcessLookupError: | |
| logger.warning( | |
| f"Process {pid} for job {job_id} not found; may have exited already" | |
| ) | |
| job["status"] = "stopping_failed" | |
| save_job(job_id) | |
| return {"job_id": job_id, "status": job["status"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to stop job: {e}") | |
| else: | |
| job["status"] = "stopping" | |
| save_job(job_id) | |
| return {"job_id": job_id, "status": "stopping"} | |
| # Main entry point | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |