fschwartzer commited on
Commit
d016c3b
·
verified ·
1 Parent(s): 1d29ee7

Update src/brain.py

Browse files
Files changed (1) hide show
  1. src/brain.py +13 -14
src/brain.py CHANGED
@@ -4,19 +4,18 @@ tokenizer = BertTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese
4
  model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
5
 
6
  def generate_answers(query):
7
- # Garantindo que a query é uma string
8
- if not isinstance(query, str):
9
- raise ValueError("A entrada para a função generate_answers deve ser uma string.")
10
 
11
- # Tokenização
12
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
13
 
14
- # Realizando a predição
15
- outputs = model(**inputs)
16
- prediction = torch.argmax(outputs.logits, dim=1).item() # Converter tensor para um inteiro
17
-
18
- # Labels devem corresponder ao número de classes do modelo
19
- labels = ['ds', 'real', 'Group']
20
- predicted_label = labels[prediction] # Usando o índice para acessar a label
21
-
22
- return predicted_label
 
4
  model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
5
 
6
  def generate_answers(query):
7
+ inputs = tokenizer(query, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
8
+ attention_mask = inputs['attention_mask']
9
+ input_ids = inputs['input_ids']
10
 
11
+ generated_ids = model.generate(
12
+ input_ids,
13
+ attention_mask=attention_mask,
14
+ max_length=len(input_ids[0]) + 100, # Aumentar o limite de geração
15
+ temperature=0.7, # Ajustar a criatividade
16
+ top_p=0.9, # Usar nucleus sampling
17
+ no_repeat_ngram_size=2 # Evitar repetições desnecessárias
18
+ )
19
 
20
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
21
+ return generated_text