orangewong commited on
Commit
d4a97a6
·
verified ·
1 Parent(s): 0f9d18a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import spaces
5
 
6
-
7
  model_name = "Zhihu-ai/Zhi-writing-dsr1-14b"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
  model = AutoModelForCausalLM.from_pretrained(
@@ -12,10 +12,9 @@ model = AutoModelForCausalLM.from_pretrained(
12
  device_map="auto",
13
  trust_remote_code=True
14
  )
15
-
16
  @spaces.GPU()
17
  def predict(message, history):
18
-
19
  history_text = ""
20
  for human, assistant in history:
21
  history_text += f"Human: {human}\nAssistant: {assistant}\n"
@@ -23,40 +22,29 @@ def predict(message, history):
23
 
24
  # 生成回复
25
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
26
-
27
- # 使用流式生成
28
- for response in model.generate(
29
  **inputs,
30
  max_new_tokens=10000,
31
  do_sample=True,
32
  temperature=0.7,
33
  top_p=0.9,
34
  repetition_penalty=1.1,
35
- pad_token_id=tokenizer.eos_token_id,
36
- streamer=gr.TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
- ):
38
- yield response.strip()
39
-
40
-
41
-
42
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
43
- gr.Markdown("# Zhi-writing-dsr1-14")
44
- gr.Markdown("这是一个基于Zhi-writing-dsr1-14的文章生成器")
45
 
46
- chatbot = gr.Chatbot()
47
- msg = gr.Textbox(label="输入消息")
48
- clear = gr.Button("清除对话")
49
-
50
- def respond(message, chat_history):
51
- bot_message = ""
52
- for response in predict(message, chat_history):
53
- bot_message = response
54
- chat_history.append((message, bot_message))
55
- yield chat_history
56
- return "", chat_history
57
-
58
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
59
- clear.click(lambda: None, None, chatbot, queue=False)
60
 
61
  if __name__ == "__main__":
62
  demo.launch(share=True)
 
 
 
3
  import torch
4
  import spaces
5
 
6
+ # 加载模型和分词器
7
  model_name = "Zhihu-ai/Zhi-writing-dsr1-14b"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
  model = AutoModelForCausalLM.from_pretrained(
 
12
  device_map="auto",
13
  trust_remote_code=True
14
  )
 
15
  @spaces.GPU()
16
  def predict(message, history):
17
+ # 构建输入
18
  history_text = ""
19
  for human, assistant in history:
20
  history_text += f"Human: {human}\nAssistant: {assistant}\n"
 
22
 
23
  # 生成回复
24
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
25
+ outputs = model.generate(
 
 
26
  **inputs,
27
  max_new_tokens=10000,
28
  do_sample=True,
29
  temperature=0.7,
30
  top_p=0.9,
31
  repetition_penalty=1.1,
32
+ pad_token_id=tokenizer.eos_token_id
33
+ )
34
+ response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
35
 
36
+ return response.strip()
37
+
38
+ # 创建Gradio界面
39
+ demo = gr.ChatInterface(
40
+ predict,
41
+ title="测试Zhi-writing-dsr1-14b",
42
+ description="Zhihu-ai/Zhi-writing-dsr1-14b",
43
+ examples=["鲁迅口吻写五百字,描述桔猫的可爱!", "桔了个仔是谁", "介绍自己"],
44
+ theme=gr.themes.Soft()
45
+ )
 
 
 
 
46
 
47
  if __name__ == "__main__":
48
  demo.launch(share=True)
49
+
50
+