File size: 1,757 Bytes
02be092 c14b84e 1aec583 6439837 1aec583 6439837 1aec583 c14b84e 1aec583 c14b84e 6439837 c14b84e 1aec583 3702b9d c14b84e 1aec583 3702b9d f93531b 3d192a7 1aec583 3702b9d 6439837 1aec583 c14b84e f93531b 0f33f24 6439837 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class EndpointHandler:
def __init__(self, path=""):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(path, device=device)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.to(device)
def tokenize(batch):
return self.tokenizer(
batch['topic'],
batch['text'],
max_length=384, #512
truncation="only_second",
return_offsets_mapping=False,
padding="max_length",
return_tensors='pt'
).to(device)
self.tokenize = tokenize
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
topics List[str]
text str
id str
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
topics = data.pop("topics", data)
id = data.pop("id", data)
text = data.pop("text", data)
batch = {
'id': id,
'text': [],
'topic': []
}
for topic in topics:
batch['text'].append(text)
batch['topic'].append(topic)
tokenized_inputs = self.tokenize(batch)
# run normal prediction
output = self.model(**tokenized_inputs)
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
batch['label'] = [self.model.config.id2label[p] for p in predictions]
batch.pop('text')
return batch
|