Vx2-3y
Integrate Supabase: store job input, parameters, model, and result in inference_results table
fcd441c
from fastapi import FastAPI, HTTPException, Request | |
from pydantic import BaseModel | |
from typing import Optional, Any | |
import os | |
import logging | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from huggingface_hub import login | |
import threading | |
import time | |
import uuid | |
import redis | |
from supabase import create_client, Client | |
app = FastAPI( | |
title="NCOS Compliance LLM API", | |
description="API contract for inference, health checks, and job queueing.", | |
version="1.0.0" | |
) | |
# --- Pydantic models for request/response --- | |
class InferRequest(BaseModel): | |
input_text: str # The text to run inference on | |
parameters: Optional[dict] = None # Optional model parameters (e.g., temperature, max_tokens) | |
class InferResponse(BaseModel): | |
result: str # The model's output | |
status: str # 'success' or 'error' | |
error: Optional[str] = None # Error message if status is 'error' | |
class QueueRequest(BaseModel): | |
input_text: str # The text to enqueue for inference | |
parameters: Optional[dict] = None # Optional model parameters | |
class QueueResponse(BaseModel): | |
job_id: str # Unique job identifier | |
status: str # 'queued', 'pending', 'done', or 'error' | |
result: Optional[str] = None # Model output if available | |
error: Optional[str] = None # Error message if status is 'error' | |
# --- Model Loading (Cloud-Ready) --- | |
# Read model name and token from environment variables for security | |
MODEL_NAME = os.getenv("HF_MODEL_NAME", "gpt2") # Use gpt2 for testing; switch back to ACATECH/ncos after | |
HF_TOKEN = os.getenv("HF_TOKEN") # Should be set in Hugging Face Space secrets | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("ncos-backend") | |
# Login to Hugging Face Hub if token is provided | |
if HF_TOKEN: | |
try: | |
login(token=HF_TOKEN) | |
logger.info("Logged in to Hugging Face Hub.") | |
except Exception as e: | |
logger.error(f"Failed to login to Hugging Face Hub: {e}") | |
# Load model and tokenizer at startup | |
try: | |
logger.info(f"Loading model: {MODEL_NAME}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
# Use pipeline for simple inference | |
ncos_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0) | |
logger.info("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
logger.error(f"Model loading failed: {e}") | |
ncos_pipeline = None | |
# --- Redis Connection --- | |
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") # Set your cloud Redis URL in env | |
redis_client = redis.Redis.from_url(REDIS_URL) | |
# --- Job Queue Logic --- | |
JOB_QUEUE = "ncos_job_queue" | |
JOB_RESULT_PREFIX = "ncos_job_result:" | |
# --- Model Cache --- | |
model_cache = {"name": None, "pipeline": None} | |
# --- Supabase Connection --- | |
SUPABASE_URL = os.getenv("SUPABASE_URL") | |
SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
supabase: Client = None | |
if SUPABASE_URL and SUPABASE_KEY: | |
try: | |
supabase = create_client(SUPABASE_URL, SUPABASE_KEY) | |
logger.info("Connected to Supabase.") | |
except Exception as e: | |
logger.error(f"Failed to connect to Supabase: {e}") | |
else: | |
logger.warning("Supabase credentials not set. Skipping Supabase integration.") | |
# --- Background Worker Thread --- | |
def job_worker(): | |
while True: | |
job_data = redis_client.lpop(JOB_QUEUE) | |
if job_data: | |
job = eval(job_data) # In production, use json.loads for safety | |
job_id = job["job_id"] | |
input_text = job["input_text"] | |
parameters = job.get("parameters", {}) | |
model_name = job.get("model_name", "gpt2") | |
try: | |
# Load model if needed | |
if model_cache["name"] != model_name: | |
logger.info(f"Loading model for job: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model_cache["pipeline"] = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0) | |
model_cache["name"] = model_name | |
pipe = model_cache["pipeline"] | |
params = parameters or {} | |
params.setdefault("max_new_tokens", 128) | |
params.setdefault("temperature", 0.7) | |
output = pipe(input_text, **params) | |
result_text = output[0]["generated_text"] if output and "generated_text" in output[0] else str(output) | |
redis_client.set(JOB_RESULT_PREFIX + job_id, result_text) | |
# --- Store result in Supabase --- | |
if supabase: | |
try: | |
data = { | |
"job_id": job_id, | |
"input_text": input_text, | |
"parameters": str(parameters), | |
"model_name": model_name, | |
"result": result_text | |
} | |
supabase.table("inference_results").insert(data).execute() | |
logger.info(f"Stored job {job_id} result in Supabase.") | |
except Exception as e: | |
logger.error(f"Failed to store job {job_id} in Supabase: {e}") | |
except Exception as e: | |
logger.error(f"Job {job_id} failed: {e}") | |
redis_client.set(JOB_RESULT_PREFIX + job_id, f"ERROR: {e}") | |
else: | |
time.sleep(1) | |
# Start background worker thread | |
threading.Thread(target=job_worker, daemon=True).start() | |
# --- Endpoints --- | |
def infer(request: InferRequest): | |
""" | |
Run model inference on the input text. | |
- **input_text**: The text to run inference on. | |
- **parameters**: Optional model parameters (e.g., temperature, max_tokens). | |
Returns the model's output or an error message. | |
""" | |
if ncos_pipeline is None: | |
# Model failed to load | |
logger.error("Inference requested but model is not loaded.") | |
return InferResponse(result="", status="error", error="Model not loaded.") | |
try: | |
# Prepare parameters for the pipeline | |
params = request.parameters or {} | |
# Set sensible defaults if not provided | |
params.setdefault("max_new_tokens", 128) | |
params.setdefault("temperature", 0.7) | |
# Run inference | |
logger.info(f"Running inference for input: {request.input_text}") | |
output = ncos_pipeline(request.input_text, **params) | |
# output is a list of dicts with 'generated_text' | |
result_text = output[0]["generated_text"] if output and "generated_text" in output[0] else str(output) | |
return InferResponse(result=result_text, status="success") | |
except Exception as e: | |
logger.error(f"Error during inference: {e}") | |
return InferResponse(result="", status="error", error=str(e)) | |
def healthz(): | |
""" | |
Health check endpoint. | |
Returns 200 OK if the service is healthy. | |
""" | |
return {"status": "ok"} | |
def submit_job(request: QueueRequest): | |
""" | |
Submit a job to the queue (e.g., Redis). | |
- **input_text**: The text to enqueue for inference. | |
- **parameters**: Optional model parameters. | |
Returns a job ID and status. | |
""" | |
job_id = str(uuid.uuid4()) | |
job = { | |
"job_id": job_id, | |
"input_text": request.input_text, | |
"parameters": request.parameters, | |
"model_name": os.getenv("HF_MODEL_NAME", "gpt2") # Allow override per job in future | |
} | |
redis_client.rpush(JOB_QUEUE, str(job)) | |
return QueueResponse(job_id=job_id, status="queued") | |
def get_job_status(job_id: str): | |
""" | |
Get the status/result of a queued job. | |
- **job_id**: The job identifier. | |
Returns the job status and result if available. | |
""" | |
result = redis_client.get(JOB_RESULT_PREFIX + job_id) | |
if result: | |
result_str = result.decode("utf-8") | |
if result_str.startswith("ERROR:"): | |
return QueueResponse(job_id=job_id, status="error", error=result_str) | |
return QueueResponse(job_id=job_id, status="done", result=result_str) | |
else: | |
return QueueResponse(job_id=job_id, status="pending") | |
def root(): | |
return { | |
"message": "Welcome to the NCOS_S3 FastAPI backend!", | |
"docs": "/docs", | |
"health": "/healthz" | |
} | |
# Add middleware to log every incoming request path and method | |
async def log_requests(request: Request, call_next): | |
logger.info(f"Incoming request: {request.method} {request.url.path}") | |
response = await call_next(request) | |
return response | |
# --- End of API contract skeleton --- | |
# FastAPI will auto-generate OpenAPI docs at /docs and /openapi.json | |
# | |
# To test endpoints: | |
# - Use curl, httpie, or Postman to send requests to /infer, /healthz, /queue | |
# - Visit /docs for interactive API documentation |