dariadaria commited on
Commit
f93531b
·
1 Parent(s): 1aec583

handler without datasets

Browse files
Files changed (1) hide show
  1. handler.py +9 -10
handler.py CHANGED
@@ -1,7 +1,6 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
- from datasets import Dataset
5
 
6
 
7
  class EndpointHandler:
@@ -30,7 +29,7 @@ class EndpointHandler:
30
  """
31
  topics = data.pop("topics", data)
32
  texts = data.pop("texts", data)
33
- batch_dict = {
34
  'id': [],
35
  'text': [],
36
  'topic': []
@@ -38,16 +37,16 @@ class EndpointHandler:
38
 
39
  for topic in topics:
40
  for text in texts:
41
- batch_dict['id'].append(text['id'])
42
- batch_dict['text'].append(text['text'])
43
- batch_dict['topic'].append(topic)
44
-
45
- batch = Dataset.from_dict(batch_dict)
46
 
47
  tokenized_inputs = self.tokenize(batch)
48
 
49
  # run normal prediction
50
  output = self.model(**tokenized_inputs)
51
- batch = batch.add_column('predictions', torch.argmax(output.logits, dim=-1).numpy(force=True))
52
- batch = batch.map(lambda b: {'label': [self.model.config.id2label[p] for p in b['predictions']]}, batched=True, remove_columns=['text', 'predictions'])
53
- return batch.to_dict()
 
 
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
5
 
6
  class EndpointHandler:
 
29
  """
30
  topics = data.pop("topics", data)
31
  texts = data.pop("texts", data)
32
+ batch = {
33
  'id': [],
34
  'text': [],
35
  'topic': []
 
37
 
38
  for topic in topics:
39
  for text in texts:
40
+ batch['id'].append(text['id'])
41
+ batch['text'].append(text['text'])
42
+ batch['topic'].append(topic)
 
 
43
 
44
  tokenized_inputs = self.tokenize(batch)
45
 
46
  # run normal prediction
47
  output = self.model(**tokenized_inputs)
48
+ predictions = torch.argmax(output.logits, dim=-1).numpy(force=True)
49
+ batch['label'] = [self.model.config.id2label[p] for p in predictions]
50
+ batch.pop('text')
51
+
52
+ return batch