File size: 1,728 Bytes
42d6781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458aa66
42d6781
 
 
 
 
a5aea8a
42d6781
 
 
458aa66
42d6781
 
458aa66
 
 
 
 
 
 
 
 
 
 
 
42d6781
 
533ef34
39f7efe
42d6781
458aa66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Dict, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

class EndpointHandler:
    def __init__(self, path: str = "."):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.model.to(self.device)
        self.model.eval()

        self.id2label = self.model.config.id2label

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        input_text = data.get("inputs", "")
        if not input_text:
            return {"error": "No input provided."}

        # Tokenization
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=64
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Forward pass
        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)[0]  # shape: (num_classes,)

        # Get top class
        top_class_id = torch.argmax(probs).item()
        top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
        top_class_prob = probs[top_class_id].item()

        # Convert full distribution to label->probability dict
        prob_distribution = {
            self.id2label.get(i) or self.id2label.get(str(i)): round(p.item(), 4)
            for i, p in enumerate(probs)
        }

        return {
            "pack": top_class_label,
            "probDistribution": prob_distribution
        }