Huihui-InternVL3_5-1B-Instruct-abliterated / 01-compute_refusal_dir-Arcee-Blitz-2.py
huihui-ai's picture
Add files using upload-large-folder tool
26e1cba verified
import jaxtyping
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
import einops
from tqdm import tqdm
from datasets import load_dataset
import os
torch.inference_mode()
torch.set_default_device("cuda")
MODEL_ID = "arcee-ai/Arcee-Blitz"
output_dir = MODEL_ID + "/hidden_states"
n_instructions = 6653
num_layers = 40
# 处理拒绝向量的计算
final_refusal_dirs = []
# 遍历每一条指令的数据
for idx in tqdm(range(n_instructions), desc="Processing instruction"):
harmful_hidden = torch.load(f"{output_dir}/harmful_hidden_state_{idx}.pt", map_location='cpu', weights_only=True)
harmless_hidden = torch.load(f"{output_dir}/harmless_hidden_state_{idx}.pt", map_location='cpu', weights_only=True)
# 针对每一层处理
for layer_idx in range(num_layers):
# 获取该指令的每一层的隐藏状态
harmful_layer_hidden = harmful_hidden[layer_idx]
harmless_layer_hidden = harmless_hidden[layer_idx]
# 如果这是第一次处理该层,初始化该层的存储
if len(final_refusal_dirs) <= layer_idx:
final_refusal_dirs.append([])
# 保存该层的有害和无害隐藏状态
final_refusal_dirs[layer_idx].append((harmful_layer_hidden, harmless_layer_hidden))
# 释放内存
del harmful_hidden, harmless_hidden
torch.cuda.empty_cache()
# 计算每一层的拒绝向量
final_refusal_directions = []
for layer_idx in tqdm(range(num_layers), desc="Calculating refusal direction for layer"):
pos = -1
# 将有害和无害隐藏状态分开
harmful_hidden_list = [hidden[0][:, pos, :] for hidden in final_refusal_dirs[layer_idx]]
harmless_hidden_list = [hidden[1][:, pos, :] for hidden in final_refusal_dirs[layer_idx]]
# 计算有害和无害隐藏状态的均值
harmful_mean = torch.stack(harmful_hidden_list).mean(dim=0)
harmless_mean = torch.stack(harmless_hidden_list).mean(dim=0)
# 计算拒绝向量
refusal_dir = harmful_mean - harmless_mean
refusal_dir = refusal_dir / refusal_dir.norm() # 归一化
# 保存拒绝向量
final_refusal_directions.append(refusal_dir)
# 最终的拒绝向量存储在 final_refusal_directions 中
torch.save(final_refusal_directions, output_dir + "/final_refusal_dirs.pt")
print("Refusal directions saved successfully.")