Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from simplet5 import SimpleT5 | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import chinese_converter | |
MODEL_PATH = "hululuzhu/chinese-poem-t5-mengzi-finetune" | |
class PoemModel(SimpleT5): | |
def __init__(self) -> None: | |
super().__init__() | |
self.device = torch.device("cuda") | |
def load_my_model(self): | |
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH) | |
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH) | |
# 有一些预先设定参数 | |
AUTHOR_PROMPT = "模仿:" | |
TITLE_PROMPT = "作诗:" | |
EOS_TOKEN = '</s>' | |
poem_model = PoemModel() | |
poem_model.load_my_model() | |
poem_model.model = poem_model.model.to('cuda') | |
MAX_AUTHOR_CHAR = 4 | |
MAX_TITLE_CHAR = 12 | |
MIN_CONTENT_CHAR = 10 | |
MAX_CONTENT_CHAR = 64 | |
poem("百花") | |
def poem(title_str, opt_author=None, model=poem_model, | |
is_input_traditional_chinese=False): | |
model.model = model.model.to('cuda') | |
if opt_author: | |
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR] | |
else: | |
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] | |
if is_input_traditional_chinese: | |
in_request = chinese_converter.to_simplified(in_request) | |
out = model.predict(in_request, | |
max_length=MAX_CONTENT_CHAR)[0].replace(",", ",") | |
if is_input_traditional_chinese: | |
out = chinese_converter.to_traditional(out) | |
st.text(f"標題: {in_request.replace('</s>', ' ')}\n詩歌: {out}") | |
else: | |
st.text(f"标题: {in_request.replace('</s>', ' ')}\n诗歌: {out}") | |