dariadaria commited on
Commit
6439837
·
1 Parent(s): d89e25d

added device to handler

Browse files
Files changed (1) hide show
  1. handler.py +10 -5
handler.py CHANGED
@@ -5,8 +5,12 @@ import torch
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
9
  self.model = AutoModelForSequenceClassification.from_pretrained(path)
 
 
10
  def tokenize(batch):
11
  return self.tokenizer(
12
  batch['topic'],
@@ -16,7 +20,8 @@ class EndpointHandler:
16
  return_offsets_mapping=False,
17
  padding="max_length",
18
  return_tensors='pt'
19
- )
 
20
  self.tokenize = tokenize
21
 
22
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -36,11 +41,11 @@ class EndpointHandler:
36
  }
37
 
38
  for topic in topics:
39
- for text in texts:
40
  batch['id'].append(text['id'])
41
  batch['text'].append(text['text'])
42
  batch['topic'].append(topic)
43
-
44
  tokenized_inputs = self.tokenize(batch)
45
 
46
  # run normal prediction
@@ -48,4 +53,4 @@ class EndpointHandler:
48
  predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
49
  batch['label'] = [self.model.config.id2label[p] for p in predictions]
50
  batch.pop('text')
51
- return batch
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path, device=device)
11
  self.model = AutoModelForSequenceClassification.from_pretrained(path)
12
+ self.model.to(device)
13
+
14
  def tokenize(batch):
15
  return self.tokenizer(
16
  batch['topic'],
 
20
  return_offsets_mapping=False,
21
  padding="max_length",
22
  return_tensors='pt'
23
+ ).to(device)
24
+
25
  self.tokenize = tokenize
26
 
27
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
41
  }
42
 
43
  for topic in topics:
44
+ for text in texts:
45
  batch['id'].append(text['id'])
46
  batch['text'].append(text['text'])
47
  batch['topic'].append(topic)
48
+
49
  tokenized_inputs = self.tokenize(batch)
50
 
51
  # run normal prediction
 
53
  predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
54
  batch['label'] = [self.model.config.id2label[p] for p in predictions]
55
  batch.pop('text')
56
+ return batch