transformers / app.py
censky's picture
Update app.py
01481cf
import streamlit as st
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration
import chinese_converter
MODEL_PATH = "hululuzhu/chinese-poem-t5-mengzi-finetune"
class PoemModel(SimpleT5):
def __init__(self) -> None:
super().__init__()
self.device = torch.device("cuda")
def load_my_model(self):
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
# 有一些预先设定参数
AUTHOR_PROMPT = "模仿:"
TITLE_PROMPT = "作诗:"
EOS_TOKEN = '</s>'
poem_model = PoemModel()
poem_model.load_my_model()
poem_model.model = poem_model.model.to('cuda')
MAX_AUTHOR_CHAR = 4
MAX_TITLE_CHAR = 12
MIN_CONTENT_CHAR = 10
MAX_CONTENT_CHAR = 64
poem("百花")
def poem(title_str, opt_author=None, model=poem_model,
is_input_traditional_chinese=False):
model.model = model.model.to('cuda')
if opt_author:
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR]
else:
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR]
if is_input_traditional_chinese:
in_request = chinese_converter.to_simplified(in_request)
out = model.predict(in_request,
max_length=MAX_CONTENT_CHAR)[0].replace(",", ",")
if is_input_traditional_chinese:
out = chinese_converter.to_traditional(out)
st.text(f"標題: {in_request.replace('</s>', ' ')}\n詩歌: {out}")
else:
st.text(f"标题: {in_request.replace('</s>', ' ')}\n诗歌: {out}")