import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "my2000cup/Gaia-Petro-LLM" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) def build_prompt(history, system_message, user_message): # 可以根据你的模型模板调整 messages = [] if system_message: messages.append({"role": "system", "content": system_message}) for user, assistant in history: if user: messages.append({"role": "user", "content": user}) if assistant: messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": user_message}) # 使用 tokenizer 的 chat 模板 prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt def respond( message, history, system_message, max_tokens, temperature, top_p ): prompt = build_prompt(history, system_message, message) model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) # 流式输出 streamer = None try: from transformers import TextIteratorStreamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) except ImportError: streamer = None gen_kwargs = dict( **model_inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) if streamer: gen_kwargs["streamer"] = streamer thread = torch.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response thread.join() else: output = model.generate(**gen_kwargs) response = tokenizer.decode(output[0][model_inputs['input_ids'].shape[1]:], skip_special_tokens=True) yield response demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are an oil & gas industry expert.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()