Spaces:
Runtime error
Runtime error
John Doe
commited on
Commit
·
04bdba9
1
Parent(s):
684ff8a
app.py, Zmaker.py, requirements.txtのアップロード
Browse filesapp.py : streamlitによるGUI制御
Zmaker.py : fine-tuning済みのGPT-2で推論を行うためのコード
- app.py +94 -0
- requirements.txt +67 -0
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from Zmaker import Zmaker
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
|
6 |
+
#ファインチューニング済みモデルの読み込み
|
7 |
+
with st.spinner(text = "loading GPT-2..."):
|
8 |
+
if not ("AI" in st.session_state.keys()):
|
9 |
+
st.session_state["AI"] = Zmaker(
|
10 |
+
ft_path = "model/gpt2-ft/"
|
11 |
+
)
|
12 |
+
|
13 |
+
#設定用サイドバーの設定
|
14 |
+
with st.sidebar:
|
15 |
+
st.title("GPT-2のパラメータ")
|
16 |
+
|
17 |
+
#max_lenの設定用スライダ
|
18 |
+
sld_max_len = st.sidebar.slider(
|
19 |
+
"length of the sentence", min_value = 0, max_value = 256,
|
20 |
+
value = (25, 75), step = 1, key = "length"
|
21 |
+
)
|
22 |
+
|
23 |
+
#temperatureの設定用スライダ
|
24 |
+
sld_temp = st.sidebar.slider(
|
25 |
+
"temperature", min_value = 0.1, max_value = 1.5,
|
26 |
+
value = 0.1, step = 0.1, key = "temp"
|
27 |
+
)
|
28 |
+
|
29 |
+
#top_kの設定用スライダ
|
30 |
+
sld_top_k = st.sidebar.slider(
|
31 |
+
"top_k", min_value = 0, max_value = 500,
|
32 |
+
value = 40, step = 1, key = "top_k"
|
33 |
+
)
|
34 |
+
|
35 |
+
#top_pの設定用スライダ
|
36 |
+
sld_top_p = st.sidebar.slider(
|
37 |
+
"top_p", min_value = 0.01, max_value = 1.0,
|
38 |
+
value = 0.95, step = 0.01, key = "top_p"
|
39 |
+
)
|
40 |
+
|
41 |
+
#repeat_ngram_sizeの設定用スライダ
|
42 |
+
sld_top_p = st.sidebar.slider(
|
43 |
+
"repeat_ngram_size ", min_value = 1, max_value = 10,
|
44 |
+
value = 1, step = 1, key = "repeat_ngram_size"
|
45 |
+
)
|
46 |
+
|
47 |
+
#メインフォームの設定
|
48 |
+
with st.form(key = "Letter Form", clear_on_submit = False):
|
49 |
+
st.title("おてがみ 入力欄")
|
50 |
+
body = st.empty()
|
51 |
+
if ("letter_body" in st.session_state.keys()):
|
52 |
+
ret = body.text_area(
|
53 |
+
label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\
|
54 |
+
"本アプリで生成できるのは本文のみです。",
|
55 |
+
value = st.session_state["letter_body"]
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
ret = body.text_area(
|
59 |
+
label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\
|
60 |
+
"続きをAIが生成します。",
|
61 |
+
value = "ズッポシ村へようこそ!"
|
62 |
+
)
|
63 |
+
sub = st.form_submit_button("Generate")
|
64 |
+
|
65 |
+
#注意事項
|
66 |
+
with st.expander("注意事項"):
|
67 |
+
st.text(
|
68 |
+
"※このAIは「どうぶつの森e+実況プレイ」"\
|
69 |
+
" (https://www.nicovideo.jp/mylist/45062007)において"\
|
70 |
+
" 稲葉百万鉄氏により作成された文章を学習データに用いております。\n"
|
71 |
+
" また,教師データの作成においてmintmama氏の作成した"\
|
72 |
+
" 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\
|
73 |
+
"を用いております。"
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
#submitボタンが押された
|
78 |
+
if sub == True:
|
79 |
+
#predictに必要な条件をGUIで設定した値に更新
|
80 |
+
st.session_state["AI"].min_len = st.session_state["length"][0]
|
81 |
+
st.session_state["AI"].max_len = st.session_state["length"][-1]
|
82 |
+
st.session_state["AI"].top_k = st.session_state["top_k"]
|
83 |
+
st.session_state["AI"].top_p = st.session_state["top_p"]
|
84 |
+
st.session_state["AI"].temp = st.session_state["temp"]
|
85 |
+
st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"]
|
86 |
+
|
87 |
+
#AIによる予測を実行
|
88 |
+
with st.spinner(text = "generating..."):
|
89 |
+
prompt = ret
|
90 |
+
text = str(st.session_state["AI"].GenLetter("<s>"+prompt)[0])
|
91 |
+
text = text.replace('<s>', '')
|
92 |
+
text = text.replace('</s>', '')
|
93 |
+
st.session_state["letter_body"] = text
|
94 |
+
st.experimental_rerun()
|
requirements.txt
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.18.0
|
2 |
+
altair==4.2.2
|
3 |
+
attrs==23.1.0
|
4 |
+
blinker==1.6.2
|
5 |
+
cachetools==5.3.0
|
6 |
+
certifi==2022.12.7
|
7 |
+
charset-normalizer==3.1.0
|
8 |
+
click==8.1.3
|
9 |
+
colorama==0.4.6
|
10 |
+
decorator==5.1.1
|
11 |
+
entrypoints==0.4
|
12 |
+
filelock==3.12.0
|
13 |
+
fsspec==2023.4.0
|
14 |
+
gitdb==4.0.10
|
15 |
+
GitPython==3.1.31
|
16 |
+
huggingface-hub==0.14.1
|
17 |
+
idna==3.4
|
18 |
+
importlib-metadata==6.6.0
|
19 |
+
Jinja2==3.1.2
|
20 |
+
JsonForm==0.0.2
|
21 |
+
jsonschema==4.17.3
|
22 |
+
JsonSir==0.0.2
|
23 |
+
markdown-it-py==2.2.0
|
24 |
+
MarkupSafe==2.1.2
|
25 |
+
mdurl==0.1.2
|
26 |
+
mojimoji==0.0.12
|
27 |
+
mpmath==1.3.0
|
28 |
+
networkx==3.1
|
29 |
+
numpy==1.24.3
|
30 |
+
packaging==23.1
|
31 |
+
pandas==2.0.1
|
32 |
+
Pillow==9.5.0
|
33 |
+
protobuf==3.20.3
|
34 |
+
psutil==5.9.5
|
35 |
+
pyarrow==12.0.0
|
36 |
+
pydeck==0.8.1b0
|
37 |
+
Pygments==2.15.1
|
38 |
+
Pympler==1.0.1
|
39 |
+
pyrsistent==0.19.3
|
40 |
+
python-dateutil==2.8.2
|
41 |
+
Python-EasyConfig==0.1.7
|
42 |
+
pytz==2023.3
|
43 |
+
pytz-deprecation-shim==0.1.0.post0
|
44 |
+
PyYAML==6.0
|
45 |
+
regex==2023.3.23
|
46 |
+
requests==2.29.0
|
47 |
+
rich==13.3.5
|
48 |
+
sentencepiece==0.1.99
|
49 |
+
six==1.16.0
|
50 |
+
smmap==5.0.0
|
51 |
+
streamlit==1.22.0
|
52 |
+
sympy==1.11.1
|
53 |
+
tenacity==8.2.2
|
54 |
+
tokenizers==0.13.3
|
55 |
+
toml==0.10.2
|
56 |
+
toolz==0.12.0
|
57 |
+
torch==2.0.0
|
58 |
+
tornado==6.3.1
|
59 |
+
tqdm==4.65.0
|
60 |
+
transformers==4.28.1
|
61 |
+
typing_extensions==4.5.0
|
62 |
+
tzdata==2023.3
|
63 |
+
tzlocal==4.3
|
64 |
+
urllib3==1.26.15
|
65 |
+
validators==0.20.0
|
66 |
+
watchdog==3.0.0
|
67 |
+
zipp==3.15.0
|