llmguard / api /app.py
Tuathe's picture
Clean repo without large checkpoint files
6826247
raw
history blame
1.35 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
app = FastAPI(title="LLMGuard - Prompt Injection Classifier API")
# Add the health check route
@app.get("/health")
def health_check():
return {"status": "ok"}
# Load model and tokenizer once at startup
model_path = "model/injection_classifier"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
class PromptRequest(BaseModel):
prompt: str
class PromptResponse(BaseModel):
label: str
confidence: float
@app.post("/moderate", response_model=PromptResponse)
def moderate_prompt(req: PromptRequest):
try:
inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][predicted].item()
label = "Injection" if predicted == 1 else "Normal"
return {"label": label, "confidence": round(confidence, 3)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))