# import gradio as gr # import torch # from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # import threading # MODEL_NAME = "my2000cup/Gaia-LLM-4B" # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # model = AutoModelForCausalLM.from_pretrained( # MODEL_NAME, # torch_dtype="auto", # device_map="auto" # ) # def build_prompt(history, system_message, user_message): # messages = [] # if system_message: # messages.append({"role": "system", "content": system_message}) # for user, assistant in history: # if user: # messages.append({"role": "user", "content": user}) # if assistant: # messages.append({"role": "assistant", "content": assistant}) # messages.append({"role": "user", "content": user_message}) # if hasattr(tokenizer, "apply_chat_template"): # prompt = tokenizer.apply_chat_template( # messages, tokenize=False, add_generation_prompt=True # ) # else: # prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:" # return prompt # def respond(message, history, system_message, max_tokens, temperature, top_p): # prompt = build_prompt(history, system_message, message) # inputs = tokenizer([prompt], return_tensors="pt").to(model.device) # output_ids = model.generate( # **inputs, # max_new_tokens=max_tokens, # temperature=temperature, # top_p=top_p, # do_sample=True, # pad_token_id=tokenizer.eos_token_id, # ) # output = tokenizer.decode(output_ids[0], skip_special_tokens=True) # # 提取assistant回复部分(你可以根据模板微调) # yield output[len(prompt):] # demo = gr.ChatInterface( # respond, # additional_inputs=[ # gr.Textbox(value="You are an oil & gas industry expert.", label="System message"), # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), # ], # title="Gaia-Petro-LLM Chatbot", # description="⚡ 基于Hugging Face Transformers的石油行业专家助手。" # ) # if __name__ == "__main__": # demo.launch() import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import threading MODEL_NAME = "my2000cup/Gaia-LLM-4B" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto" ) def build_prompt(history, system_message, user_message): messages = [] if system_message: messages.append({"role": "system", "content": system_message}) for user, assistant in history: if user: messages.append({"role": "user", "content": user}) if assistant: messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": user_message}) if hasattr(tokenizer, "apply_chat_template"): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:" return prompt def generate_streamer(model, generation_kwargs): # 用inference_mode避免多线程+GPU死锁 with torch.inference_mode(): model.generate(**generation_kwargs) def respond(message, history, system_message, max_tokens, temperature, top_p): prompt = build_prompt(history, system_message, message) inputs = tokenizer([prompt], return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=generate_streamer, args=(model, generation_kwargs)) thread.start() output = "" for new_text in streamer: output += new_text yield output demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are an oil & gas industry expert.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], title="Gaia-Petro-LLM Chatbot", description="⚡ 基于Hugging Face Transformers的石油行业专家助手。" ) if __name__ == "__main__": demo.launch()