Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True, device_map="auto", load_in_8bit=True) | |
#from transformers import | |
#from modelscope.pipelines import pipeline | |
#from modelscope.utils.constant import Tasks | |
#text_generation_zh = pipeline(task=Tasks.text_generation, model='baichuan-inc/baichuan-7B',model_revision='v1.0.2') | |
#text_generation_zh._model_prepare = True | |
def text_generation(input_text): | |
if input_text == "": | |
return "" | |
inputs = tokenizer(input_text, return_tensors='pt').to('cuda:0') #"run->walk\nwrite->" | |
pred = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1) | |
return tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) #text_generation_zh(input_text)['text'] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# [Chinese] 百川-7B 模型体验 | |
输入文本以查看生成结果。 | |
# [English] Baichuan-7B model experience | |
Enter text to see build results. | |
""") | |
inp = gr.Textbox(label="输入Prompt / Input Prompt") | |
submit = gr.Button("提交/Submit") | |
out = gr.Textbox(label="续写结果/Generated Text") | |
submit.click(text_generation, inp, out) | |
gr.Markdown("## Text Examples") | |
gr.Examples( | |
["登鹳雀楼->王之涣\n夜雨寄北->", "Hamlet->Shakespeare\nOne Hundred Years of Solitude->"], | |
inp, | |
out, | |
text_generation, | |
cache_examples=True, | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |