|
import jaxtyping |
|
import random |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig |
|
import einops |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
|
|
import os |
|
import signal |
|
|
|
cpu_count = os.cpu_count() |
|
print(f"Number of CPU cores in the system: {cpu_count}") |
|
half_cpu_count = cpu_count // 2 |
|
os.environ["MKL_NUM_THREADS"] = str(half_cpu_count) |
|
os.environ["OMP_NUM_THREADS"] = str(half_cpu_count) |
|
torch.set_num_threads(half_cpu_count) |
|
|
|
print(f"PyTorch threads: {torch.get_num_threads()}") |
|
print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}") |
|
print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}") |
|
|
|
|
|
MODEL_ID = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" |
|
output_dir = MODEL_ID + "/hidden_states" |
|
|
|
n_instructions = 5510 |
|
num_layers = 36 |
|
|
|
final_refusal_dirs = [] |
|
|
|
def find_lines_positions(small_file_path, large_file_path): |
|
|
|
with open(small_file_path, 'r', encoding='utf-8') as small_file: |
|
small_lines = {line.strip() for line in small_file if line.strip()} |
|
|
|
|
|
result = {} |
|
with open(large_file_path, 'r', encoding='utf-8') as large_file: |
|
for line_num, line in enumerate(large_file, 0): |
|
line = line.strip().strip("?") |
|
if line in small_lines: |
|
if line in result: |
|
result[line].append(line_num) |
|
else: |
|
result[line] = [line_num] |
|
|
|
|
|
for line in small_lines: |
|
if line in result: |
|
print(f"##Line '{line}' found at line number(s): {result[line]}") |
|
|
|
|
|
|
|
def count_lines(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
return sum(1 for line in f) |
|
|
|
|
|
|
|
small_file_path = 'datasets21/harmful-refuese-r1.txt' |
|
large_file_path = 'datasets22/harmful.txt' |
|
|
|
|
|
|
|
total_lines = count_lines(large_file_path) |
|
|
|
|
|
with open(small_file_path, 'r', encoding='utf-8') as small_file: |
|
small_lines = {line.strip() for line in small_file if line.strip()} |
|
with open(large_file_path, 'r', encoding='utf-8') as large_file: |
|
for line_num, line in tqdm(enumerate(large_file, start=0), total=total_lines, desc="Processing instruction"): |
|
line = line.strip().strip("?") |
|
if line in small_lines: |
|
try: |
|
|
|
harmful_hidden = torch.load(f"{output_dir}/harmful_hidden_state_{line_num}.pt", map_location='cpu', weights_only=True) |
|
harmless_hidden = torch.load(f"{output_dir}/harmless_hidden_state_{line_num}.pt", map_location='cpu', weights_only=True) |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
harmful_layer_hidden = harmful_hidden[layer_idx] |
|
harmless_layer_hidden = harmless_hidden[layer_idx] |
|
|
|
|
|
if len(final_refusal_dirs) <= layer_idx: |
|
final_refusal_dirs.append([]) |
|
|
|
|
|
final_refusal_dirs[layer_idx].append((harmful_layer_hidden, harmless_layer_hidden)) |
|
|
|
|
|
del harmful_hidden, harmless_hidden |
|
|
|
except FileNotFoundError as e: |
|
print(f"Error: File not found for line {line_num}: {e}") |
|
continue |
|
|
|
final_refusal_directions16 = [] |
|
final_refusal_directions32 = [] |
|
|
|
for layer_idx in range(0, num_layers): |
|
pos = -1 |
|
|
|
|
|
harmful_hidden_list = [hidden[0][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
harmless_hidden_list = [hidden[1][:, pos, :] for hidden in final_refusal_dirs[layer_idx]] |
|
|
|
|
|
harmful_mean = torch.stack(harmful_hidden_list).mean(dim=0) |
|
harmless_mean = torch.stack(harmless_hidden_list).mean(dim=0) |
|
|
|
mean_diff_norm = (harmful_mean - harmless_mean).norm().item() |
|
|
|
refusal_dir16 = harmful_mean - harmless_mean |
|
refusal_dir32 = refusal_dir16.to(torch.float32) |
|
|
|
if mean_diff_norm < 1e-6: |
|
print(f"Warning: Layer {layer_idx} has near-zero refusal_dir") |
|
refusal_dir16 = torch.zeros_like(refusal_dir16) |
|
refusal_dir32 = torch.zeros_like(refusal_dir32) |
|
else: |
|
refusal_dir16 = refusal_dir16 / refusal_dir16.norm() |
|
refusal_dir32 = refusal_dir32 / refusal_dir32.norm() |
|
|
|
print(f"layer {layer_idx:3d}:{mean_diff_norm:.6f}, {refusal_dir32.norm().item():.16f}") |
|
|
|
|
|
final_refusal_directions16.append(refusal_dir16) |
|
final_refusal_directions32.append(refusal_dir32) |
|
|
|
|
|
torch.save(final_refusal_directions16, output_dir + "/final_refusal_dirs16-1.pt") |
|
torch.save(final_refusal_directions32, output_dir + "/final_refusal_dirs32-1.pt") |
|
print("Refusal directions saved successfully.") |
|
|
|
refusal_data = [] |
|
for layer_idx, refusal_dir in enumerate(final_refusal_directions32): |
|
value = refusal_dir.norm().item() |
|
refusal_data.append((layer_idx, value)) |
|
|
|
|
|
|
|
sorted_data = sorted(refusal_data, key=lambda x: (-x[1], x[0])) |
|
for layer_idx, value in sorted_data: |
|
print(f"layer {layer_idx}:{value:.16f}") |
|
print("----------") |
|
|
|
test_layes = [] |
|
print("test_layes = [", end="") |
|
for layer_idx, value in sorted_data: |
|
if value < 1.0: |
|
print(f"'{layer_idx}', ", end="") |
|
test_layes.append(layer_idx) |
|
print("]") |
|
|
|
print("----------") |
|
|
|
for _, layer_idx in enumerate(test_layes): |
|
print(f"layer {layer_idx}") |
|
|