Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria | |
from peft import PeftModel | |
import torch | |
import gradio as gr | |
import os | |
import re | |
class ChineseCharacterStop(StoppingCriteria): | |
def __init__(self, chars: list[str]): | |
self.chars = [ | |
tokenizer(i, add_special_tokens=False, return_tensors='pt').input_ids | |
for i in chars | |
] | |
# for chars, tokens in zip(chars, self.chars): | |
# print(f"'{chars}':{tokens}") | |
def __call__(self, input_ids: torch.LongTensor, | |
scores: torch.FloatTensor, **kwargs) -> bool: | |
for c in self.chars: | |
c = c.to(input_ids.device) | |
match = torch.eq(input_ids[..., -c.shape[1]:], c) | |
if torch.any(torch.all(match, dim=1)): | |
return True | |
return False | |
tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M") | |
tokenizer.pad_token = tokenizer.eos_token | |
gpt2_model = AutoModelForCausalLM.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M") | |
model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4.1') | |
def cang_tou(tou: str): | |
poem_now = "写一首唐诗:" | |
for c in tou: | |
poem_now += c | |
print(poem_now) | |
inputs = tokenizer(poem_now, return_tensors='pt') | |
outputs = model.generate( | |
**inputs, | |
return_dict_in_generate=True, | |
max_length=150, | |
do_sample=True, | |
top_p=0.4, | |
num_beams=1, | |
num_return_sequences=1, | |
stopping_criteria=[ChineseCharacterStop(['。', ','])], | |
pad_token_id=tokenizer.pad_token_id | |
) | |
poem_now = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0] | |
print(poem_now) | |
return poem_now[6:] | |
def prompt_gen(prompt): | |
inputs = tokenizer(prompt, return_tensors='pt') | |
outputs = model.generate( | |
**inputs, | |
return_dict_in_generate=True, | |
max_length=200, | |
do_sample=True, | |
top_p=0.8, | |
num_beams=5, | |
num_return_sequences=3, | |
# stopping_criteria=[ChineseCharacterStop(['。', ',', ''])], | |
pad_token_id=tokenizer.pad_token_id | |
) | |
res = '' | |
for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True): | |
line = line[len(prompt):] | |
res = res+line+'\n' | |
return res | |
css = """ | |
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;} | |
a {text-decoration-line: underline; font-weight: 600;} | |
.animate-spin { | |
animation: spin 1s linear infinite; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;">✨古诗生成</h1> | |
<p style="text-align: center;"> | |
根据输入的提示生成古诗、藏头诗<br /> | |
</p> | |
""" | |
) | |
with gr.Tab("提示"): | |
prompt_in = gr.Textbox(label="Prompt", placeholder="写一首关于思乡的古诗:", elem_id="prompt-in") | |
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in") | |
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False) | |
submit_btn = gr.Button("Submit") | |
poetry_result = gr.Textbox(label="Output", elem_id="poetry-output") | |
submit_btn.click(fn=prompt_gen, | |
inputs=[prompt_in], | |
outputs=[poetry_result]) | |
with gr.Tab("藏头诗"): | |
tou_in = gr.Textbox(label="Prompt", placeholder="一见如故", elem_id="tou-in") | |
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in") | |
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False) | |
submit_btn = gr.Button("Submit") | |
cangtou_result = gr.Textbox(label="Output", elem_id="cangtou-output") | |
submit_btn.click(fn=cang_tou, | |
inputs=[tou_in], | |
outputs=[cangtou_result]) | |
demo.queue(max_size=12).launch() |