File size: 4,988 Bytes
03f4a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca925a9
11ef136
ce7a9cc
 
ca925a9
03f4a93
ca925a9
3c06957
11ef136
3c06957
11ef136
 
 
 
 
 
 
 
 
 
 
 
 
 
3c06957
 
 
 
 
 
11ef136
ca925a9
03f4a93
 
 
 
 
3c06957
11ef136
3c06957
03f4a93
 
3c06957
03f4a93
11ef136
ca925a9
 
11ef136
ce7a9cc
11ef136
03f4a93
 
 
 
 
 
ca925a9
 
 
 
11ef136
ca925a9
 
3c06957
ca925a9
3c06957
 
ca925a9
 
 
11ef136
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# import gradio as gr
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# import threading

# MODEL_NAME = "my2000cup/Gaia-LLM-4B"

# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_NAME,
#     torch_dtype="auto",
#     device_map="auto"
# )

# def build_prompt(history, system_message, user_message):
#     messages = []
#     if system_message:
#         messages.append({"role": "system", "content": system_message})
#     for user, assistant in history:
#         if user:
#             messages.append({"role": "user", "content": user})
#         if assistant:
#             messages.append({"role": "assistant", "content": assistant})
#     messages.append({"role": "user", "content": user_message})
#     if hasattr(tokenizer, "apply_chat_template"):
#         prompt = tokenizer.apply_chat_template(
#             messages, tokenize=False, add_generation_prompt=True
#         )
#     else:
#         prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
#     return prompt

# def respond(message, history, system_message, max_tokens, temperature, top_p):
#     prompt = build_prompt(history, system_message, message)
#     inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
#     output_ids = model.generate(
#         **inputs,
#         max_new_tokens=max_tokens,
#         temperature=temperature,
#         top_p=top_p,
#         do_sample=True,
#         pad_token_id=tokenizer.eos_token_id,
#     )
#     output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
#     # 提取assistant回复部分(你可以根据模板微调)
#     yield output[len(prompt):]

# demo = gr.ChatInterface(
#     respond,
#     additional_inputs=[
#         gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
#         gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
#         gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
#         gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
#     ],
#     title="Gaia-Petro-LLM Chatbot",
#     description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
# )

# if __name__ == "__main__":
#     demo.launch()

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading

MODEL_NAME = "my2000cup/Gaia-LLM-4B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto"
)

def build_prompt(history, system_message, user_message):
    messages = []
    if system_message:
        messages.append({"role": "system", "content": system_message})
    for user, assistant in history:
        if user:
            messages.append({"role": "user", "content": user})
        if assistant:
            messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": user_message})
    if hasattr(tokenizer, "apply_chat_template"):
        prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    else:
        prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
    return prompt

def generate_streamer(model, generation_kwargs):
    # 用inference_mode避免多线程+GPU死锁
    with torch.inference_mode():
        model.generate(**generation_kwargs)

def respond(message, history, system_message, max_tokens, temperature, top_p):
    prompt = build_prompt(history, system_message, message)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    thread = threading.Thread(target=generate_streamer, args=(model, generation_kwargs))
    thread.start()
    output = ""
    for new_text in streamer:
        output += new_text
        yield output

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    title="Gaia-Petro-LLM Chatbot",
    description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
)

if __name__ == "__main__":
    demo.launch()