File size: 4,988 Bytes
03f4a93 ca925a9 11ef136 ce7a9cc ca925a9 03f4a93 ca925a9 3c06957 11ef136 3c06957 11ef136 3c06957 11ef136 ca925a9 03f4a93 3c06957 11ef136 3c06957 03f4a93 3c06957 03f4a93 11ef136 ca925a9 11ef136 ce7a9cc 11ef136 03f4a93 ca925a9 11ef136 ca925a9 3c06957 ca925a9 3c06957 ca925a9 11ef136 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# import gradio as gr
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# import threading
# MODEL_NAME = "my2000cup/Gaia-LLM-4B"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_NAME,
# torch_dtype="auto",
# device_map="auto"
# )
# def build_prompt(history, system_message, user_message):
# messages = []
# if system_message:
# messages.append({"role": "system", "content": system_message})
# for user, assistant in history:
# if user:
# messages.append({"role": "user", "content": user})
# if assistant:
# messages.append({"role": "assistant", "content": assistant})
# messages.append({"role": "user", "content": user_message})
# if hasattr(tokenizer, "apply_chat_template"):
# prompt = tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# else:
# prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
# return prompt
# def respond(message, history, system_message, max_tokens, temperature, top_p):
# prompt = build_prompt(history, system_message, message)
# inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
# output_ids = model.generate(
# **inputs,
# max_new_tokens=max_tokens,
# temperature=temperature,
# top_p=top_p,
# do_sample=True,
# pad_token_id=tokenizer.eos_token_id,
# )
# output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# # 提取assistant回复部分(你可以根据模板微调)
# yield output[len(prompt):]
# demo = gr.ChatInterface(
# respond,
# additional_inputs=[
# gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
# ],
# title="Gaia-Petro-LLM Chatbot",
# description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
# )
# if __name__ == "__main__":
# demo.launch()
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
MODEL_NAME = "my2000cup/Gaia-LLM-4B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
def build_prompt(history, system_message, user_message):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
for user, assistant in history:
if user:
messages.append({"role": "user", "content": user})
if assistant:
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": user_message})
if hasattr(tokenizer, "apply_chat_template"):
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
return prompt
def generate_streamer(model, generation_kwargs):
# 用inference_mode避免多线程+GPU死锁
with torch.inference_mode():
model.generate(**generation_kwargs)
def respond(message, history, system_message, max_tokens, temperature, top_p):
prompt = build_prompt(history, system_message, message)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=generate_streamer, args=(model, generation_kwargs))
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="Gaia-Petro-LLM Chatbot",
description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
)
if __name__ == "__main__":
demo.launch() |