ke-lly-d's picture
Update handler.py
39f7efe verified
from typing import Dict, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class EndpointHandler:
def __init__(self, path: str = "."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input provided."}
# Tokenization
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Forward pass
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0] # shape: (num_classes,)
# Get top class
top_class_id = torch.argmax(probs).item()
top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
top_class_prob = probs[top_class_id].item()
# Convert full distribution to label->probability dict
prob_distribution = {
self.id2label.get(i) or self.id2label.get(str(i)): round(p.item(), 4)
for i, p in enumerate(probs)
}
return {
"pack": top_class_label,
"probDistribution": prob_distribution
}