# inference.py from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch # --- 模型路径 --- base_model_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/models/Qwen-7B-Chat" lora_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/training/test1" # --- 加载 tokenizer --- tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) # --- 加载模型 --- model = AutoModelForCausalLM.from_pretrained( base_model_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, ) model = PeftModel.from_pretrained(model, lora_path) # --- 推理 --- instruction = "生成分析输入表格。生成的表格应包括 SampleID,fastq_P1,fastq_P2,..." user_input = "/gpfs/gibbs/pi/augert/Collaboration/guangxiao/batch6_2/01Sam/merge_data\n├── PU1_WT_D7_R1_P1.fastq.gz\n..." prompt = f"你是一个自动化助手。\n\n用户:{instruction}\n{user_input}\n\n助手:" # ✅ 关键修复:只提取 input_ids 和 attention_mask inputs = tokenizer(prompt, return_tensors="pt") inputs = { "input_ids": inputs["input_ids"].to("cuda"), "attention_mask": inputs["attention_mask"].to("cuda"), # ✅ 显式排除 token_type_ids } # ✅ 关键修复:禁用缓存(避免 past_key_values 问题) outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9, use_cache=False, # ✅ 关键:禁用缓存,避免 past_key_values 问题 ) # 解码输出 response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(response)