File size: 1,350 Bytes
6826247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
app = FastAPI(title="LLMGuard - Prompt Injection Classifier API")
# Add the health check route
@app.get("/health")
def health_check():
return {"status": "ok"}
# Load model and tokenizer once at startup
model_path = "model/injection_classifier"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
class PromptRequest(BaseModel):
prompt: str
class PromptResponse(BaseModel):
label: str
confidence: float
@app.post("/moderate", response_model=PromptResponse)
def moderate_prompt(req: PromptRequest):
try:
inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][predicted].item()
label = "Injection" if predicted == 1 else "Normal"
return {"label": label, "confidence": round(confidence, 3)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|