poem-generate / app.py
Wendyy's picture
Update app.py
624b44d
raw
history blame
4.28 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
from peft import PeftModel
import torch
import gradio as gr
import os
import re
class ChineseCharacterStop(StoppingCriteria):
def __init__(self, chars: list[str]):
self.chars = [
tokenizer(i, add_special_tokens=False, return_tensors='pt').input_ids
for i in chars
]
# for chars, tokens in zip(chars, self.chars):
# print(f"'{chars}':{tokens}")
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor, **kwargs) -> bool:
for c in self.chars:
c = c.to(input_ids.device)
match = torch.eq(input_ids[..., -c.shape[1]:], c)
if torch.any(torch.all(match, dim=1)):
return True
return False
tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
tokenizer.pad_token = tokenizer.eos_token
gpt2_model = AutoModelForCausalLM.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4.1')
def cang_tou(tou: str):
poem_now = "写一首唐诗:"
for c in tou:
poem_now += c
print(poem_now)
inputs = tokenizer(poem_now, return_tensors='pt')
outputs = model.generate(
**inputs,
return_dict_in_generate=True,
max_length=150,
do_sample=True,
top_p=0.4,
num_beams=1,
num_return_sequences=1,
stopping_criteria=[ChineseCharacterStop(['。', ','])],
pad_token_id=tokenizer.pad_token_id
)
poem_now = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
print(poem_now)
return poem_now[6:]
def prompt_gen(prompt):
inputs = tokenizer(prompt, return_tensors='pt')
outputs = model.generate(
**inputs,
return_dict_in_generate=True,
max_length=200,
do_sample=True,
top_p=0.8,
num_beams=5,
num_return_sequences=3,
# stopping_criteria=[ChineseCharacterStop(['。', ',', ''])],
pad_token_id=tokenizer.pad_token_id
)
res = ''
for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True):
line = line[len(prompt):]
res = res+line+'\n'
return res
css = """
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">✨古诗生成</h1>
<p style="text-align: center;">
根据输入的提示生成古诗、藏头诗<br />
</p>
"""
)
with gr.Tab("提示"):
prompt_in = gr.Textbox(label="Prompt", placeholder="写一首关于思乡的古诗:", elem_id="prompt-in")
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
submit_btn = gr.Button("Submit")
poetry_result = gr.Textbox(label="Output", elem_id="poetry-output")
submit_btn.click(fn=prompt_gen,
inputs=[prompt_in],
outputs=[poetry_result])
with gr.Tab("藏头诗"):
tou_in = gr.Textbox(label="Prompt", placeholder="一见如故", elem_id="tou-in")
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
submit_btn = gr.Button("Submit")
cangtou_result = gr.Textbox(label="Output", elem_id="cangtou-output")
submit_btn.click(fn=cang_tou,
inputs=[tou_in],
outputs=[cangtou_result])
demo.queue(max_size=12).launch()