File size: 10,068 Bytes
889f722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Project: corr-steer

This script loads a given dataset (using a datasets configuration), tokenizes the texts,
runs a LLM model forward to extract hidden states, and then - for each SAE hook defined
for the model - computes binary feature activations. Activations are thresholded and aggregated
(max over the sequence), and the point-biserial correlation is computed between each feature
and each label category.

For each hook, the script finds the top 10 features (sorted in descending order by absolute correlation)
per category and saves that record to a JSON file:

    features/{model}.{dataset}.{hook}.json

Additionally, an aggregated JSON file

    features/{model}.{dataset}.json

is created that, for each category, merges records from all hooks (each record includes its hook)
and retains the top 10 features overall.

This script is callable via Fire, e.g.:
    python corr_extract.py --dataset emgsd --model gpt2
"""

import os
import json
import math
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
from sklearn.preprocessing import LabelBinarizer
from scipy.stats import pointbiserialr
import wandb
import fire

from config import datasets_config, models_config


# =========
# Load a dataset given the dataset configuration.
#
# Returns:
#   texts, labels, and the maximum token length.
# =========
def load_custom_dataset(dataset_name, limit: int):
    config = datasets_config[dataset_name]
    # For this example, we use the "train" split.
    dataset = load_dataset(config["id"], split="test")
    # Shuffle dataset for extracting features to divide validation
    dataset = dataset.shuffle(seed=42).select(range(int(dataset.num_rows / 2)))
    
    # Select only the specified columns.
    dataset = dataset.select_columns(config["columns"])
    
    # Apply filtering if specified.
    if "filter" in config:
        for key, val in config["filter"].items():
            dataset = dataset.filter(lambda ex, key=key, val=val: ex[key] == val)
    
    texts = []
    labels = []
    text_field = config["text_field"]
    label_field = config["label_field"]
    if limit:
      dataset = dataset.select(range(limit))
    for ex in tqdm(dataset, desc="Loading dataset"):
        texts.append(str(ex[text_field]))
        labels.append(str(ex[label_field]))
    return texts, labels, config["max_length"]

# =========
# Tokenize a list of texts.
#
# Returns a list of tokenized (and device-mapped) inputs.
# =========
def tokenize_texts(tokenizer, texts, max_length, device):
    tokenized = []
    for text in tqdm(texts, desc="Tokenizing texts"):
        encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
        for k, v in encoding.items():
            encoding[k] = v.to(device)
        tokenized.append(encoding)
    return tokenized

# =========
# Extract aggregated binary feature activations for each SAE hook.
#
# For each sample, the model forward is run once to get all hidden states. Then for each hook,
# the corresponding hidden state (parsed from the hook string) is passed through its SAE. The output
# is thresholded (> 0) and aggregated via max pooling along the sequence dimension.
#
# Returns:
#   A dictionary mapping hook names to a numpy array of shape (num_samples, num_features).
# =========
def extract_features(llm, model_name, tokens_list, hooks, device):
    # Preload SAE models for each hook.
    sae_models = {}
    for hook in hooks:
        sae, _, _ = SAE.from_pretrained(models_config[model_name]["sae"], hook, device=device)
        sae_models[hook] = sae

    features_by_hook = {hook: [] for hook in hooks}
    
    for encoding in tqdm(tokens_list, desc="Extracting activations"):
        with torch.no_grad():
            outputs = llm(**encoding, output_hidden_states=True)
            # For each hook, extract its corresponding hidden state.
            for hook in hooks:
                # Parse layer index from the hook string.
                layer = int(hook.split(".")[1])
                hidden_state = outputs.hidden_states[layer]  # shape: (batch_size, seq_len, hidden_dim)
                activations = sae_models[hook].encode(hidden_state)
                # Remove the batch dimension (assumes batch size = 1) and move to CPU.
                activations = activations.squeeze(0).cpu().numpy()  # (seq_len, num_features)
                # Threshold activations.
                binary_acts = (activations > 0).astype(int)
                # Aggregate over the sequence (max pooling).
                aggregated = binary_acts.max(axis=0)  # (num_features,)
                features_by_hook[hook].append(aggregated)
    # Convert lists to numpy arrays.
    for hook in hooks:
        features_by_hook[hook] = np.array(features_by_hook[hook])
    return features_by_hook

# =========
# Compute correlations per label category.
#
# Given a feature activation array of shape (n_samples, n_features) and binary labels of shape
# (n_samples, n_categories) along with a list of category names, compute the point-biserial correlation
# for each feature against each category. For each category, sort by absolute correlation and keep the top 10.
#
# Returns a dictionary keyed by category.
# =========
def compute_correlations_by_category(feature_activations, binary_labels, categories):
    results = {cat: [] for cat in categories}
    n_features = feature_activations.shape[1]
    for feat_idx in range(n_features):
        feat_vec = feature_activations[:, feat_idx]
        for cat_idx, cat in enumerate(categories):
            lbl = binary_labels[:, cat_idx]
            corr, _ = pointbiserialr(lbl, feat_vec)
            results[cat].append({
                "feature_index": feat_idx,
                "correlation": corr
            })
    # For each category, sort and keep the top 10 records (by absolute correlation).
    for cat in categories:
        results[cat] = sorted(results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True)[:10]
    return results

# =========
# Main function
#
# This function initializes wandb (project "corr-steer"), loads the specified dataset and LLM along with
# the SAE hooks (from models_config), tokenizes the texts, extracts feature activations,
# computes (per hook) the per-category top 10 correlation records, and writes out one JSON per hook as
# well as one aggregated JSON file that combines (and sorts) records across hooks.
#
# Run via, for example:
#     python corr_extract.py --dataset emgsd --model gpt2
# =========
def main(dataset="emgsd", model="gpt2", limit=1000):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "mps" if torch.backends.mps.is_available() else device

    # Initialize wandb.
    wandb.init(project="corr-steer", config={"model": model, "dataset": dataset})
    
    # Load tokenizer and the LLM.
    print("Loading tokenizer and model...")
    model_id = models_config[model]["id"]
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    llm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
    llm.eval()
    
    # Load dataset texts, labels, and max token length.
    print("Loading dataset...")
    texts, labels, max_length = load_custom_dataset(dataset, limit)
    
    # Binarize labels.
    lb = LabelBinarizer()
    binary_labels = lb.fit_transform(labels)
    # Ensure binary_labels is 2D.
    if binary_labels.ndim == 1:
        binary_labels = binary_labels.reshape(-1, 1)
    
    # Tokenize texts.
    tokens_list = tokenize_texts(tokenizer, texts, max_length, device)
    
    # Determine hooks from models_config.
    if model in models_config:
        hooks = models_config[model]["hooks"]
    else:
        raise ValueError(f"Model {model} is not configured in models_config.")
    
    # Extract features for all hooks.
    print("Extracting features for all hooks...")
    features_by_hook = extract_features(llm, model, tokens_list, hooks, device)
    
    out_dir = "features"
    os.makedirs(out_dir, exist_ok=True)
    
    # For aggregated results across hooks.
    categories = lb.classes_
    aggregated_results = {cat: [] for cat in categories}
    
    # Process each hook individually.
    for hook in hooks:
        print(f"Computing per-category correlations for hook {hook} ...")
        feat = features_by_hook[hook]  # (n_samples, n_features)
        hook_corrs = compute_correlations_by_category(feat, binary_labels, categories)
        
        # Save per-hook file.
        hook_filename = f"{model}.{dataset}.{hook}.json"
        hook_filepath = os.path.join(out_dir, hook_filename)
        with open(hook_filepath, "w", encoding="utf-8") as f:
            json.dump(hook_corrs, f, indent=2, ensure_ascii=False)
        print(f"Saved per-hook correlations to {hook_filepath}")
        
        # Log each hook's results to wandb.
        wandb.log({hook: hook_corrs})
        
        # For aggregation, add hook info to each record.
        for cat in categories:
            for rec in hook_corrs[cat]:
                rec_with_hook = rec.copy()
                rec_with_hook["hook"] = hook
                aggregated_results[cat].append(rec_with_hook)
    
    # Now, for each category, sort aggregated records across hooks and retain top 10.
    final_aggregated = {}
    for cat in categories:
        sorted_records = sorted(
            aggregated_results[cat],
            key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]),
            reverse=True
        )[:10]
        final_aggregated[cat] = sorted_records
    
    agg_filename = f"{model}.{dataset}.json"
    agg_filepath = os.path.join(out_dir, agg_filename)
    with open(agg_filepath, "w", encoding="utf-8") as f:
        json.dump(final_aggregated, f, indent=2, ensure_ascii=False)
    print(f"Saved aggregated correlations to {agg_filepath}")
    
    wandb.finish()

if __name__ == "__main__":
    fire.Fire(main)