from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch class EndpointHandler: def __init__(self, path=""): device = "cuda:0" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(path, device=device) self.model = AutoModelForSequenceClassification.from_pretrained(path) self.model.to(device) def tokenize(batch): return self.tokenizer( batch['topic'], batch['text'], max_length=384, #512 truncation="only_second", return_offsets_mapping=False, padding="max_length", return_tensors='pt' ).to(device) self.tokenize = tokenize def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: topics List[str] text str id str Return: A :obj:`list` | `dict`: will be serialized and returned """ topics = data.pop("topics", data) id = data.pop("id", data) text = data.pop("text", data) batch = { 'id': id, 'text': [], 'topic': [] } for topic in topics: batch['text'].append(text) batch['topic'].append(topic) tokenized_inputs = self.tokenize(batch) # run normal prediction output = self.model(**tokenized_inputs) predictions = torch.argmax(output.logits, dim=-1).numpy(force=True) batch['label'] = [self.model.config.id2label[p] for p in predictions] batch.pop('text') return batch