Update README.md
Browse files
README.md
CHANGED
@@ -24,12 +24,15 @@ from transformers import (
|
|
24 |
|
25 |
|
26 |
def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
|
|
|
|
|
27 |
if len(text) < 20:
|
28 |
raise ValueError('Text must be at least 20 characters long.')
|
29 |
# This text template is important.
|
30 |
inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
|
31 |
in_data = inputs.input_ids.to('cuda')
|
32 |
-
|
|
|
33 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
34 |
|
35 |
return generated_text
|
|
|
24 |
|
25 |
|
26 |
def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
|
27 |
+
if tokenizer.pad_token is None:
|
28 |
+
tokenizer.pad_token = tokenizer.eos_token
|
29 |
if len(text) < 20:
|
30 |
raise ValueError('Text must be at least 20 characters long.')
|
31 |
# This text template is important.
|
32 |
inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
|
33 |
in_data = inputs.input_ids.to('cuda')
|
34 |
+
attention_mask = inputs.attention_mask.to('cuda')
|
35 |
+
output_ids = model.generate(input_ids=in_data, attention_mask=attention_mask, num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True, early_stopping=True, use_cache=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
|
36 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
37 |
|
38 |
return generated_text
|