fschwartzer commited on
Commit
1d29ee7
·
verified ·
1 Parent(s): 151d72b

Update src/brain.py

Browse files
Files changed (1) hide show
  1. src/brain.py +13 -3
src/brain.py CHANGED
@@ -4,9 +4,19 @@ tokenizer = BertTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese
4
  model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
5
 
6
  def generate_answers(query):
 
 
 
 
 
7
  inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
 
 
8
  outputs = model(**inputs)
9
- prediction = torch.argmax(outputs.logits, dim=1)
10
- labels = ['ds','real','Group']
11
- predicted_label = labels[prediction]
 
 
 
12
  return predicted_label
 
4
  model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
5
 
6
  def generate_answers(query):
7
+ # Garantindo que a query é uma string
8
+ if not isinstance(query, str):
9
+ raise ValueError("A entrada para a função generate_answers deve ser uma string.")
10
+
11
+ # Tokenização
12
  inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
13
+
14
+ # Realizando a predição
15
  outputs = model(**inputs)
16
+ prediction = torch.argmax(outputs.logits, dim=1).item() # Converter tensor para um inteiro
17
+
18
+ # Labels devem corresponder ao número de classes do modelo
19
+ labels = ['ds', 'real', 'Group']
20
+ predicted_label = labels[prediction] # Usando o índice para acessar a label
21
+
22
  return predicted_label