Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,48 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from simplet5 import SimpleT5
|
4 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
5 |
+
import chinese_converter
|
6 |
+
|
7 |
+
MODEL_PATH = "hululuzhu/chinese-poem-t5-mengzi-finetune"
|
8 |
+
class PoemModel(SimpleT5):
|
9 |
+
def __init__(self) -> None:
|
10 |
+
super().__init__()
|
11 |
+
self.device = torch.device("cuda")
|
12 |
+
|
13 |
+
def load_my_model(self):
|
14 |
+
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)
|
15 |
+
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
|
16 |
+
|
17 |
+
# 有一些预先设定参数
|
18 |
+
AUTHOR_PROMPT = "模仿:"
|
19 |
+
TITLE_PROMPT = "作诗:"
|
20 |
+
EOS_TOKEN = '</s>'
|
21 |
+
|
22 |
+
poem_model = PoemModel()
|
23 |
+
poem_model.load_my_model()
|
24 |
+
poem_model.model = poem_model.model.to('cuda')
|
25 |
+
|
26 |
+
MAX_AUTHOR_CHAR = 4
|
27 |
+
MAX_TITLE_CHAR = 12
|
28 |
+
MIN_CONTENT_CHAR = 10
|
29 |
+
MAX_CONTENT_CHAR = 64
|
30 |
+
|
31 |
+
poem("百花")
|
32 |
+
|
33 |
+
def poem(title_str, opt_author=None, model=poem_model,
|
34 |
+
is_input_traditional_chinese=False):
|
35 |
+
model.model = model.model.to('cuda')
|
36 |
+
if opt_author:
|
37 |
+
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR]
|
38 |
+
else:
|
39 |
+
in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR]
|
40 |
+
if is_input_traditional_chinese:
|
41 |
+
in_request = chinese_converter.to_simplified(in_request)
|
42 |
+
out = model.predict(in_request,
|
43 |
+
max_length=MAX_CONTENT_CHAR)[0].replace(",", ",")
|
44 |
+
if is_input_traditional_chinese:
|
45 |
+
out = chinese_converter.to_traditional(out)
|
46 |
+
st.text(f"標題: {in_request.replace('</s>', ' ')}\n詩歌: {out}")
|
47 |
+
else:
|
48 |
+
st.text(f"标题: {in_request.replace('</s>', ' ')}\n诗歌: {out}")
|