|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, device=device) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(path) |
|
self.model.to(device) |
|
|
|
def tokenize(batch): |
|
return self.tokenizer( |
|
batch['topic'], |
|
batch['text'], |
|
max_length=384, |
|
truncation="only_second", |
|
return_offsets_mapping=False, |
|
padding="max_length", |
|
return_tensors='pt' |
|
).to(device) |
|
|
|
self.tokenize = tokenize |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
topics List[str] |
|
text str |
|
id str |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
topics = data.pop("topics", data) |
|
id = data.pop("id", data) |
|
text = data.pop("text", data) |
|
batch = { |
|
'id': id, |
|
'text': [], |
|
'topic': [] |
|
} |
|
|
|
for topic in topics: |
|
batch['text'].append(text) |
|
batch['topic'].append(topic) |
|
|
|
tokenized_inputs = self.tokenize(batch) |
|
|
|
|
|
output = self.model(**tokenized_inputs) |
|
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True) |
|
batch['label'] = [self.model.config.id2label[p] for p in predictions] |
|
batch.pop('text') |
|
return batch |
|
|