File size: 2,405 Bytes
26e1cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import jaxtyping
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
import einops
from tqdm import tqdm
from datasets import load_dataset

import os

torch.inference_mode()
torch.set_default_device("cuda")

MODEL_ID = "arcee-ai/Arcee-Blitz"
output_dir = MODEL_ID + "/hidden_states"

n_instructions = 6653
num_layers = 40

# 处理拒绝向量的计算
final_refusal_dirs = []

# 遍历每一条指令的数据
for idx in tqdm(range(n_instructions), desc="Processing instruction"):

    harmful_hidden = torch.load(f"{output_dir}/harmful_hidden_state_{idx}.pt", map_location='cpu', weights_only=True)
    harmless_hidden = torch.load(f"{output_dir}/harmless_hidden_state_{idx}.pt", map_location='cpu', weights_only=True)

    # 针对每一层处理
    for layer_idx in range(num_layers):
        # 获取该指令的每一层的隐藏状态
        harmful_layer_hidden = harmful_hidden[layer_idx]
        harmless_layer_hidden = harmless_hidden[layer_idx]

        # 如果这是第一次处理该层,初始化该层的存储
        if len(final_refusal_dirs) <= layer_idx:
            final_refusal_dirs.append([])

        # 保存该层的有害和无害隐藏状态
        final_refusal_dirs[layer_idx].append((harmful_layer_hidden, harmless_layer_hidden))

    # 释放内存
    del harmful_hidden, harmless_hidden
    torch.cuda.empty_cache()

# 计算每一层的拒绝向量
final_refusal_directions = []

for layer_idx in tqdm(range(num_layers), desc="Calculating refusal direction for layer"):
    pos = -1

    # 将有害和无害隐藏状态分开
    harmful_hidden_list = [hidden[0][:, pos, :] for hidden in final_refusal_dirs[layer_idx]]
    harmless_hidden_list = [hidden[1][:, pos, :] for hidden in final_refusal_dirs[layer_idx]]

    # 计算有害和无害隐藏状态的均值
    harmful_mean = torch.stack(harmful_hidden_list).mean(dim=0)
    harmless_mean = torch.stack(harmless_hidden_list).mean(dim=0)

    # 计算拒绝向量
    refusal_dir = harmful_mean - harmless_mean
    refusal_dir = refusal_dir / refusal_dir.norm()  # 归一化

    # 保存拒绝向量
    final_refusal_directions.append(refusal_dir)

# 最终的拒绝向量存储在 final_refusal_directions 中
torch.save(final_refusal_directions, output_dir + "/final_refusal_dirs.pt")
print("Refusal directions saved successfully.")