""" Project: corr-steer This script loads a given dataset (using a datasets configuration), tokenizes the texts, runs a LLM model forward to extract hidden states, and then - for each SAE hook defined for the model - computes binary feature activations. Activations are thresholded and aggregated (max over the sequence), and the point-biserial correlation is computed between each feature and each label category. For each hook, the script finds the top 10 features (sorted in descending order by absolute correlation) per category and saves that record to a JSON file: features/{model}.{dataset}.{hook}.json Additionally, an aggregated JSON file features/{model}.{dataset}.json is created that, for each category, merges records from all hooks (each record includes its hook) and retains the top 10 features overall. This script is callable via Fire, e.g.: python corr_extract.py --dataset emgsd --model gpt2 """ import os import json import math import torch import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM from sae_lens import SAE from datasets import load_dataset, concatenate_datasets from tqdm import tqdm from sklearn.preprocessing import LabelBinarizer from scipy.stats import pointbiserialr import wandb import fire from config import datasets_config, models_config # ========= # Load a dataset given the dataset configuration. # # Returns: # texts, labels, and the maximum token length. # ========= def load_custom_dataset(dataset_name, limit: int): config = datasets_config[dataset_name] # For this example, we use the "train" split. dataset = load_dataset(config["id"], split="test") # Shuffle dataset for extracting features to divide validation dataset = dataset.shuffle(seed=42).select(range(int(dataset.num_rows / 2))) # Select only the specified columns. dataset = dataset.select_columns(config["columns"]) # Apply filtering if specified. if "filter" in config: for key, val in config["filter"].items(): dataset = dataset.filter(lambda ex, key=key, val=val: ex[key] == val) texts = [] labels = [] text_field = config["text_field"] label_field = config["label_field"] if limit: dataset = dataset.select(range(limit)) for ex in tqdm(dataset, desc="Loading dataset"): texts.append(str(ex[text_field])) labels.append(str(ex[label_field])) return texts, labels, config["max_length"] # ========= # Tokenize a list of texts. # # Returns a list of tokenized (and device-mapped) inputs. # ========= def tokenize_texts(tokenizer, texts, max_length, device): tokenized = [] for text in tqdm(texts, desc="Tokenizing texts"): encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) for k, v in encoding.items(): encoding[k] = v.to(device) tokenized.append(encoding) return tokenized # ========= # Extract aggregated binary feature activations for each SAE hook. # # For each sample, the model forward is run once to get all hidden states. Then for each hook, # the corresponding hidden state (parsed from the hook string) is passed through its SAE. The output # is thresholded (> 0) and aggregated via max pooling along the sequence dimension. # # Returns: # A dictionary mapping hook names to a numpy array of shape (num_samples, num_features). # ========= def extract_features(llm, model_name, tokens_list, hooks, device): # Preload SAE models for each hook. sae_models = {} for hook in hooks: sae, _, _ = SAE.from_pretrained(models_config[model_name]["sae"], hook, device=device) sae_models[hook] = sae features_by_hook = {hook: [] for hook in hooks} for encoding in tqdm(tokens_list, desc="Extracting activations"): with torch.no_grad(): outputs = llm(**encoding, output_hidden_states=True) # For each hook, extract its corresponding hidden state. for hook in hooks: # Parse layer index from the hook string. layer = int(hook.split(".")[1]) hidden_state = outputs.hidden_states[layer] # shape: (batch_size, seq_len, hidden_dim) activations = sae_models[hook].encode(hidden_state) # Remove the batch dimension (assumes batch size = 1) and move to CPU. activations = activations.squeeze(0).cpu().numpy() # (seq_len, num_features) # Threshold activations. binary_acts = (activations > 0).astype(int) # Aggregate over the sequence (max pooling). aggregated = binary_acts.max(axis=0) # (num_features,) features_by_hook[hook].append(aggregated) # Convert lists to numpy arrays. for hook in hooks: features_by_hook[hook] = np.array(features_by_hook[hook]) return features_by_hook # ========= # Compute correlations per label category. # # Given a feature activation array of shape (n_samples, n_features) and binary labels of shape # (n_samples, n_categories) along with a list of category names, compute the point-biserial correlation # for each feature against each category. For each category, sort by absolute correlation and keep the top 10. # # Returns a dictionary keyed by category. # ========= def compute_correlations_by_category(feature_activations, binary_labels, categories): results = {cat: [] for cat in categories} n_features = feature_activations.shape[1] for feat_idx in range(n_features): feat_vec = feature_activations[:, feat_idx] for cat_idx, cat in enumerate(categories): lbl = binary_labels[:, cat_idx] corr, _ = pointbiserialr(lbl, feat_vec) results[cat].append({ "feature_index": feat_idx, "correlation": corr }) # For each category, sort and keep the top 10 records (by absolute correlation). for cat in categories: results[cat] = sorted(results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True)[:10] return results # ========= # Main function # # This function initializes wandb (project "corr-steer"), loads the specified dataset and LLM along with # the SAE hooks (from models_config), tokenizes the texts, extracts feature activations, # computes (per hook) the per-category top 10 correlation records, and writes out one JSON per hook as # well as one aggregated JSON file that combines (and sorts) records across hooks. # # Run via, for example: # python corr_extract.py --dataset emgsd --model gpt2 # ========= def main(dataset="emgsd", model="gpt2", limit=1000): device = "cuda" if torch.cuda.is_available() else "cpu" device = "mps" if torch.backends.mps.is_available() else device # Initialize wandb. wandb.init(project="corr-steer", config={"model": model, "dataset": dataset}) # Load tokenizer and the LLM. print("Loading tokenizer and model...") model_id = models_config[model]["id"] tokenizer = AutoTokenizer.from_pretrained(model_id) llm = AutoModelForCausalLM.from_pretrained(model_id).to(device) llm.eval() # Load dataset texts, labels, and max token length. print("Loading dataset...") texts, labels, max_length = load_custom_dataset(dataset, limit) # Binarize labels. lb = LabelBinarizer() binary_labels = lb.fit_transform(labels) # Ensure binary_labels is 2D. if binary_labels.ndim == 1: binary_labels = binary_labels.reshape(-1, 1) # Tokenize texts. tokens_list = tokenize_texts(tokenizer, texts, max_length, device) # Determine hooks from models_config. if model in models_config: hooks = models_config[model]["hooks"] else: raise ValueError(f"Model {model} is not configured in models_config.") # Extract features for all hooks. print("Extracting features for all hooks...") features_by_hook = extract_features(llm, model, tokens_list, hooks, device) out_dir = "features" os.makedirs(out_dir, exist_ok=True) # For aggregated results across hooks. categories = lb.classes_ aggregated_results = {cat: [] for cat in categories} # Process each hook individually. for hook in hooks: print(f"Computing per-category correlations for hook {hook} ...") feat = features_by_hook[hook] # (n_samples, n_features) hook_corrs = compute_correlations_by_category(feat, binary_labels, categories) # Save per-hook file. hook_filename = f"{model}.{dataset}.{hook}.json" hook_filepath = os.path.join(out_dir, hook_filename) with open(hook_filepath, "w", encoding="utf-8") as f: json.dump(hook_corrs, f, indent=2, ensure_ascii=False) print(f"Saved per-hook correlations to {hook_filepath}") # Log each hook's results to wandb. wandb.log({hook: hook_corrs}) # For aggregation, add hook info to each record. for cat in categories: for rec in hook_corrs[cat]: rec_with_hook = rec.copy() rec_with_hook["hook"] = hook aggregated_results[cat].append(rec_with_hook) # Now, for each category, sort aggregated records across hooks and retain top 10. final_aggregated = {} for cat in categories: sorted_records = sorted( aggregated_results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True )[:10] final_aggregated[cat] = sorted_records agg_filename = f"{model}.{dataset}.json" agg_filepath = os.path.join(out_dir, agg_filename) with open(agg_filepath, "w", encoding="utf-8") as f: json.dump(final_aggregated, f, indent=2, ensure_ascii=False) print(f"Saved aggregated correlations to {agg_filepath}") wandb.finish() if __name__ == "__main__": fire.Fire(main)