my2000cup commited on
Commit
3c06957
·
verified ·
1 Parent(s): ba0a1d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -50
app.py CHANGED
@@ -2,17 +2,16 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- model_name = "my2000cup/Gaia-Petro-LLM"
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
  torch_dtype="auto",
11
  device_map="auto"
12
  )
13
 
14
  def build_prompt(history, system_message, user_message):
15
- # 可以根据你的模型模板调整
16
  messages = []
17
  if system_message:
18
  messages.append({"role": "system", "content": system_message})
@@ -22,54 +21,29 @@ def build_prompt(history, system_message, user_message):
22
  if assistant:
23
  messages.append({"role": "assistant", "content": assistant})
24
  messages.append({"role": "user", "content": user_message})
25
- # 使用 tokenizer 的 chat 模板
26
- prompt = tokenizer.apply_chat_template(
27
- messages,
28
- tokenize=False,
29
- add_generation_prompt=True
30
- )
 
 
31
  return prompt
32
 
33
- def respond(
34
- message,
35
- history,
36
- system_message,
37
- max_tokens,
38
- temperature,
39
- top_p
40
- ):
41
  prompt = build_prompt(history, system_message, message)
42
- model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
43
-
44
- # 流式输出
45
- streamer = None
46
- try:
47
- from transformers import TextIteratorStreamer
48
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
49
- except ImportError:
50
- streamer = None
51
-
52
- gen_kwargs = dict(
53
- **model_inputs,
54
  max_new_tokens=max_tokens,
55
  temperature=temperature,
56
  top_p=top_p,
57
  do_sample=True,
58
  pad_token_id=tokenizer.eos_token_id
59
  )
60
- if streamer:
61
- gen_kwargs["streamer"] = streamer
62
- thread = torch.Thread(target=model.generate, kwargs=gen_kwargs)
63
- thread.start()
64
- response = ""
65
- for new_text in streamer:
66
- response += new_text
67
- yield response
68
- thread.join()
69
- else:
70
- output = model.generate(**gen_kwargs)
71
- response = tokenizer.decode(output[0][model_inputs['input_ids'].shape[1]:], skip_special_tokens=True)
72
- yield response
73
 
74
  demo = gr.ChatInterface(
75
  respond,
@@ -77,14 +51,10 @@ demo = gr.ChatInterface(
77
  gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
78
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
79
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
80
- gr.Slider(
81
- minimum=0.1,
82
- maximum=1.0,
83
- value=0.95,
84
- step=0.05,
85
- label="Top-p (nucleus sampling)",
86
- ),
87
  ],
 
 
88
  )
89
 
90
  if __name__ == "__main__":
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ MODEL_NAME = "my2000cup/Gaia-Petro-LLM"
6
 
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(
9
+ MODEL_NAME,
10
  torch_dtype="auto",
11
  device_map="auto"
12
  )
13
 
14
  def build_prompt(history, system_message, user_message):
 
15
  messages = []
16
  if system_message:
17
  messages.append({"role": "system", "content": system_message})
 
21
  if assistant:
22
  messages.append({"role": "assistant", "content": assistant})
23
  messages.append({"role": "user", "content": user_message})
24
+ # 如果你有chat模板支持,推荐用apply_chat_template
25
+ if hasattr(tokenizer, "apply_chat_template"):
26
+ prompt = tokenizer.apply_chat_template(
27
+ messages, tokenize=False, add_generation_prompt=True
28
+ )
29
+ else:
30
+ # fallback: 简单拼接
31
+ prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
32
  return prompt
33
 
34
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
35
  prompt = build_prompt(history, system_message, message)
36
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
37
+ output = model.generate(
38
+ **inputs,
 
 
 
 
 
 
 
 
 
39
  max_new_tokens=max_tokens,
40
  temperature=temperature,
41
  top_p=top_p,
42
  do_sample=True,
43
  pad_token_id=tokenizer.eos_token_id
44
  )
45
+ response = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
46
+ yield response
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  demo = gr.ChatInterface(
49
  respond,
 
51
  gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
52
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
53
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
54
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
55
  ],
56
+ title="Gaia-Petro-LLM Chatbot",
57
+ description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
58
  )
59
 
60
  if __name__ == "__main__":