|
import jaxtyping |
|
import random |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import einops |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
|
|
import os |
|
|
|
os.environ["MKL_NUM_THREADS"] = "72" |
|
os.environ["OMP_NUM_THREADS"] = "72" |
|
torch.set_num_threads(72) |
|
|
|
print(f"PyTorch threads: {torch.get_num_threads()}") |
|
print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}") |
|
print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}") |
|
|
|
torch.inference_mode() |
|
torch.set_default_device("cuda") |
|
|
|
MODEL_ID = "agentica-org/DeepCoder-14B-Preview" |
|
output_dir = MODEL_ID + "/hidden_states" |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
print(f"Load Model {MODEL_ID} ... ") |
|
quant_config_4 = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=True, |
|
llm_int8_enable_fp32_cpu_offload=True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
quantization_config=quant_config_4, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
tokenizer.padding_side = 'left' |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
num_layers = len(model.model.layers) |
|
print(f"Model has {num_layers} layers.") |
|
|
|
print(f"Load data ... ") |
|
|
|
|
|
def reformat_texts(texts): |
|
return [[{"role": "user", "content": text}] for text in texts] |
|
|
|
def get_harmful_instructions(): |
|
with open("datasets17/harmful.txt", "r", encoding="utf-8") as f: |
|
harmful = f.readlines() |
|
return reformat_texts(harmful) |
|
|
|
def get_harmless_instructions(): |
|
with open("datasets17/harmless.txt", "r", encoding="utf-8") as f: |
|
harmless = f.readlines() |
|
return reformat_texts(harmless) |
|
|
|
|
|
|
|
harmful = get_harmful_instructions() |
|
|
|
|
|
harmless = get_harmless_instructions() |
|
|
|
print(f"harmful len: {len(harmful)}") |
|
print(f"harmless len: {len(harmless)}") |
|
|
|
n_instructions = min(len(harmful), len(harmless)) |
|
|
|
print("Instruction count: " + str(n_instructions)) |
|
|
|
harmful_instructions = harmful[:n_instructions] |
|
harmless_instructions = harmless[:n_instructions] |
|
|
|
print("Tokenizer ... ") |
|
|
|
harmful_toks = [ |
|
tokenizer.apply_chat_template(insn, tokenize=True, add_generation_prompt=True, |
|
return_tensors="pt", return_dict=True) for insn in harmful_instructions] |
|
harmless_toks = [ |
|
tokenizer.apply_chat_template(insn, tokenize=True, add_generation_prompt=True, |
|
return_tensors="pt", return_dict=True) for insn in harmless_instructions] |
|
|
|
max_its = n_instructions * 2 |
|
bar = tqdm(total=max_its) |
|
|
|
|
|
import gc |
|
|
|
def generate_and_process(toks, label, idx): |
|
bar.update(n=1) |
|
|
|
|
|
tokens = toks['input_ids'].to("cuda:0") |
|
attention_mask = toks['attention_mask'].to("cuda:0") |
|
|
|
|
|
output = model.generate(tokens, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
max_new_tokens=1, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
return_dict_in_generate=True, |
|
output_hidden_states=True) |
|
|
|
|
|
|
|
hidden_states_0 = output.hidden_states[0] |
|
torch.save(hidden_states_0, f"{output_dir}/{label}_hidden_state_{idx}.pt") |
|
|
|
|
|
del toks, tokens, attention_mask, output, hidden_states_0 |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
print("Generate and process...") |
|
|
|
|
|
for idx, toks in enumerate(harmful_toks): |
|
generate_and_process(toks, 'harmful', idx) |
|
|
|
for idx, toks in enumerate(harmless_toks): |
|
generate_and_process(toks, 'harmless', idx) |
|
|
|
bar.close() |
|
|
|
del model, tokenizer |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
final_refusal_dirs = [] |
|
|
|
|
|
for idx in tqdm(range(n_instructions), desc="Processing instruction"): |
|
|
|
harmful_hidden = torch.load(f"{output_dir}/harmful_hidden_state_{idx}.pt", map_location='cpu', weights_only=True) |
|
harmless_hidden = torch.load(f"{output_dir}/harmless_hidden_state_{idx}.pt", map_location='cpu', weights_only=True) |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
harmful_layer_hidden = harmful_hidden[layer_idx] |
|
harmless_layer_hidden = harmless_hidden[layer_idx] |
|
|
|
|
|
if len(final_refusal_dirs) <= layer_idx: |
|
final_refusal_dirs.append([]) |
|
|
|
|
|
final_refusal_dirs[layer_idx].append((harmful_layer_hidden, harmless_layer_hidden)) |
|
|
|
|
|
del harmful_hidden, harmless_hidden |
|
torch.cuda.empty_cache() |
|
|
|
|
|
final_refusal_directions = [] |
|
|
|
for layer_idx in tqdm(range(num_layers), desc="Calculating refusal direction for layer"): |
|
pos = -1 |
|
|
|
|
|
harmful_hidden_list = [hidden[0][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
harmless_hidden_list = [hidden[1][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
|
|
|
|
harmful_mean = torch.stack(harmful_hidden_list).mean(dim=0) |
|
harmless_mean = torch.stack(harmless_hidden_list).mean(dim=0) |
|
|
|
|
|
refusal_dir = harmful_mean - harmless_mean |
|
refusal_dir = refusal_dir / refusal_dir.norm() |
|
|
|
|
|
final_refusal_directions.append(refusal_dir) |
|
|
|
|
|
torch.save(final_refusal_directions, output_dir + "/final_refusal_dirs.pt") |
|
print("Refusal directions saved successfully.") |
|
|