ndc8
Refactor application to implement GGUF backend with native transformers support; update requirements and add GGUF-specific entry point
6e96e6e
| #!/usr/bin/env python3 | |
| """ | |
| GGUF Backend with Native Transformers Support | |
| Uses transformers library's built-in GGUF loading (no llama-cpp-python needed) | |
| """ | |
| import os | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Dict, Any, Optional | |
| import uuid | |
| import time | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| # Import transformers with GGUF support | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pydantic models for OpenAI-compatible API | |
| class ChatMessage(BaseModel): | |
| role: str = Field(..., description="The role of the message author") | |
| content: str = Field(..., description="The content of the message") | |
| def validate_role(cls, v: str) -> str: | |
| if v not in ["system", "user", "assistant"]: | |
| raise ValueError("Role must be one of: system, user, assistant") | |
| return v | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default="gemma-3n-e4b-it", description="The model to use for completion") | |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") | |
| max_tokens: Optional[int] = Field(default=256, ge=1, le=1024, description="Maximum tokens to generate") | |
| temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") | |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") | |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: str | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model: str | |
| version: str | |
| backend: str | |
| quantization: str | |
| # Global variables for model management | |
| current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF") | |
| gguf_filename = os.environ.get("GGUF_FILE", "*Q4_K_M.gguf") | |
| tokenizer = None | |
| model = None | |
| text_pipeline = None | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager with GGUF model loading via transformers""" | |
| global tokenizer, model, text_pipeline | |
| logger.info("🚀 Starting GGUF Backend Service (Transformers Native)") | |
| if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"): | |
| logger.info("🧪 DEMO_MODE enabled: skipping model load") | |
| yield | |
| logger.info("🔄 Shutting down GGUF Backend Service (demo mode)...") | |
| return | |
| try: | |
| logger.info(f"📥 Loading GGUF model: {current_model}") | |
| logger.info(f"🎯 GGUF file pattern: {gguf_filename}") | |
| # Load tokenizer first | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| current_model, | |
| trust_remote_code=True, | |
| use_fast=True | |
| ) | |
| # Ensure pad token exists | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load GGUF model using native transformers support | |
| logger.info("⚙️ Loading GGUF model with transformers native support...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| current_model, | |
| gguf_file=gguf_filename, # Key parameter for GGUF loading | |
| torch_dtype=torch.float32, # CPU-compatible | |
| device_map="auto", # Let transformers handle device placement | |
| low_cpu_mem_usage=True, # Memory optimization | |
| trust_remote_code=True, | |
| ) | |
| # Create pipeline for efficient generation | |
| text_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| logger.info("✅ Successfully loaded GGUF model with transformers") | |
| logger.info(f"📊 Model: {current_model}") | |
| logger.info(f"🔧 GGUF File: {gguf_filename}") | |
| logger.info(f"🧠 Backend: Transformers native GGUF support") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize GGUF model: {e}") | |
| logger.info("🔄 Starting service in demo mode") | |
| model = None | |
| tokenizer = None | |
| text_pipeline = None | |
| yield | |
| logger.info("🔄 Shutting down GGUF Backend Service...") | |
| # Clean up model resources | |
| if model: | |
| del model | |
| if tokenizer: | |
| del tokenizer | |
| if text_pipeline: | |
| del text_pipeline | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="GGUF Backend Service (Transformers Native)", | |
| description="Memory-efficient GGUF model API using transformers native support", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: | |
| """Convert OpenAI messages format to Gemma 3n chat format.""" | |
| prompt_parts = [] | |
| for message in messages: | |
| role = message.role | |
| content = message.content.strip() | |
| if role == "system": | |
| prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") | |
| elif role == "user": | |
| prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
| elif role == "assistant": | |
| prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
| # Add the start for model response | |
| prompt_parts.append("<start_of_turn>model\n") | |
| return "\n".join(prompt_parts) | |
| def generate_response(messages: List[ChatMessage], max_tokens: int = 256, temperature: float = 1.0, top_p: float = 0.95) -> str: | |
| """Generate response using GGUF model via transformers pipeline.""" | |
| if text_pipeline is None: | |
| return "🤖 Demo mode: GGUF model not loaded. This would be a real response from the Gemma 3n GGUF model." | |
| try: | |
| # Convert messages to prompt | |
| prompt = convert_messages_to_prompt(messages) | |
| # Limit max_tokens for memory efficiency | |
| max_tokens = min(max_tokens, 512) | |
| # Generate response | |
| result = text_pipeline( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| return_full_text=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Extract generated text | |
| if result and len(result) > 0: | |
| response_text = result[0]['generated_text'].strip() | |
| # Clean up any unwanted tokens | |
| if "<end_of_turn>" in response_text: | |
| response_text = response_text.split("<end_of_turn>")[0].strip() | |
| return response_text | |
| else: | |
| return "I apologize, but I'm having trouble generating a response right now." | |
| except Exception as e: | |
| logger.error(f"GGUF generation failed: {e}") | |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
| async def root() -> Dict[str, Any]: | |
| """Root endpoint with service information""" | |
| return { | |
| "service": "GGUF Backend Service", | |
| "version": "1.0.0", | |
| "model": current_model, | |
| "gguf_file": gguf_filename, | |
| "backend": "transformers-native-gguf", | |
| "quantization": "Q4_K_M", | |
| "endpoints": { | |
| "health": "/health", | |
| "chat": "/v1/chat/completions", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| status = "healthy" if text_pipeline is not None else "demo_mode" | |
| return HealthResponse( | |
| status=status, | |
| model=current_model, | |
| version="1.0.0", | |
| backend="transformers-native-gguf", | |
| quantization="Q4_K_M" | |
| ) | |
| async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| """Create a chat completion (OpenAI-compatible) using GGUF model""" | |
| try: | |
| # Generate response | |
| response_text = generate_response( | |
| messages=request.messages, | |
| max_tokens=request.max_tokens or 256, | |
| temperature=request.temperature or 1.0, | |
| top_p=request.top_p or 0.95 | |
| ) | |
| # Create response message | |
| response_message = ChatMessage(role="assistant", content=response_text) | |
| # Create choice | |
| choice = ChatCompletionChoice( | |
| index=0, | |
| message=response_message, | |
| finish_reason="stop" | |
| ) | |
| # Create completion response | |
| completion = ChatCompletionResponse( | |
| id=f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
| object="chat.completion", | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[choice] | |
| ) | |
| return completion | |
| except Exception as e: | |
| logger.error(f"Chat completion failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |