import streamlit as st import torch from simplet5 import SimpleT5 from transformers import T5Tokenizer, T5ForConditionalGeneration import chinese_converter MODEL_PATH = "hululuzhu/chinese-poem-t5-mengzi-finetune" class PoemModel(SimpleT5): def __init__(self) -> None: super().__init__() self.device = torch.device("cuda") def load_my_model(self): self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH) self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH) # 有一些预先设定参数 AUTHOR_PROMPT = "模仿:" TITLE_PROMPT = "作诗:" EOS_TOKEN = '' poem_model = PoemModel() poem_model.load_my_model() poem_model.model = poem_model.model.to('cuda') MAX_AUTHOR_CHAR = 4 MAX_TITLE_CHAR = 12 MIN_CONTENT_CHAR = 10 MAX_CONTENT_CHAR = 64 poem("百花") def poem(title_str, opt_author=None, model=poem_model, is_input_traditional_chinese=False): model.model = model.model.to('cuda') if opt_author: in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR] else: in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] if is_input_traditional_chinese: in_request = chinese_converter.to_simplified(in_request) out = model.predict(in_request, max_length=MAX_CONTENT_CHAR)[0].replace(",", ",") if is_input_traditional_chinese: out = chinese_converter.to_traditional(out) st.text(f"標題: {in_request.replace('', ' ')}\n詩歌: {out}") else: st.text(f"标题: {in_request.replace('', ' ')}\n诗歌: {out}")