|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig |
|
from threading import Thread |
|
import gradio as gr |
|
import torch |
|
|
|
lora_folder = '' |
|
model_folder = '' |
|
|
|
config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == '' |
|
else lora_folder), |
|
trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == '' |
|
else model_folder), |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True) |
|
model = PeftModel.from_pretrained(model, |
|
("Junity/Genshin-World-Model" if lora_folder == '' |
|
else lora_folder), |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == '' |
|
else model_folder), |
|
trust_remote_code=True) |
|
|
|
history = [] |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k): |
|
if textbox != '': |
|
textbox = (textbox |
|
+ "\n" |
|
+ role_name |
|
+ (":" if role_name != '' else '') |
|
+ msg |
|
+ ('。\n' if msg[-1] not in ['。', '!', '?'] else '')) |
|
yield ["", textbox] |
|
else: |
|
textbox = (textbox |
|
+ role_name |
|
+ (":" if role_name != '' else '') |
|
+ msg |
|
+ ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '') |
|
+ ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else '')) |
|
yield ["", textbox] |
|
if character_name != '': |
|
textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':' |
|
input_ids = tokenizer.encode(textbox)[-3200:] |
|
input_ids = torch.LongTensor([input_ids]).to(device) |
|
generation_config = model.generation_config |
|
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) |
|
gen_kwargs = {} |
|
gen_kwargs.update(dict( |
|
input_ids=input_ids, |
|
temperature=temp, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=rep, |
|
max_new_tokens=max_len, |
|
do_sample=True, |
|
)) |
|
outputs = [] |
|
print(input_ids) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) |
|
gen_kwargs["streamer"] = streamer |
|
|
|
thread = Thread(target=model.generate, kwargs=gen_kwargs) |
|
thread.start() |
|
|
|
for new_text in streamer: |
|
textbox += new_text |
|
yield ["", textbox] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
## Genshin-World-Model |
|
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) |
|
- 此模型不支持要求对方回答什么,只支持续写。 |
|
- 目前运行不了,因为没有钱租卡。 |
|
""" |
|
) |
|
with gr.Tab("创作") as chat: |
|
role_name = gr.Textbox(label="你将扮演的角色(可留空)") |
|
character_name = gr.Textbox(label="对方的角色(可留空)") |
|
msg = gr.Textbox(label="你说的话") |
|
with gr.Row(): |
|
clear = gr.ClearButton() |
|
sub = gr.Button("Submit", variant="primary") |
|
with gr.Row(): |
|
temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True) |
|
rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True) |
|
max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True) |
|
top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True) |
|
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True) |
|
textbox = gr.Textbox(interactive=True, label="全部文本(可修改)") |
|
clear.add([msg, role_name, textbox]) |
|
sub.click(fn=respond, |
|
inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k], |
|
outputs=[msg, textbox]) |
|
gr.Markdown( |
|
""" |
|
#### 特别鸣谢 XXXX |
|
""" |
|
) |
|
demo.queue().launch() |
|
|