Quintino Fernandes commited on
Commit
80ceb39
·
1 Parent(s): b6506cc

Model changes

Browse files
Files changed (1) hide show
  1. models/summarization.py +6 -4
models/summarization.py CHANGED
@@ -4,10 +4,12 @@ import torch
4
  class SummarizationModel:
5
  def __init__(self):
6
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
- self.tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')
8
  self.model = T5ForConditionalGeneration.from_pretrained('recogna-nlp/ptt5-base-summ').to(self.device)
9
 
10
- def summarize(self, text: str, max_length: int = 256, min_length: int = 128) -> str:
 
 
11
  """Summarize the input text using T5 model"""
12
  # Model and tokenization parameters
13
  inputs = self.tokenizer.encode(
@@ -19,8 +21,8 @@ class SummarizationModel:
19
 
20
  summary_ids = self.model.generate(
21
  inputs,
22
- max_length=max_length,
23
- min_length=min_length,
24
  num_beams=4,
25
  no_repeat_ngram_size=3,
26
  early_stopping=True,
 
4
  class SummarizationModel:
5
  def __init__(self):
6
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ self.tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab').to(self.device)
8
  self.model = T5ForConditionalGeneration.from_pretrained('recogna-nlp/ptt5-base-summ').to(self.device)
9
 
10
+ def summarize(self, text: str) -> str:
11
+ print(next(self.tokenizer.parameters()).device)
12
+ print(next(self.model.parameters()).device)
13
  """Summarize the input text using T5 model"""
14
  # Model and tokenization parameters
15
  inputs = self.tokenizer.encode(
 
21
 
22
  summary_ids = self.model.generate(
23
  inputs,
24
+ max_length=256,
25
+ min_length=64,
26
  num_beams=4,
27
  no_repeat_ngram_size=3,
28
  early_stopping=True,