Commit
·
6439837
1
Parent(s):
d89e25d
added device to handler
Browse files- handler.py +10 -5
handler.py
CHANGED
@@ -5,8 +5,12 @@ import torch
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
-
|
|
|
|
|
9 |
self.model = AutoModelForSequenceClassification.from_pretrained(path)
|
|
|
|
|
10 |
def tokenize(batch):
|
11 |
return self.tokenizer(
|
12 |
batch['topic'],
|
@@ -16,7 +20,8 @@ class EndpointHandler:
|
|
16 |
return_offsets_mapping=False,
|
17 |
padding="max_length",
|
18 |
return_tensors='pt'
|
19 |
-
)
|
|
|
20 |
self.tokenize = tokenize
|
21 |
|
22 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
@@ -36,11 +41,11 @@ class EndpointHandler:
|
|
36 |
}
|
37 |
|
38 |
for topic in topics:
|
39 |
-
for text in texts:
|
40 |
batch['id'].append(text['id'])
|
41 |
batch['text'].append(text['text'])
|
42 |
batch['topic'].append(topic)
|
43 |
-
|
44 |
tokenized_inputs = self.tokenize(batch)
|
45 |
|
46 |
# run normal prediction
|
@@ -48,4 +53,4 @@ class EndpointHandler:
|
|
48 |
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
|
49 |
batch['label'] = [self.model.config.id2label[p] for p in predictions]
|
50 |
batch.pop('text')
|
51 |
-
return batch
|
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
9 |
+
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path, device=device)
|
11 |
self.model = AutoModelForSequenceClassification.from_pretrained(path)
|
12 |
+
self.model.to(device)
|
13 |
+
|
14 |
def tokenize(batch):
|
15 |
return self.tokenizer(
|
16 |
batch['topic'],
|
|
|
20 |
return_offsets_mapping=False,
|
21 |
padding="max_length",
|
22 |
return_tensors='pt'
|
23 |
+
).to(device)
|
24 |
+
|
25 |
self.tokenize = tokenize
|
26 |
|
27 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
41 |
}
|
42 |
|
43 |
for topic in topics:
|
44 |
+
for text in texts:
|
45 |
batch['id'].append(text['id'])
|
46 |
batch['text'].append(text['text'])
|
47 |
batch['topic'].append(topic)
|
48 |
+
|
49 |
tokenized_inputs = self.tokenize(batch)
|
50 |
|
51 |
# run normal prediction
|
|
|
53 |
predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
|
54 |
batch['label'] = [self.model.config.id2label[p] for p in predictions]
|
55 |
batch.pop('text')
|
56 |
+
return batch
|