Spaces:
Runtime error
Runtime error
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif ( | |
hasattr(torch.backends, "mps") | |
and torch.backends.mps.is_available() | |
and torch.backends.mps.is_built() | |
): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"running device: {device}") | |
auth_token = os.environ.get("TOKEN_READ_SECRET") or True | |
tokenizer = AutoTokenizer.from_pretrained( | |
"NorHsangPha/shan_gpt2_news", token=auth_token | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
"NorHsangPha/shan_gpt2_news", pad_token_id=tokenizer.eos_token_id, token=auth_token | |
).to(device) | |
def greedy_search(model_inputs, max_new_tokens): | |
greedy_output = model.generate(**model_inputs, max_new_tokens=max_new_tokens) | |
return tokenizer.decode(greedy_output[0], skip_special_tokens=True) | |
def beem_search(model_inputs, max_new_tokens): | |
beam_output = model.generate( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
num_beams=5, | |
no_repeat_ngram_size=2, # | |
num_return_sequences=5, # | |
early_stopping=True, | |
) | |
return tokenizer.decode(beam_output[0], skip_special_tokens=True) | |
def sample_outputs(model_inputs, max_new_tokens): | |
sample_output = model.generate( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_k=0, | |
temperature=0.6, | |
) | |
return tokenizer.decode(sample_output[0], skip_special_tokens=True) | |
def top_k_search(model_inputs, max_new_tokens): | |
top_k_output = model.generate( | |
**model_inputs, max_new_tokens=max_new_tokens, do_sample=True, top_k=50 | |
) | |
return tokenizer.decode(top_k_output[0], skip_special_tokens=True) | |
def top_p_search(model_inputs, max_new_tokens): | |
top_p_output = model.generate( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=0.92, | |
top_k=0, | |
) | |
return tokenizer.decode(top_p_output[0], skip_special_tokens=True) | |
def generate_text(input_text, search_method="sample_outputs"): | |
model_inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
max_new_tokens = 120 | |
match search_method: | |
case "greedy_search": | |
text = greedy_search(model_inputs, max_new_tokens) | |
case "beem_search": | |
text = beem_search(model_inputs, max_new_tokens) | |
case "top_k_search": | |
text = top_k_search(model_inputs, max_new_tokens) | |
case "top_p_search": | |
text = top_p_search(model_inputs, max_new_tokens) | |
case _: | |
text = sample_outputs(model_inputs, max_new_tokens) | |
return text | |
GENERATE_EXAMPLES = [ | |
["αααΊααα―ααΊαΆαα", "sample_outputs"], | |
["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "greedy_search"], | |
["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "top_k_search"], | |
["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "top_p_search"], | |
["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "beem_search"], | |
] | |