my2000cup commited on
Commit
ce7a9cc
·
verified ·
1 Parent(s): 3c06957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
  MODEL_NAME = "my2000cup/Gaia-Petro-LLM"
6
 
@@ -21,29 +22,35 @@ def build_prompt(history, system_message, user_message):
21
  if assistant:
22
  messages.append({"role": "assistant", "content": assistant})
23
  messages.append({"role": "user", "content": user_message})
24
- # 如果你有chat模板支持,推荐用apply_chat_template
25
  if hasattr(tokenizer, "apply_chat_template"):
26
  prompt = tokenizer.apply_chat_template(
27
  messages, tokenize=False, add_generation_prompt=True
28
  )
29
  else:
30
- # fallback: 简单拼接
31
  prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
32
  return prompt
33
 
34
  def respond(message, history, system_message, max_tokens, temperature, top_p):
35
  prompt = build_prompt(history, system_message, message)
36
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
37
- output = model.generate(
 
 
 
38
  **inputs,
 
39
  max_new_tokens=max_tokens,
40
  temperature=temperature,
41
  top_p=top_p,
42
  do_sample=True,
43
- pad_token_id=tokenizer.eos_token_id
44
  )
45
- response = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
46
- yield response
 
 
 
 
47
 
48
  demo = gr.ChatInterface(
49
  respond,
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ import threading
5
 
6
  MODEL_NAME = "my2000cup/Gaia-Petro-LLM"
7
 
 
22
  if assistant:
23
  messages.append({"role": "assistant", "content": assistant})
24
  messages.append({"role": "user", "content": user_message})
 
25
  if hasattr(tokenizer, "apply_chat_template"):
26
  prompt = tokenizer.apply_chat_template(
27
  messages, tokenize=False, add_generation_prompt=True
28
  )
29
  else:
 
30
  prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
31
  return prompt
32
 
33
  def respond(message, history, system_message, max_tokens, temperature, top_p):
34
  prompt = build_prompt(history, system_message, message)
35
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
36
+
37
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
+ # 在新线程中异步生成
39
+ generation_kwargs = dict(
40
  **inputs,
41
+ streamer=streamer,
42
  max_new_tokens=max_tokens,
43
  temperature=temperature,
44
  top_p=top_p,
45
  do_sample=True,
46
+ pad_token_id=tokenizer.eos_token_id,
47
  )
48
+ gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
49
+ gen_thread.start()
50
+ output = ""
51
+ for new_text in streamer:
52
+ output += new_text
53
+ yield output
54
 
55
  demo = gr.ChatInterface(
56
  respond,