File size: 4,282 Bytes
624b44d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
from peft import PeftModel
import torch
import gradio as gr
import os
import re

class ChineseCharacterStop(StoppingCriteria):
    def __init__(self, chars: list[str]):
        self.chars = [
            tokenizer(i, add_special_tokens=False, return_tensors='pt').input_ids
            for i in chars
        ]
        # for chars, tokens in zip(chars, self.chars):
        #     print(f"'{chars}':{tokens}")

    def __call__(self, input_ids: torch.LongTensor,
                 scores: torch.FloatTensor, **kwargs) -> bool:
        for c in self.chars:
            c = c.to(input_ids.device)
            match = torch.eq(input_ids[..., -c.shape[1]:], c)
            if torch.any(torch.all(match, dim=1)):
                return True
        return False


tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
tokenizer.pad_token = tokenizer.eos_token
gpt2_model = AutoModelForCausalLM.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4.1')


def cang_tou(tou: str):
    poem_now = "写一首唐诗:"
    for c in tou:
        poem_now += c
        print(poem_now)
        inputs = tokenizer(poem_now, return_tensors='pt')
        outputs = model.generate(
            **inputs,
            return_dict_in_generate=True,
            max_length=150,
            do_sample=True,
            top_p=0.4,
            num_beams=1,
            num_return_sequences=1,
            stopping_criteria=[ChineseCharacterStop(['。', ','])],
            pad_token_id=tokenizer.pad_token_id
        )
        poem_now = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
    print(poem_now)
    return poem_now[6:]


def prompt_gen(prompt):
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(
        **inputs,
        return_dict_in_generate=True,
        max_length=200,
        do_sample=True,
        top_p=0.8,
        num_beams=5,
        num_return_sequences=3,
        # stopping_criteria=[ChineseCharacterStop(['。', ',', ''])],
        pad_token_id=tokenizer.pad_token_id
    )
    res = ''
    for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True):
      line = line[len(prompt):]
      res = res+line+'\n'
    return res

css = """
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
  animation: spin 1s linear infinite;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            <h1 style="text-align: center;">✨古诗生成</h1>
            <p style="text-align: center;">
            根据输入的提示生成古诗、藏头诗<br />
            </p>            
            """
        )
        with gr.Tab("提示"):
            prompt_in = gr.Textbox(label="Prompt", placeholder="写一首关于思乡的古诗:", elem_id="prompt-in")
            #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
            #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
            submit_btn = gr.Button("Submit")
            poetry_result = gr.Textbox(label="Output", elem_id="poetry-output")

            submit_btn.click(fn=prompt_gen,
                    inputs=[prompt_in],
                    outputs=[poetry_result])
        
        with gr.Tab("藏头诗"):
            tou_in = gr.Textbox(label="Prompt", placeholder="一见如故", elem_id="tou-in")
            #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
            #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
            submit_btn = gr.Button("Submit")
            cangtou_result = gr.Textbox(label="Output", elem_id="cangtou-output")
            submit_btn.click(fn=cang_tou,
                    inputs=[tou_in],
                    outputs=[cangtou_result])
        
    

demo.queue(max_size=12).launch()