|
|
|
|
|
|
|
|
|
|
|
|
|
print("--- Running Layer Diagnosis Script ---") |
|
|
|
import torch |
|
import gc |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from collections import defaultdict |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
MODEL_ID = "./gemma-3-4b-it-qat-q4_0-unquantized" |
|
|
|
NUM_SAMPLES_TO_DIAGNOSE = 64 |
|
BATCH_SIZE = 4 |
|
OUTPUT_CHART_FILENAME = "layer_significance_chart.png" |
|
|
|
|
|
print(f"\n[STEP 1] Loading model and tokenizer from: {MODEL_ID}") |
|
torch.set_grad_enabled(False) |
|
|
|
hf_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
local_files_only=True, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, local_files_only=True) |
|
tokenizer.padding_side = 'left' |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
TOTAL_LAYERS = hf_model.config.text_config.num_hidden_layers |
|
print(f"[SUCCESS] Model with {TOTAL_LAYERS} layers and tokenizer loaded.") |
|
|
|
|
|
|
|
print(f"\n[STEP 2] Preparing datasets with {NUM_SAMPLES_TO_DIAGNOSE} samples each...") |
|
|
|
def reformat_texts(texts): |
|
return [[{"role": "user", "content": text}] for text in texts] |
|
|
|
|
|
harmful_dataset = load_dataset('./harmful_behaviors') |
|
harmful_inst = reformat_texts(harmful_dataset['train']['text'])[:NUM_SAMPLES_TO_DIAGNOSE] |
|
|
|
|
|
harmless_dataset = load_dataset('./harmless_alpaca') |
|
harmless_inst = reformat_texts(harmless_dataset['train']['text'])[:NUM_SAMPLES_TO_DIAGNOSE] |
|
|
|
print("[SUCCESS] Datasets prepared.") |
|
|
|
|
|
|
|
|
|
|
|
all_texts = [instr[0]['content'] for instr in harmful_inst + harmless_inst] |
|
max_len = max([tokenizer(text, return_tensors="pt").input_ids.shape[1] for text in all_texts]) |
|
print(f"Max sequence length calculated: {max_len}") |
|
|
|
def tokenize_instructions(tokenizer, instructions, max_length): |
|
return tokenizer.apply_chat_template( |
|
instructions, padding="max_length", truncation=True, max_length=max_length, |
|
return_tensors="pt", return_dict=True, add_generation_prompt=True, |
|
) |
|
|
|
def get_activations(model, instructions, num_samples): |
|
"""一个专门用于收集所有层激活的函数""" |
|
cache = defaultdict(list) |
|
def create_hook_fn(layer_name): |
|
def hook_fn(module, input, output): |
|
|
|
|
|
cache[layer_name].append(output[0][:, -1, :].cpu()) |
|
return hook_fn |
|
|
|
hooks = [] |
|
for i in range(TOTAL_LAYERS): |
|
layer_name = f"layer_{i}" |
|
module = model.get_submodule(f"model.language_model.layers.{i}") |
|
hook = module.register_forward_hook(create_hook_fn(layer_name)) |
|
hooks.append(hook) |
|
|
|
num_batches = (num_samples + BATCH_SIZE - 1) // BATCH_SIZE |
|
for i in tqdm(range(num_batches), desc="Collecting activations"): |
|
start_idx, end_idx = i * BATCH_SIZE, min(num_samples, (i + 1) * BATCH_SIZE) |
|
batch_instructions = instructions[start_idx:end_idx] |
|
tokenized_input = tokenize_instructions(tokenizer, batch_instructions, max_length=max_len).to(model.device) |
|
model(**tokenized_input) |
|
|
|
for hook in hooks: hook.remove() |
|
|
|
|
|
for layer_name, activations in cache.items(): |
|
cache[layer_name] = torch.cat(activations, dim=0) |
|
|
|
return cache |
|
|
|
|
|
print("\n[STEP 4] Collecting activations for both datasets...") |
|
|
|
print("Collecting for Harmful dataset (A-Side)...") |
|
harmful_activations = get_activations(hf_model, harmful_inst, NUM_SAMPLES_TO_DIAGNOSE) |
|
|
|
print("Collecting for Harmless dataset (B-Side)...") |
|
harmless_activations = get_activations(hf_model, harmless_inst, NUM_SAMPLES_TO_DIAGNOSE) |
|
|
|
print("[SUCCESS] All activations collected.") |
|
|
|
|
|
|
|
print("\n[STEP 5] Calculating and visualizing layer significance...") |
|
|
|
layer_significance = [] |
|
layer_indices = range(TOTAL_LAYERS) |
|
|
|
for l in layer_indices: |
|
layer_name = f"layer_{l}" |
|
|
|
|
|
harmful_mean_act = harmful_activations[layer_name].mean(dim=0) |
|
harmless_mean_act = harmless_activations[layer_name].mean(dim=0) |
|
|
|
|
|
diff_vector = harmful_mean_act - harmless_mean_act |
|
|
|
|
|
significance = torch.linalg.norm(diff_vector).item() |
|
layer_significance.append(significance) |
|
print(f"Layer {l:02d}: Significance (L2 Norm of diff) = {significance:.4f}") |
|
|
|
|
|
del harmful_activations, harmless_activations, hf_model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
print(f"\n[STEP 6] Generating chart and saving to {OUTPUT_CHART_FILENAME}...") |
|
plt.style.use('seaborn-v0_8-whitegrid') |
|
fig, ax = plt.subplots(figsize=(15, 7)) |
|
|
|
ax.plot(layer_indices, layer_significance, marker='o', linestyle='-', color='royalblue', label='Signal Significance') |
|
ax.set_title('Significance of "Refusal Signal" Across Model Layers', fontsize=16, fontweight='bold') |
|
ax.set_xlabel('Layer Index', fontsize=12) |
|
ax.set_ylabel('Significance Score (L2 Norm of Activation Difference)', fontsize=12) |
|
ax.grid(True, which='both', linestyle='--', linewidth=0.5) |
|
ax.set_xticks(np.arange(0, TOTAL_LAYERS, 2)) |
|
ax.legend() |
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(OUTPUT_CHART_FILENAME) |
|
|
|
print(f"\n[SUCCESS] Diagnosis complete. Chart saved to '{OUTPUT_CHART_FILENAME}'.") |
|
print("You can now analyze this chart to determine the optimal layers for your fine-tuning surgery.") |
|
|