aifeifei798's picture
Upload 3 files
a9d6403 verified
raw
history blame
6.28 kB
# ====================================================================================
# diagnose_layers.py
# 目的:通过计算和可视化每一层激活的“AB面差异显著性”,
# 来诊断模型中特定行为(如“安全拒绝”)主要发生在哪些层。
# ====================================================================================
print("--- Running Layer Diagnosis Script ---")
import torch
import gc
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- [配置参数] ---
MODEL_ID = "./gemma-3-4b-it-qat-q4_0-unquantized"
# 为了快速诊断,我们使用较少的样本
NUM_SAMPLES_TO_DIAGNOSE = 64
BATCH_SIZE = 4
OUTPUT_CHART_FILENAME = "layer_significance_chart.png"
# --- [STEP 1] 设置模型和分词器 ---
print(f"\n[STEP 1] Loading model and tokenizer from: {MODEL_ID}")
torch.set_grad_enabled(False)
hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
local_files_only=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, local_files_only=True)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
TOTAL_LAYERS = hf_model.config.text_config.num_hidden_layers
print(f"[SUCCESS] Model with {TOTAL_LAYERS} layers and tokenizer loaded.")
# --- [STEP 2] 准备数据 ---
print(f"\n[STEP 2] Preparing datasets with {NUM_SAMPLES_TO_DIAGNOSE} samples each...")
def reformat_texts(texts):
return [[{"role": "user", "content": text}] for text in texts]
# 加载有害指令 (A面)
harmful_dataset = load_dataset('./harmful_behaviors')
harmful_inst = reformat_texts(harmful_dataset['train']['text'])[:NUM_SAMPLES_TO_DIAGNOSE]
# 加载无害指令 (B面)
harmless_dataset = load_dataset('./harmless_alpaca')
harmless_inst = reformat_texts(harmless_dataset['train']['text'])[:NUM_SAMPLES_TO_DIAGNOSE]
print("[SUCCESS] Datasets prepared.")
# --- [STEP 3] 定义激活收集的辅助函数 ---
# 计算最大长度
all_texts = [instr[0]['content'] for instr in harmful_inst + harmless_inst]
max_len = max([tokenizer(text, return_tensors="pt").input_ids.shape[1] for text in all_texts])
print(f"Max sequence length calculated: {max_len}")
def tokenize_instructions(tokenizer, instructions, max_length):
return tokenizer.apply_chat_template(
instructions, padding="max_length", truncation=True, max_length=max_length,
return_tensors="pt", return_dict=True, add_generation_prompt=True,
)
def get_activations(model, instructions, num_samples):
"""一个专门用于收集所有层激活的函数"""
cache = defaultdict(list)
def create_hook_fn(layer_name):
def hook_fn(module, input, output):
# 我们只关心最后一个 token 的激活,并且只收集残差流的输出
# output[0] 是残差流的激活张量
cache[layer_name].append(output[0][:, -1, :].cpu())
return hook_fn
hooks = []
for i in range(TOTAL_LAYERS):
layer_name = f"layer_{i}"
module = model.get_submodule(f"model.language_model.layers.{i}")
hook = module.register_forward_hook(create_hook_fn(layer_name))
hooks.append(hook)
num_batches = (num_samples + BATCH_SIZE - 1) // BATCH_SIZE
for i in tqdm(range(num_batches), desc="Collecting activations"):
start_idx, end_idx = i * BATCH_SIZE, min(num_samples, (i + 1) * BATCH_SIZE)
batch_instructions = instructions[start_idx:end_idx]
tokenized_input = tokenize_instructions(tokenizer, batch_instructions, max_length=max_len).to(model.device)
model(**tokenized_input)
for hook in hooks: hook.remove()
# 拼接张量
for layer_name, activations in cache.items():
cache[layer_name] = torch.cat(activations, dim=0)
return cache
# --- [STEP 4] 收集两类数据的激活 ---
print("\n[STEP 4] Collecting activations for both datasets...")
print("Collecting for Harmful dataset (A-Side)...")
harmful_activations = get_activations(hf_model, harmful_inst, NUM_SAMPLES_TO_DIAGNOSE)
print("Collecting for Harmless dataset (B-Side)...")
harmless_activations = get_activations(hf_model, harmless_inst, NUM_SAMPLES_TO_DIAGNOSE)
print("[SUCCESS] All activations collected.")
# --- [STEP 5] 计算并可视化每层的显著性 ---
print("\n[STEP 5] Calculating and visualizing layer significance...")
layer_significance = []
layer_indices = range(TOTAL_LAYERS)
for l in layer_indices:
layer_name = f"layer_{l}"
# 计算每个数据集在该层的均值激活
harmful_mean_act = harmful_activations[layer_name].mean(dim=0)
harmless_mean_act = harmless_activations[layer_name].mean(dim=0)
# 计算差分向量
diff_vector = harmful_mean_act - harmless_mean_act
# 计算其L2范数(模长)作为显著性得分
significance = torch.linalg.norm(diff_vector).item()
layer_significance.append(significance)
print(f"Layer {l:02d}: Significance (L2 Norm of diff) = {significance:.4f}")
# 清理内存
del harmful_activations, harmless_activations, hf_model
gc.collect()
torch.cuda.empty_cache()
# 绘制图表
print(f"\n[STEP 6] Generating chart and saving to {OUTPUT_CHART_FILENAME}...")
plt.style.use('seaborn-v0_8-whitegrid') # 使用一个好看的样式
fig, ax = plt.subplots(figsize=(15, 7))
ax.plot(layer_indices, layer_significance, marker='o', linestyle='-', color='royalblue', label='Signal Significance')
ax.set_title('Significance of "Refusal Signal" Across Model Layers', fontsize=16, fontweight='bold')
ax.set_xlabel('Layer Index', fontsize=12)
ax.set_ylabel('Significance Score (L2 Norm of Activation Difference)', fontsize=12)
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
ax.set_xticks(np.arange(0, TOTAL_LAYERS, 2)) # 每隔2层显示一个刻度
ax.legend()
plt.tight_layout()
# 保存图表
plt.savefig(OUTPUT_CHART_FILENAME)
print(f"\n[SUCCESS] Diagnosis complete. Chart saved to '{OUTPUT_CHART_FILENAME}'.")
print("You can now analyze this chart to determine the optimal layers for your fine-tuning surgery.")