from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria from peft import PeftModel import torch import gradio as gr import os import re class ChineseCharacterStop(StoppingCriteria): def __init__(self, chars: list[str]): self.chars = [ tokenizer(i, add_special_tokens=False, return_tensors='pt').input_ids for i in chars ] # for chars, tokens in zip(chars, self.chars): # print(f"'{chars}':{tokens}") def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for c in self.chars: c = c.to(input_ids.device) match = torch.eq(input_ids[..., -c.shape[1]:], c) if torch.any(torch.all(match, dim=1)): return True return False tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M") tokenizer.pad_token = tokenizer.eos_token gpt2_model = AutoModelForCausalLM.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M") model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4.1') def cang_tou(tou: str): poem_now = "写一首唐诗:" for c in tou: poem_now += c print(poem_now) inputs = tokenizer(poem_now, return_tensors='pt') outputs = model.generate( **inputs, return_dict_in_generate=True, max_length=150, do_sample=True, top_p=0.4, num_beams=1, num_return_sequences=1, stopping_criteria=[ChineseCharacterStop(['。', ','])], pad_token_id=tokenizer.pad_token_id ) poem_now = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0] print(poem_now) return poem_now[6:] def prompt_gen(prompt): inputs = tokenizer(prompt, return_tensors='pt') outputs = model.generate( **inputs, return_dict_in_generate=True, max_length=200, do_sample=True, top_p=0.8, num_beams=5, num_return_sequences=3, # stopping_criteria=[ChineseCharacterStop(['。', ',', ''])], pad_token_id=tokenizer.pad_token_id ) res = '' for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True): line = line[len(prompt):] res = res+line+'\n' return res css = """ #col-container {max-width: 510px; margin-left: auto; margin-right: auto;} a {text-decoration-line: underline; font-weight: 600;} .animate-spin { animation: spin 1s linear infinite; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( """
根据输入的提示生成古诗、藏头诗