Qwen_test / inference.py
qiyongli22's picture
Upload folder using huggingface_hub
fc213de verified
# inference.py
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# --- 模型路径 ---
base_model_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/models/Qwen-7B-Chat"
lora_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/training/test1"
# --- 加载 tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
# --- 加载模型 ---
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(model, lora_path)
# --- 推理 ---
instruction = "生成分析输入表格。生成的表格应包括 SampleID,fastq_P1,fastq_P2,..."
user_input = "/gpfs/gibbs/pi/augert/Collaboration/guangxiao/batch6_2/01Sam/merge_data\n├── PU1_WT_D7_R1_P1.fastq.gz\n..."
prompt = f"你是一个自动化助手。\n\n用户:{instruction}\n{user_input}\n\n助手:"
# ✅ 关键修复:只提取 input_ids 和 attention_mask
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
"input_ids": inputs["input_ids"].to("cuda"),
"attention_mask": inputs["attention_mask"].to("cuda"),
# ✅ 显式排除 token_type_ids
}
# ✅ 关键修复:禁用缓存(避免 past_key_values 问题)
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=False, # ✅ 关键:禁用缓存,避免 past_key_values 问题
)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)