File size: 1,630 Bytes
fc213de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# inference.py
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# --- 模型路径 ---
base_model_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/models/Qwen-7B-Chat"
lora_path = "/home/yq238/project_pi_aaa247/yq238/qwen_training/training/test1"

# --- 加载 tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)

# --- 加载模型 ---
model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(model, lora_path)

# --- 推理 ---
instruction = "生成分析输入表格。生成的表格应包括 SampleID,fastq_P1,fastq_P2,..."
user_input = "/gpfs/gibbs/pi/augert/Collaboration/guangxiao/batch6_2/01Sam/merge_data\n├── PU1_WT_D7_R1_P1.fastq.gz\n..."

prompt = f"你是一个自动化助手。\n\n用户:{instruction}\n{user_input}\n\n助手:"

# ✅ 关键修复:只提取 input_ids 和 attention_mask
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
    "input_ids": inputs["input_ids"].to("cuda"),
    "attention_mask": inputs["attention_mask"].to("cuda"),
    # ✅ 显式排除 token_type_ids
}

# ✅ 关键修复:禁用缓存(避免 past_key_values 问题)
outputs = model.generate(
    **inputs,
    max_new_tokens=1024,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    use_cache=False,  # ✅ 关键:禁用缓存,避免 past_key_values 问题
)

# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)