Vx2-3y
commited on
Commit
·
04ba0b0
1
Parent(s):
3fa9baf
Implement cloud-ready model loading and inference logic; update requirements for HF Spaces deployment
Browse files- main.py +78 -20
- requirements.txt +10 -1
main.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
from typing import Optional, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
app = FastAPI(
|
| 6 |
title="NCOS Compliance LLM API",
|
|
@@ -11,40 +15,85 @@ app = FastAPI(
|
|
| 11 |
# --- Pydantic models for request/response ---
|
| 12 |
|
| 13 |
class InferRequest(BaseModel):
|
| 14 |
-
input_text: str
|
| 15 |
-
parameters: Optional[dict] = None # e.g., temperature, max_tokens
|
| 16 |
|
| 17 |
class InferResponse(BaseModel):
|
| 18 |
-
result: str
|
| 19 |
-
status: str
|
| 20 |
-
error: Optional[str] = None
|
| 21 |
|
| 22 |
class QueueRequest(BaseModel):
|
| 23 |
-
input_text: str
|
| 24 |
-
parameters: Optional[dict] = None
|
| 25 |
|
| 26 |
class QueueResponse(BaseModel):
|
| 27 |
-
job_id: str
|
| 28 |
-
status: str
|
| 29 |
-
result: Optional[str] = None
|
| 30 |
-
error: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# --- Endpoints ---
|
| 33 |
|
| 34 |
-
@app.post("/infer", response_model=InferResponse)
|
| 35 |
def infer(request: InferRequest):
|
| 36 |
"""
|
| 37 |
Run model inference on the input text.
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
try:
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
except Exception as e:
|
|
|
|
| 45 |
return InferResponse(result="", status="error", error=str(e))
|
| 46 |
|
| 47 |
-
@app.get("/healthz")
|
| 48 |
def healthz():
|
| 49 |
"""
|
| 50 |
Health check endpoint.
|
|
@@ -52,23 +101,32 @@ def healthz():
|
|
| 52 |
"""
|
| 53 |
return {"status": "ok"}
|
| 54 |
|
| 55 |
-
@app.post("/queue", response_model=QueueResponse)
|
| 56 |
def submit_job(request: QueueRequest):
|
| 57 |
"""
|
| 58 |
Submit a job to the queue (e.g., Redis).
|
|
|
|
|
|
|
|
|
|
| 59 |
"""
|
| 60 |
# TODO: Integrate with Redis queue
|
| 61 |
job_id = "job_123" # Placeholder
|
| 62 |
return QueueResponse(job_id=job_id, status="queued")
|
| 63 |
|
| 64 |
-
@app.get("/queue", response_model=QueueResponse)
|
| 65 |
def get_job_status(job_id: str):
|
| 66 |
"""
|
| 67 |
Get the status/result of a queued job.
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
# TODO: Query Redis for job status/result
|
| 70 |
return QueueResponse(job_id=job_id, status="pending")
|
| 71 |
|
| 72 |
# --- End of API contract skeleton ---
|
| 73 |
|
| 74 |
-
# FastAPI will auto-generate OpenAPI docs at /docs and /openapi.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
from typing import Optional, Any
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 7 |
+
from huggingface_hub import login
|
| 8 |
|
| 9 |
app = FastAPI(
|
| 10 |
title="NCOS Compliance LLM API",
|
|
|
|
| 15 |
# --- Pydantic models for request/response ---
|
| 16 |
|
| 17 |
class InferRequest(BaseModel):
|
| 18 |
+
input_text: str # The text to run inference on
|
| 19 |
+
parameters: Optional[dict] = None # Optional model parameters (e.g., temperature, max_tokens)
|
| 20 |
|
| 21 |
class InferResponse(BaseModel):
|
| 22 |
+
result: str # The model's output
|
| 23 |
+
status: str # 'success' or 'error'
|
| 24 |
+
error: Optional[str] = None # Error message if status is 'error'
|
| 25 |
|
| 26 |
class QueueRequest(BaseModel):
|
| 27 |
+
input_text: str # The text to enqueue for inference
|
| 28 |
+
parameters: Optional[dict] = None # Optional model parameters
|
| 29 |
|
| 30 |
class QueueResponse(BaseModel):
|
| 31 |
+
job_id: str # Unique job identifier
|
| 32 |
+
status: str # 'queued', 'pending', 'done', or 'error'
|
| 33 |
+
result: Optional[str] = None # Model output if available
|
| 34 |
+
error: Optional[str] = None # Error message if status is 'error'
|
| 35 |
+
|
| 36 |
+
# --- Model Loading (Cloud-Ready) ---
|
| 37 |
+
|
| 38 |
+
# Read model name and token from environment variables for security
|
| 39 |
+
MODEL_NAME = os.getenv("HF_MODEL_NAME", "ACATECH/ncos") # Default to ACATECH/ncos
|
| 40 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Should be set in Hugging Face Space secrets
|
| 41 |
+
|
| 42 |
+
# Set up logging
|
| 43 |
+
logging.basicConfig(level=logging.INFO)
|
| 44 |
+
logger = logging.getLogger("ncos-backend")
|
| 45 |
+
|
| 46 |
+
# Login to Hugging Face Hub if token is provided
|
| 47 |
+
if HF_TOKEN:
|
| 48 |
+
try:
|
| 49 |
+
login(token=HF_TOKEN)
|
| 50 |
+
logger.info("Logged in to Hugging Face Hub.")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Failed to login to Hugging Face Hub: {e}")
|
| 53 |
+
|
| 54 |
+
# Load model and tokenizer at startup
|
| 55 |
+
try:
|
| 56 |
+
logger.info(f"Loading model: {MODEL_NAME}")
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 58 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
| 59 |
+
# Use pipeline for simple inference
|
| 60 |
+
ncos_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
|
| 61 |
+
logger.info("Model and tokenizer loaded successfully.")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Model loading failed: {e}")
|
| 64 |
+
ncos_pipeline = None
|
| 65 |
|
| 66 |
# --- Endpoints ---
|
| 67 |
|
| 68 |
+
@app.post("/infer", response_model=InferResponse, summary="Run model inference", description="Run LLM inference on the input text and return the result.")
|
| 69 |
def infer(request: InferRequest):
|
| 70 |
"""
|
| 71 |
Run model inference on the input text.
|
| 72 |
+
- **input_text**: The text to run inference on.
|
| 73 |
+
- **parameters**: Optional model parameters (e.g., temperature, max_tokens).
|
| 74 |
+
Returns the model's output or an error message.
|
| 75 |
"""
|
| 76 |
+
if ncos_pipeline is None:
|
| 77 |
+
# Model failed to load
|
| 78 |
+
logger.error("Inference requested but model is not loaded.")
|
| 79 |
+
return InferResponse(result="", status="error", error="Model not loaded.")
|
| 80 |
try:
|
| 81 |
+
# Prepare parameters for the pipeline
|
| 82 |
+
params = request.parameters or {}
|
| 83 |
+
# Set sensible defaults if not provided
|
| 84 |
+
params.setdefault("max_new_tokens", 128)
|
| 85 |
+
params.setdefault("temperature", 0.7)
|
| 86 |
+
# Run inference
|
| 87 |
+
logger.info(f"Running inference for input: {request.input_text}")
|
| 88 |
+
output = ncos_pipeline(request.input_text, **params)
|
| 89 |
+
# output is a list of dicts with 'generated_text'
|
| 90 |
+
result_text = output[0]["generated_text"] if output and "generated_text" in output[0] else str(output)
|
| 91 |
+
return InferResponse(result=result_text, status="success")
|
| 92 |
except Exception as e:
|
| 93 |
+
logger.error(f"Error during inference: {e}")
|
| 94 |
return InferResponse(result="", status="error", error=str(e))
|
| 95 |
|
| 96 |
+
@app.get("/healthz", summary="Health check", description="Check if the backend service is healthy.")
|
| 97 |
def healthz():
|
| 98 |
"""
|
| 99 |
Health check endpoint.
|
|
|
|
| 101 |
"""
|
| 102 |
return {"status": "ok"}
|
| 103 |
|
| 104 |
+
@app.post("/queue", response_model=QueueResponse, summary="Submit job to queue", description="Submit a job to the Redis queue for asynchronous processing.")
|
| 105 |
def submit_job(request: QueueRequest):
|
| 106 |
"""
|
| 107 |
Submit a job to the queue (e.g., Redis).
|
| 108 |
+
- **input_text**: The text to enqueue for inference.
|
| 109 |
+
- **parameters**: Optional model parameters.
|
| 110 |
+
Returns a job ID and status.
|
| 111 |
"""
|
| 112 |
# TODO: Integrate with Redis queue
|
| 113 |
job_id = "job_123" # Placeholder
|
| 114 |
return QueueResponse(job_id=job_id, status="queued")
|
| 115 |
|
| 116 |
+
@app.get("/queue", response_model=QueueResponse, summary="Get job status/result", description="Get the status or result of a queued job by job_id.")
|
| 117 |
def get_job_status(job_id: str):
|
| 118 |
"""
|
| 119 |
Get the status/result of a queued job.
|
| 120 |
+
- **job_id**: The job identifier.
|
| 121 |
+
Returns the job status and result if available.
|
| 122 |
"""
|
| 123 |
# TODO: Query Redis for job status/result
|
| 124 |
return QueueResponse(job_id=job_id, status="pending")
|
| 125 |
|
| 126 |
# --- End of API contract skeleton ---
|
| 127 |
|
| 128 |
+
# FastAPI will auto-generate OpenAPI docs at /docs and /openapi.json
|
| 129 |
+
#
|
| 130 |
+
# To test endpoints:
|
| 131 |
+
# - Use curl, httpie, or Postman to send requests to /infer, /healthz, /queue
|
| 132 |
+
# - Visit /docs for interactive API documentation
|
requirements.txt
CHANGED
|
@@ -11,4 +11,13 @@ supabase
|
|
| 11 |
python-dotenv
|
| 12 |
|
| 13 |
# Data validation and settings management
|
| 14 |
-
pydantic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
python-dotenv
|
| 12 |
|
| 13 |
# Data validation and settings management
|
| 14 |
+
pydantic
|
| 15 |
+
|
| 16 |
+
# For loading and running LLMs from Hugging Face
|
| 17 |
+
transformers==4.40.2
|
| 18 |
+
|
| 19 |
+
# For model hub integration
|
| 20 |
+
huggingface_hub==0.23.1
|
| 21 |
+
|
| 22 |
+
# For GPU inference (update if needed for CUDA compatibility)
|
| 23 |
+
torch==2.2.2
|