hanbin commited on
Commit
e66d039
·
1 Parent(s): 3c2f2d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -2
app.py CHANGED
@@ -1,4 +1,147 @@
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
+ # This file is .....
2
+ # Author: Hanbin Wang
3
+ # Date: 2023/4/18
4
+ import transformers
5
  import streamlit as st
6
+ from PIL import Image
7
+
8
+ from transformers import RobertaTokenizer, T5ForConditionalGeneration
9
+ from transformers import pipeline
10
+
11
+ @st.cache_resource
12
+ def get_model(model_path):
13
+ tokenizer = RobertaTokenizer.from_pretrained(model_path)
14
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
15
+ model.eval()
16
+ return tokenizer, model
17
+
18
+
19
+ def main():
20
+ # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
21
+
22
+ st.set_page_config(
23
+ layout="centered", page_title="MaMaL-Gen Demo(代码生成)", page_icon="❄️"
24
+ )
25
+
26
+ c1, c2 = st.columns([0.32, 2])
27
+
28
+ # The snowflake logo will be displayed in the first column, on the left.
29
+
30
+ with c1:
31
+ st.image(
32
+ "./panda23.png",
33
+ width=100,
34
+ )
35
+
36
+ # The heading will be on the right.
37
+
38
+ with c2:
39
+ st.caption("")
40
+ st.title("MaMaL-Gen(代码生成)")
41
+
42
+
43
+ ############ SIDEBAR CONTENT ############
44
+
45
+ st.sidebar.image("./panda23.png",width=270)
46
+
47
+ st.sidebar.markdown("---")
48
+
49
+ st.sidebar.write(
50
+ """
51
+ ## 使用方法:
52
+ 在【输入】文本框输入自然语言,点击【生成】按钮,即会生成想要的代码。
53
+ """
54
+ )
55
+
56
+ st.sidebar.write(
57
+ """
58
+ ## 注意事项:
59
+ 1)APP托管在外网上,请确保您可以全局科学上网。
60
+
61
+ 2)您可以下载[MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen)模型,本地测试。(无需科学上网)
62
+ """
63
+ )
64
+ # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
65
+
66
+ # We create a text input field for users to enter their API key.
67
+
68
+ # API_KEY = st.sidebar.text_input(
69
+ # "Enter your HuggingFace API key",
70
+ # help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
71
+ # type="password",
72
+ # )
73
+ #
74
+ # # Adding the HuggingFace API inference URL.
75
+ # API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
76
+ #
77
+ # # Now, let's create a Python dictionary to store the API headers.
78
+ # headers = {"Authorization": f"Bearer {API_KEY}"}
79
+
80
+
81
+ st.sidebar.markdown("---")
82
+
83
+ st.write(
84
+ "> **Tip:** 首次运行需要加载模型,可能需要一定的时间!"
85
+ )
86
+
87
+ st.write(
88
+ "> **Tip:** 左侧栏给出了一些good case 和 bad case,you can try it!"
89
+ )
90
+ st.write(
91
+ "> **Tip:** 只支持英文输入,输入过长,效果会变差。"
92
+ )
93
+
94
+ st.sidebar.write(
95
+ "> **Good case:**"
96
+ )
97
+ code_good = """1)Convert a SVG string to a QImage
98
+ 2)Try to seek to given offset"""
99
+ st.sidebar.code(code_good, language='python')
100
+
101
+ st.sidebar.write(
102
+ "> **Bad cases:**"
103
+ )
104
+ code_bad = """Read an OpenAPI binary file ."""
105
+ st.sidebar.code(code_bad, language='python')
106
+
107
+ # Let's add some info about the app to the sidebar.
108
+
109
+ st.sidebar.write(
110
+ """
111
+ App 由 东北大学NLP课小组成员创建, 使用 [Streamlit](https://streamlit.io/)🎈 和 [HuggingFace](https://huggingface.co/inference-api)'s [MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen) 模型.
112
+ """
113
+ )
114
+
115
+ # model, tokenizer = load_model("hanbin/MaMaL-Gen")
116
+ st.write("### 输入:")
117
+ input = st.text_area("", height=100)
118
+ button = st.button('生成')
119
+
120
+ tokenizer,model = get_model("hanbin/MaMaL-Gen")
121
+
122
+ input_ids = tokenizer(input, return_tensors="pt").input_ids
123
+ generated_ids = model.generate(input_ids, max_length=100)
124
+ output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
125
+ # generator = pipeline('text-generation', model="E:\DenseRetrievalGroup\CodeT5-base")
126
+ # output = generator(input)
127
+ # code = '''def hello():
128
+ # print("Hello, Streamlit!")'''
129
+ if button:
130
+ st.write("### 输出:")
131
+ st.code(output, language='python')
132
+ else:
133
+ st.write('#### 输出位置~~')
134
+
135
+
136
+
137
+
138
+
139
+ if __name__ == '__main__':
140
+
141
+ main()
142
+
143
+
144
+
145
+
146
+
147