gaia_demo / app.py
my2000cup's picture
Update app.py
03f4a93 verified
# import gradio as gr
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# import threading
# MODEL_NAME = "my2000cup/Gaia-LLM-4B"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_NAME,
# torch_dtype="auto",
# device_map="auto"
# )
# def build_prompt(history, system_message, user_message):
# messages = []
# if system_message:
# messages.append({"role": "system", "content": system_message})
# for user, assistant in history:
# if user:
# messages.append({"role": "user", "content": user})
# if assistant:
# messages.append({"role": "assistant", "content": assistant})
# messages.append({"role": "user", "content": user_message})
# if hasattr(tokenizer, "apply_chat_template"):
# prompt = tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# else:
# prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
# return prompt
# def respond(message, history, system_message, max_tokens, temperature, top_p):
# prompt = build_prompt(history, system_message, message)
# inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
# output_ids = model.generate(
# **inputs,
# max_new_tokens=max_tokens,
# temperature=temperature,
# top_p=top_p,
# do_sample=True,
# pad_token_id=tokenizer.eos_token_id,
# )
# output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# # 提取assistant回复部分(你可以根据模板微调)
# yield output[len(prompt):]
# demo = gr.ChatInterface(
# respond,
# additional_inputs=[
# gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
# ],
# title="Gaia-Petro-LLM Chatbot",
# description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
# )
# if __name__ == "__main__":
# demo.launch()
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
MODEL_NAME = "my2000cup/Gaia-LLM-4B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
def build_prompt(history, system_message, user_message):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
for user, assistant in history:
if user:
messages.append({"role": "user", "content": user})
if assistant:
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": user_message})
if hasattr(tokenizer, "apply_chat_template"):
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
return prompt
def generate_streamer(model, generation_kwargs):
# 用inference_mode避免多线程+GPU死锁
with torch.inference_mode():
model.generate(**generation_kwargs)
def respond(message, history, system_message, max_tokens, temperature, top_p):
prompt = build_prompt(history, system_message, message)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=generate_streamer, args=(model, generation_kwargs))
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are an oil & gas industry expert.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="Gaia-Petro-LLM Chatbot",
description="⚡ 基于Hugging Face Transformers的石油行业专家助手。"
)
if __name__ == "__main__":
demo.launch()