maayanorner commited on
Commit
17055ea
verified
1 Parent(s): f3ba46f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -1
README.md CHANGED
@@ -24,12 +24,15 @@ from transformers import (
24
 
25
 
26
  def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
 
 
27
  if len(text) < 20:
28
  raise ValueError('Text must be at least 20 characters long.')
29
  # This text template is important.
30
  inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
31
  in_data = inputs.input_ids.to('cuda')
32
- output_ids = model.generate(input_ids=in_data, num_beams=num_beams, max_new_tokens = max_new_tokens, do_sample=True, early_stopping=True, use_cache=True, temperature=temperature, eos_token_id=tokenizer.eos_token_id)
 
33
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
34
 
35
  return generated_text
 
24
 
25
 
26
  def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
  if len(text) < 20:
30
  raise ValueError('Text must be at least 20 characters long.')
31
  # This text template is important.
32
  inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
33
  in_data = inputs.input_ids.to('cuda')
34
+ attention_mask = inputs.attention_mask.to('cuda')
35
+ output_ids = model.generate(input_ids=in_data, attention_mask=attention_mask, num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True, early_stopping=True, use_cache=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
36
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
37
 
38
  return generated_text