File size: 6,620 Bytes
26e1cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, Qwen3MoeForCausalLM
import torch
import torch.nn as nn
import os
import signal
from typing import Optional, Tuple
import einops
import jaxtyping

cpu_count = os.cpu_count()
print(f"Number of CPU cores in the system: {cpu_count}")
half_cpu_count = cpu_count // 2
os.environ["MKL_NUM_THREADS"] = str(half_cpu_count)
os.environ["OMP_NUM_THREADS"] = str(half_cpu_count)
torch.set_num_threads(half_cpu_count)

print(f"PyTorch threads: {torch.get_num_threads()}")
print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}")
print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}")

# Load the model and tokenizer
MODEL_ID = "kalomaze/Qwen3-16B-A3B"
print(f"Load Model {MODEL_ID} ... ")
quant_config_4 = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

model = Qwen3MoeForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cpu",
    trust_remote_code=True,
    #quantization_config=quant_config_4,
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

messages = []
enable_thinking = True
skip_prompt=True
skip_special_tokens=True

def direction_ablation_hook(activation: jaxtyping.Float[torch.Tensor, "... d_act"],
                            direction: jaxtyping.Float[torch.Tensor, "d_act"]):
    proj = einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction
    return activation - proj

class AblationDecoderLayer(nn.Module):
    def __init__(self, original_layer, refusal_dir):
        super(AblationDecoderLayer, self).__init__()
        self.original_layer = original_layer
        self.refusal_dir = refusal_dir

    def forward(self, *args, **kwargs):
        hidden_states = args[0]
        ablated = direction_ablation_hook(hidden_states, self.refusal_dir.to(hidden_states.device)).to(hidden_states.device)
        args = (ablated,) + args[1:]
        return self.original_layer.forward(*args, **kwargs)

class CustomTextStreamer(TextStreamer):
    def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
        super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
        self.generated_text = ""
        self.stop_flag = False

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.generated_text += text
        print(text, end="", flush=True)
        if self.stop_flag:
            raise StopIteration

    def stop_generation(self):
        self.stop_flag = True

def generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, max_new_tokens):
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        enable_thinking = enable_thinking,
        add_generation_prompt=True,
        return_tensors="pt"
    )
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    tokens = input_ids.to(model.device) 
    attention_mask = attention_mask.to(model.device)

    streamer = CustomTextStreamer(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)

    def signal_handler(sig, frame):
        streamer.stop_generation()
        print("\n[Generation stopped by user with Ctrl+C]")

    signal.signal(signal.SIGINT, signal_handler)
    
    print("Response: ", end="", flush=True)
    try:
        generated_ids = model.generate(
            tokens,
            attention_mask=attention_mask,
            use_cache=False,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            streamer=streamer
        )
        del generated_ids
    except StopIteration:
        print("\n[Stopped by user]")

    del input_ids, attention_mask
    torch.cuda.empty_cache()
    signal.signal(signal.SIGINT, signal.SIG_DFL)

    return streamer.generated_text, streamer.stop_flag



final_refusal_dirs= torch.load(MODEL_ID + "/hidden_states/final_refusal_dirs.pt", map_location='cpu', weights_only=True)

candidate_layer = 20
refusal_dir = final_refusal_dirs[candidate_layer]

layer = model.model.layers[20]
for name, param in layer.named_parameters():
    print(f"layer0 {name} ")

original_params = {name: param.clone() for name, param in layer.named_parameters()}

for idx in range(len(model.model.layers)):
    model.model.layers[idx] = AblationDecoderLayer(model.model.layers[idx], refusal_dir)

while True:
    user_input = input("User: ").strip()
    if user_input.lower() == "/exit":
        print("Exiting chat.")
        break
    if user_input.lower() == "/clear":
        messages = []
        print("Chat history cleared. Starting a new conversation.")
        continue
    if user_input.lower() == "/no_think":
        if enable_thinking:
            enable_thinking = False
            print("Thinking = False.")
        else:
            enable_thinking = True
            print("Thinking = True.")        
        continue
    if user_input.lower() == "/skip_prompt":
        if skip_prompt:
            skip_prompt = False
            print("skip_prompt = False.")
        else:
            skip_prompt = True
            print("skip_prompt = True.")        
        continue
    if user_input.lower() == "/skip_special_tokens":
        if skip_special_tokens:
            skip_special_tokens = False
            print("skip_special_tokens = False.")
        else:
            skip_special_tokens = True
            print("skip_special_tokens = True.")        
        continue
    if not user_input:
        print("Input cannot be empty. Please enter something.")
        continue
    messages.append({"role": "user", "content": user_input})
    response, stop_flag = generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, 2)
    print("", flush=True)
    messages.append({"role": "assistant", "content": response})
    
    layer2 = model.model.layers[20]
    for name, param in layer2.named_parameters():
        print(f"layer1 {name} ")

    layer2 = layer2.original_layer
    for name, param in layer2.named_parameters():
        print(f"layer2 {name} ")
    
    for name, param in layer2.named_parameters():
        if not torch.equal(original_params[name], param):
            print(f"参数 {name} 被修改!")

    if stop_flag:
        continue