Spaces:
Sleeping
Sleeping
File size: 4,173 Bytes
e66d039 c793c5f e66d039 c793c5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# This file is .....
# Author: Hanbin Wang
# Date: 2023/4/18
import transformers
import streamlit as st
from PIL import Image
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers import pipeline
@st.cache_resource
def get_model(model_path):
tokenizer = RobertaTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()
return tokenizer, model
def main():
# `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
st.set_page_config(
layout="centered", page_title="MaMaL-Gen Demo(代码生成)", page_icon="❄️"
)
c1, c2 = st.columns([0.32, 2])
# The snowflake logo will be displayed in the first column, on the left.
with c1:
st.image(
"./panda23.png",
width=100,
)
# The heading will be on the right.
with c2:
st.caption("")
st.title("MaMaL-Gen(代码生成)")
############ SIDEBAR CONTENT ############
st.sidebar.image("./panda23.png",width=270)
st.sidebar.markdown("---")
st.sidebar.write(
"""
## 使用方法:
在【输入】文本框输入自然语言,点击【生成】按钮,即会生成想要的代码。
"""
)
st.sidebar.write(
"""
## 注意事项:
1)APP托管在外网上,请确保您可以全局科学上网。
2)您可以下载[MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen)模型,本地测试。(无需科学上网)
"""
)
# For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
# We create a text input field for users to enter their API key.
# API_KEY = st.sidebar.text_input(
# "Enter your HuggingFace API key",
# help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
# type="password",
# )
#
# # Adding the HuggingFace API inference URL.
# API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
#
# # Now, let's create a Python dictionary to store the API headers.
# headers = {"Authorization": f"Bearer {API_KEY}"}
st.sidebar.markdown("---")
st.write(
"> **Tip:** 首次运行需要加载模型,可能需要一定的时间!"
)
st.write(
"> **Tip:** 左侧栏给出了一些good case 和 bad case,you can try it!"
)
st.write(
"> **Tip:** 只支持英文输入,输入过长,效果会变差。"
)
st.sidebar.write(
"> **Good case:**"
)
code_good = """1)Convert a SVG string to a QImage
2)Try to seek to given offset"""
st.sidebar.code(code_good, language='python')
st.sidebar.write(
"> **Bad cases:**"
)
code_bad = """Read an OpenAPI binary file ."""
st.sidebar.code(code_bad, language='python')
# Let's add some info about the app to the sidebar.
st.sidebar.write(
"""
App 由 东北大学NLP课小组成员创建, 使用 [Streamlit](https://streamlit.io/)🎈 和 [HuggingFace](https://huggingface.co/inference-api)'s [MaMaL-Gen](https://huggingface.co/hanbin/MaMaL-Gen) 模型.
"""
)
# model, tokenizer = load_model("hanbin/MaMaL-Gen")
st.write("### 输入:")
input = st.text_area("", height=100)
button = st.button('生成')
tokenizer,model = get_model("hanbin/MaMaL-Gen")
input_ids = tokenizer(input, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=100)
output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# generator = pipeline('text-generation', model="E:\DenseRetrievalGroup\CodeT5-base")
# output = generator(input)
# code = '''def hello():
# print("Hello, Streamlit!")'''
if button:
st.write("### 输出:")
st.code(output, language='python')
else:
st.write('#### 输出位置~~')
if __name__ == '__main__':
main()
|