censky commited on
Commit
01481cf
1 Parent(s): 6cf73f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -1,6 +1,48 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- translator = pipeline("translation_en_to_zh", model='Helsinki-NLP/opus-mt-en-zh')
4
- text = st.text_input('请输入文本', '')
5
- if st.button('生成'):
6
- st.write(translator(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from simplet5 import SimpleT5
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+ import chinese_converter
6
+
7
+ MODEL_PATH = "hululuzhu/chinese-poem-t5-mengzi-finetune"
8
+ class PoemModel(SimpleT5):
9
+ def __init__(self) -> None:
10
+ super().__init__()
11
+ self.device = torch.device("cuda")
12
+
13
+ def load_my_model(self):
14
+ self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)
15
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
16
+
17
+ # 有一些预先设定参数
18
+ AUTHOR_PROMPT = "模仿:"
19
+ TITLE_PROMPT = "作诗:"
20
+ EOS_TOKEN = '</s>'
21
+
22
+ poem_model = PoemModel()
23
+ poem_model.load_my_model()
24
+ poem_model.model = poem_model.model.to('cuda')
25
+
26
+ MAX_AUTHOR_CHAR = 4
27
+ MAX_TITLE_CHAR = 12
28
+ MIN_CONTENT_CHAR = 10
29
+ MAX_CONTENT_CHAR = 64
30
+
31
+ poem("百花")
32
+
33
+ def poem(title_str, opt_author=None, model=poem_model,
34
+ is_input_traditional_chinese=False):
35
+ model.model = model.model.to('cuda')
36
+ if opt_author:
37
+ in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR]
38
+ else:
39
+ in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR]
40
+ if is_input_traditional_chinese:
41
+ in_request = chinese_converter.to_simplified(in_request)
42
+ out = model.predict(in_request,
43
+ max_length=MAX_CONTENT_CHAR)[0].replace(",", ",")
44
+ if is_input_traditional_chinese:
45
+ out = chinese_converter.to_traditional(out)
46
+ st.text(f"標題: {in_request.replace('</s>', ' ')}\n詩歌: {out}")
47
+ else:
48
+ st.text(f"标题: {in_request.replace('</s>', ' ')}\n诗歌: {out}")