|
import jaxtyping |
|
import random |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig |
|
import einops |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
|
|
import os |
|
|
|
torch.inference_mode() |
|
torch.set_default_device("cuda") |
|
|
|
MODEL_ID = "arcee-ai/Arcee-Blitz" |
|
output_dir = MODEL_ID + "/hidden_states" |
|
|
|
n_instructions = 6653 |
|
num_layers = 40 |
|
|
|
|
|
final_refusal_dirs = [] |
|
|
|
|
|
for idx in tqdm(range(n_instructions), desc="Processing instruction"): |
|
|
|
harmful_hidden = torch.load(f"{output_dir}/harmful_hidden_state_{idx}.pt", map_location='cpu', weights_only=True) |
|
harmless_hidden = torch.load(f"{output_dir}/harmless_hidden_state_{idx}.pt", map_location='cpu', weights_only=True) |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
harmful_layer_hidden = harmful_hidden[layer_idx] |
|
harmless_layer_hidden = harmless_hidden[layer_idx] |
|
|
|
|
|
if len(final_refusal_dirs) <= layer_idx: |
|
final_refusal_dirs.append([]) |
|
|
|
|
|
final_refusal_dirs[layer_idx].append((harmful_layer_hidden, harmless_layer_hidden)) |
|
|
|
|
|
del harmful_hidden, harmless_hidden |
|
torch.cuda.empty_cache() |
|
|
|
|
|
final_refusal_directions = [] |
|
|
|
for layer_idx in tqdm(range(num_layers), desc="Calculating refusal direction for layer"): |
|
pos = -1 |
|
|
|
|
|
harmful_hidden_list = [hidden[0][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
harmless_hidden_list = [hidden[1][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
|
|
|
|
harmful_mean = torch.stack(harmful_hidden_list).mean(dim=0) |
|
harmless_mean = torch.stack(harmless_hidden_list).mean(dim=0) |
|
|
|
|
|
refusal_dir = harmful_mean - harmless_mean |
|
refusal_dir = refusal_dir / refusal_dir.norm() |
|
|
|
|
|
final_refusal_directions.append(refusal_dir) |
|
|
|
|
|
torch.save(final_refusal_directions, output_dir + "/final_refusal_dirs.pt") |
|
print("Refusal directions saved successfully.") |
|
|