File size: 4,173 Bytes
e66d039
 
 
 
c793c5f
e66d039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c793c5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# This file is .....
# Author: Hanbin Wang
# Date: 2023/4/18
import transformers
import streamlit as st
from PIL import Image

from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers import pipeline

@st.cache_resource
def get_model(model_path):
    tokenizer = RobertaTokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(model_path)
    model.eval()
    return tokenizer, model


def main():
    # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.

    st.set_page_config(
        layout="centered", page_title="MaMaL-Gen Demo(代码生成)", page_icon="❄️"
    )

    c1, c2 = st.columns([0.32, 2])

    # The snowflake logo will be displayed in the first column, on the left.

    with c1:
        st.image(
            "./panda23.png",
            width=100,
        )

    # The heading will be on the right.

    with c2:
        st.caption("")
        st.title("MaMaL-Gen(代码生成)")


    ############ SIDEBAR CONTENT ############

    st.sidebar.image("./panda23.png",width=270)

    st.sidebar.markdown("---")

    st.sidebar.write(
    """
    ## 使用方法:
    在【输入】文本框输入自然语言,点击【生成】按钮,即会生成想要的代码。
    """
    )

    st.sidebar.write(
    """
    ## 注意事项:
    1)APP托管在外网上,请确保您可以全局科学上网。
    
    2)您可以下载[MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen)模型,本地测试。(无需科学上网)
    """
    )
    # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.

    # We create a text input field for users to enter their API key.

    # API_KEY = st.sidebar.text_input(
    #     "Enter your HuggingFace API key",
    #     help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
    #     type="password",
    # )
    #
    # # Adding the HuggingFace API inference URL.
    # API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
    #
    # # Now, let's create a Python dictionary to store the API headers.
    # headers = {"Authorization": f"Bearer {API_KEY}"}


    st.sidebar.markdown("---")

    st.write(
        "> **Tip:** 首次运行需要加载模型,可能需要一定的时间!"
    )

    st.write(
        "> **Tip:** 左侧栏给出了一些good case 和 bad case,you can try it!"
    )
    st.write(
        "> **Tip:** 只支持英文输入,输入过长,效果会变差。"
    )

    st.sidebar.write(
        "> **Good case:**"
    )
    code_good = """1)Convert a SVG string to a QImage
2)Try to seek to given offset"""
    st.sidebar.code(code_good, language='python')

    st.sidebar.write(
        "> **Bad cases:**"
    )
    code_bad = """Read an OpenAPI binary file ."""
    st.sidebar.code(code_bad, language='python')

    # Let's add some info about the app to the sidebar.

    st.sidebar.write(
    """
    App 由 东北大学NLP课小组成员创建, 使用 [Streamlit](https://streamlit.io/)🎈 和 [HuggingFace](https://huggingface.co/inference-api)'s [MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen) 模型.
    """
    )

    # model, tokenizer = load_model("hanbin/MaMaL-Gen")
    st.write("### 输入:")
    input = st.text_area("", height=100)
    button = st.button('生成')

    tokenizer,model = get_model("hanbin/MaMaL-Gen")

    input_ids = tokenizer(input, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids, max_length=100)
    output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    # generator = pipeline('text-generation', model="E:\DenseRetrievalGroup\CodeT5-base")
    # output = generator(input)
    # code = '''def hello():
    #     print("Hello, Streamlit!")'''
    if button:
        st.write("### 输出:")
        st.code(output, language='python')
    else:
        st.write('####         输出位置~~')





if __name__ == '__main__':

    main()