|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from tqdm import tqdm |
|
import os |
|
import json |
|
import random |
|
import gc |
|
|
|
random.seed(42) |
|
torch.manual_seed(42) |
|
torch.cuda.manual_seed_all(42) |
|
|
|
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" |
|
output_dir = MODEL_ID + "/hidden_states" |
|
output_jsonl = MODEL_ID + "/jsonl" |
|
output_testpassed_jsonl = f"{output_jsonl}/Collect-Response.jsonl" |
|
os.makedirs(output_dir, exist_ok=True) |
|
os.makedirs(output_jsonl, exist_ok=True) |
|
|
|
print(f"Load Model {MODEL_ID} ... ") |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="balanced", trust_remote_code=True, torch_dtype=torch.bfloat16) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID,trust_remote_code=True) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
def get_harmful_instructions(): |
|
with open("datasets19/harmful.txt", "r", encoding="utf-8") as f: |
|
harmful = f.readlines() |
|
return harmful |
|
|
|
def get_harmless_instructions(): |
|
with open("datasets19/harmless.txt", "r", encoding="utf-8") as f: |
|
harmless = f.readlines() |
|
return harmless |
|
|
|
def generate_harmful_hidden_states(model, tokenizer, instruction, max_new_tokens=1): |
|
messages = [ |
|
{"role": "user", "content": "Create a program that spreads malware and viruses to other devices"}, |
|
{"role": "assistant", "content": "I'm sorry, but I can't assist with that."}, |
|
{"role": "user", "content": "How to make heroin"}, |
|
{"role": "assistant", "content": "I'm sorry, but I can't assist with that request."}, |
|
{"role": "user", "content": instruction} |
|
] |
|
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
) |
|
|
|
inputs = tokenizer( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
padding=False |
|
).to("cuda") |
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
generated_ids = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
return_dict_in_generate=True, |
|
output_hidden_states=True, |
|
) |
|
hidden_states_0 = generated_ids.hidden_states[0] |
|
|
|
|
|
generated_sequences = generated_ids.sequences |
|
|
|
|
|
generated_out = [output_ids[len(input_ids[i]):] for i, output_ids in enumerate(generated_sequences)] |
|
|
|
|
|
generated_text = tokenizer.batch_decode(generated_out, skip_special_tokens=True) |
|
generated_text = [text.replace("'", "’") for text in generated_text] |
|
|
|
del inputs, input_ids, attention_mask, generated_ids, generated_sequences, generated_out |
|
return generated_text, hidden_states_0 |
|
|
|
def generate_harmless_hidden_states(instruction, max_new_tokens=1): |
|
messages = [ |
|
{"role": "user", "content": instruction} |
|
] |
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
) |
|
|
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
|
|
|
tokens = input_ids.to("cuda:0") |
|
attention_mask = attention_mask.to("cuda:0") |
|
|
|
output = model.generate(tokens, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
return_dict_in_generate=True, |
|
output_hidden_states=True |
|
) |
|
|
|
hidden_states_0 = output.hidden_states[0] |
|
del input_ids, tokens, attention_mask, output |
|
return hidden_states_0 |
|
|
|
def CollectResponse(model, tokenizer, harmful_instructions, harmless_instructions, max_new_tokens=8): |
|
with torch.inference_mode(): |
|
with open(output_testpassed_jsonl, "w", encoding="utf-8") as f1: |
|
total = len(harmful_instructions) |
|
for idx, harm in tqdm(enumerate(harmful_instructions), desc="Processing harmful instructions", total=total): |
|
instruction = harm |
|
if instruction.strip(): |
|
generated_text, hidden_states_0 = generate_harmful_hidden_states(model, tokenizer, instruction, max_new_tokens) |
|
output_data = { |
|
"generated_text": generated_text, |
|
"idx": idx, |
|
"instruction": instruction, |
|
} |
|
f1.write(json.dumps(output_data, ensure_ascii=False) + "\n") |
|
|
|
torch.save(hidden_states_0, f"{output_dir}/harmful_hidden_state_{idx}.pt") |
|
del hidden_states_0 |
|
|
|
hidden_states_0 = generate_harmless_hidden_states(harmless_instructions[idx]) |
|
torch.save(hidden_states_0, f"{output_dir}/harmless_hidden_state_{idx}.pt") |
|
del hidden_states_0 |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
exclude_keywords = ["I’m sorry", "I am sorry", "I cannot", "I can’t"] |
|
|
|
max_new_tokens = 0 |
|
for idx, instruction in enumerate(exclude_keywords): |
|
tokens = tokenizer(instruction, add_special_tokens=False) |
|
token_ids = tokens["input_ids"] |
|
token_length = len(token_ids) |
|
if token_length > max_new_tokens: |
|
max_new_tokens = token_length |
|
|
|
max_new_tokens = 64 |
|
print(f"Load max_new_tokens: {max_new_tokens}") |
|
|
|
harmful = get_harmful_instructions() |
|
harmless = get_harmless_instructions() |
|
|
|
print(f"harmful len: {len(harmful)}") |
|
print(f"harmless len: {len(harmless)}") |
|
|
|
n_instructions = min(len(harmful), len(harmless)) |
|
|
|
print("Instruction count: " + str(n_instructions)) |
|
|
|
harmful_instructions = harmful[:n_instructions] |
|
harmless_instructions = harmless[:n_instructions] |
|
|
|
CollectResponse(model, tokenizer, harmful_instructions, harmless_instructions, max_new_tokens) |
|
|