|
""" |
|
Project: corr-steer |
|
|
|
This script loads a given dataset (using a datasets configuration), tokenizes the texts, |
|
runs a LLM model forward to extract hidden states, and then - for each SAE hook defined |
|
for the model - computes binary feature activations. Activations are thresholded and aggregated |
|
(max over the sequence), and the point-biserial correlation is computed between each feature |
|
and each label category. |
|
|
|
For each hook, the script finds the top 10 features (sorted in descending order by absolute correlation) |
|
per category and saves that record to a JSON file: |
|
|
|
features/{model}.{dataset}.{hook}.json |
|
|
|
Additionally, an aggregated JSON file |
|
|
|
features/{model}.{dataset}.json |
|
|
|
is created that, for each category, merges records from all hooks (each record includes its hook) |
|
and retains the top 10 features overall. |
|
|
|
This script is callable via Fire, e.g.: |
|
python corr_extract.py --dataset emgsd --model gpt2 |
|
""" |
|
|
|
import os |
|
import json |
|
import math |
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sae_lens import SAE |
|
from datasets import load_dataset, concatenate_datasets |
|
from tqdm import tqdm |
|
from sklearn.preprocessing import LabelBinarizer |
|
from scipy.stats import pointbiserialr |
|
import wandb |
|
import fire |
|
|
|
from config import datasets_config, models_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_custom_dataset(dataset_name, limit: int): |
|
config = datasets_config[dataset_name] |
|
|
|
dataset = load_dataset(config["id"], split="test") |
|
|
|
dataset = dataset.shuffle(seed=42).select(range(int(dataset.num_rows / 2))) |
|
|
|
|
|
dataset = dataset.select_columns(config["columns"]) |
|
|
|
|
|
if "filter" in config: |
|
for key, val in config["filter"].items(): |
|
dataset = dataset.filter(lambda ex, key=key, val=val: ex[key] == val) |
|
|
|
texts = [] |
|
labels = [] |
|
text_field = config["text_field"] |
|
label_field = config["label_field"] |
|
if limit: |
|
dataset = dataset.select(range(limit)) |
|
for ex in tqdm(dataset, desc="Loading dataset"): |
|
texts.append(str(ex[text_field])) |
|
labels.append(str(ex[label_field])) |
|
return texts, labels, config["max_length"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_texts(tokenizer, texts, max_length, device): |
|
tokenized = [] |
|
for text in tqdm(texts, desc="Tokenizing texts"): |
|
encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) |
|
for k, v in encoding.items(): |
|
encoding[k] = v.to(device) |
|
tokenized.append(encoding) |
|
return tokenized |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_features(llm, model_name, tokens_list, hooks, device): |
|
|
|
sae_models = {} |
|
for hook in hooks: |
|
sae, _, _ = SAE.from_pretrained(models_config[model_name]["sae"], hook, device=device) |
|
sae_models[hook] = sae |
|
|
|
features_by_hook = {hook: [] for hook in hooks} |
|
|
|
for encoding in tqdm(tokens_list, desc="Extracting activations"): |
|
with torch.no_grad(): |
|
outputs = llm(**encoding, output_hidden_states=True) |
|
|
|
for hook in hooks: |
|
|
|
layer = int(hook.split(".")[1]) |
|
hidden_state = outputs.hidden_states[layer] |
|
activations = sae_models[hook].encode(hidden_state) |
|
|
|
activations = activations.squeeze(0).cpu().numpy() |
|
|
|
binary_acts = (activations > 0).astype(int) |
|
|
|
aggregated = binary_acts.max(axis=0) |
|
features_by_hook[hook].append(aggregated) |
|
|
|
for hook in hooks: |
|
features_by_hook[hook] = np.array(features_by_hook[hook]) |
|
return features_by_hook |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_correlations_by_category(feature_activations, binary_labels, categories): |
|
results = {cat: [] for cat in categories} |
|
n_features = feature_activations.shape[1] |
|
for feat_idx in range(n_features): |
|
feat_vec = feature_activations[:, feat_idx] |
|
for cat_idx, cat in enumerate(categories): |
|
lbl = binary_labels[:, cat_idx] |
|
corr, _ = pointbiserialr(lbl, feat_vec) |
|
results[cat].append({ |
|
"feature_index": feat_idx, |
|
"correlation": corr |
|
}) |
|
|
|
for cat in categories: |
|
results[cat] = sorted(results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True)[:10] |
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(dataset="emgsd", model="gpt2", limit=1000): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = "mps" if torch.backends.mps.is_available() else device |
|
|
|
|
|
wandb.init(project="corr-steer", config={"model": model, "dataset": dataset}) |
|
|
|
|
|
print("Loading tokenizer and model...") |
|
model_id = models_config[model]["id"] |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
llm = AutoModelForCausalLM.from_pretrained(model_id).to(device) |
|
llm.eval() |
|
|
|
|
|
print("Loading dataset...") |
|
texts, labels, max_length = load_custom_dataset(dataset, limit) |
|
|
|
|
|
lb = LabelBinarizer() |
|
binary_labels = lb.fit_transform(labels) |
|
|
|
if binary_labels.ndim == 1: |
|
binary_labels = binary_labels.reshape(-1, 1) |
|
|
|
|
|
tokens_list = tokenize_texts(tokenizer, texts, max_length, device) |
|
|
|
|
|
if model in models_config: |
|
hooks = models_config[model]["hooks"] |
|
else: |
|
raise ValueError(f"Model {model} is not configured in models_config.") |
|
|
|
|
|
print("Extracting features for all hooks...") |
|
features_by_hook = extract_features(llm, model, tokens_list, hooks, device) |
|
|
|
out_dir = "features" |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
categories = lb.classes_ |
|
aggregated_results = {cat: [] for cat in categories} |
|
|
|
|
|
for hook in hooks: |
|
print(f"Computing per-category correlations for hook {hook} ...") |
|
feat = features_by_hook[hook] |
|
hook_corrs = compute_correlations_by_category(feat, binary_labels, categories) |
|
|
|
|
|
hook_filename = f"{model}.{dataset}.{hook}.json" |
|
hook_filepath = os.path.join(out_dir, hook_filename) |
|
with open(hook_filepath, "w", encoding="utf-8") as f: |
|
json.dump(hook_corrs, f, indent=2, ensure_ascii=False) |
|
print(f"Saved per-hook correlations to {hook_filepath}") |
|
|
|
|
|
wandb.log({hook: hook_corrs}) |
|
|
|
|
|
for cat in categories: |
|
for rec in hook_corrs[cat]: |
|
rec_with_hook = rec.copy() |
|
rec_with_hook["hook"] = hook |
|
aggregated_results[cat].append(rec_with_hook) |
|
|
|
|
|
final_aggregated = {} |
|
for cat in categories: |
|
sorted_records = sorted( |
|
aggregated_results[cat], |
|
key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), |
|
reverse=True |
|
)[:10] |
|
final_aggregated[cat] = sorted_records |
|
|
|
agg_filename = f"{model}.{dataset}.json" |
|
agg_filepath = os.path.join(out_dir, agg_filename) |
|
with open(agg_filepath, "w", encoding="utf-8") as f: |
|
json.dump(final_aggregated, f, indent=2, ensure_ascii=False) |
|
print(f"Saved aggregated correlations to {agg_filepath}") |
|
|
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|