Commit
·
889f722
0
Parent(s):
feat: hf space corr-steer
Browse files- .gitignore +3 -0
- Dockerfile +57 -0
- config.py +103 -0
- corr_extract.py +253 -0
- corr_steer/steer.py +65 -0
- demo/.gitignore +24 -0
- demo/README.md +71 -0
- demo/components.json +21 -0
- demo/eslint.config.js +28 -0
- demo/index.html +13 -0
- deploy.sh +49 -0
- features/gpt2.emgsd.json +314 -0
- requirements.txt +8 -0
- server.py +195 -0
- start.sh +30 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
wandb
|
2 |
+
__pycache__
|
3 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Node.js 18 as base image
|
2 |
+
FROM node:18-bullseye
|
3 |
+
|
4 |
+
# Install Python 3.9 and pip
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
python3.9 \
|
7 |
+
python3-pip \
|
8 |
+
python3-dev \
|
9 |
+
build-essential \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
# Set up aliases for python
|
13 |
+
RUN ln -s /usr/bin/python3.9 /usr/bin/python
|
14 |
+
|
15 |
+
# Install pnpm
|
16 |
+
RUN npm install -g pnpm
|
17 |
+
|
18 |
+
# Set working directory
|
19 |
+
WORKDIR /app
|
20 |
+
|
21 |
+
# Copy Python requirements and install dependencies
|
22 |
+
COPY requirements.txt .
|
23 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
24 |
+
|
25 |
+
# Copy Python files
|
26 |
+
COPY server.py .
|
27 |
+
COPY config.py .
|
28 |
+
COPY features/ ./features/
|
29 |
+
|
30 |
+
# Copy frontend files
|
31 |
+
COPY demo/ ./demo/
|
32 |
+
|
33 |
+
# Install frontend dependencies and build
|
34 |
+
WORKDIR /app/demo
|
35 |
+
RUN pnpm install
|
36 |
+
RUN pnpm build
|
37 |
+
|
38 |
+
# Verify build output
|
39 |
+
RUN ls -la dist/
|
40 |
+
|
41 |
+
# Go back to app directory
|
42 |
+
WORKDIR /app
|
43 |
+
|
44 |
+
# Copy startup script
|
45 |
+
COPY start.sh .
|
46 |
+
RUN chmod +x start.sh
|
47 |
+
|
48 |
+
# Expose port 7860 for HF Spaces
|
49 |
+
EXPOSE 7860
|
50 |
+
|
51 |
+
# Set environment variables
|
52 |
+
ENV NODE_ENV=production
|
53 |
+
ENV FLASK_ENV=production
|
54 |
+
ENV PYTHONPATH=/app
|
55 |
+
|
56 |
+
# Start both services
|
57 |
+
CMD ["./start.sh"]
|
config.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# =========
|
2 |
+
# Configuration: Define dataset information.
|
3 |
+
# =========
|
4 |
+
datasets_config = {
|
5 |
+
"emgsd": {
|
6 |
+
"id": "holistic-ai/EMGSD",
|
7 |
+
"columns": ["category", "stereotype_type", "text"],
|
8 |
+
"filter": {"category": "stereotype"},
|
9 |
+
"text_field": "text",
|
10 |
+
"label_field": "stereotype_type",
|
11 |
+
"max_length": 512,
|
12 |
+
"category": {
|
13 |
+
"lgbtq+": {
|
14 |
+
"prompt": "The gay people are",
|
15 |
+
},
|
16 |
+
"nationality": {
|
17 |
+
"prompt": "The French people are",
|
18 |
+
},
|
19 |
+
"race": {
|
20 |
+
"prompt": "The Black people are",
|
21 |
+
}
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"yahoo_question": {
|
25 |
+
"id": "community-datasets/yahoo_answers_topics",
|
26 |
+
"columns": ["topic", "question_title", "question_content"],
|
27 |
+
"text_field": "question_title",
|
28 |
+
"label_field": "topic",
|
29 |
+
"max_length": 512
|
30 |
+
},
|
31 |
+
"yahoo_answer": {
|
32 |
+
"id": "community-datasets/yahoo_answers_topics",
|
33 |
+
"columns": ["topic", "best_answer"],
|
34 |
+
"text_field": "best_answer",
|
35 |
+
"label_field": "topic",
|
36 |
+
"max_length": 512
|
37 |
+
},
|
38 |
+
"science": {
|
39 |
+
"id": "knowledgator/Scientific-text-classification",
|
40 |
+
"columns": ["text", "label"],
|
41 |
+
"text_field": "text",
|
42 |
+
"label_field": "label",
|
43 |
+
"max_length": 512
|
44 |
+
},
|
45 |
+
"wiki256": {
|
46 |
+
"id": "seonglae/wikipedia-256",
|
47 |
+
"columns": ["text", "title"],
|
48 |
+
"text_field": "text",
|
49 |
+
"label_field": "title",
|
50 |
+
"max_length": 512
|
51 |
+
},
|
52 |
+
"wiki512": {
|
53 |
+
"id": "seonglae/wikipedia-512",
|
54 |
+
"columns": ["text", "title"],
|
55 |
+
"text_field": "text",
|
56 |
+
"label_field": "title",
|
57 |
+
"max_length": 1024
|
58 |
+
}
|
59 |
+
}
|
60 |
+
|
61 |
+
# =========
|
62 |
+
# Configuration: Define model-specific information.
|
63 |
+
# For "gpt2", we specify the SAE source and the list of hooks to use.
|
64 |
+
# f"{model}-{dataset}" is the key for trained models.
|
65 |
+
# =========
|
66 |
+
models_config = {
|
67 |
+
"gpt2": {
|
68 |
+
"id": "gpt2",
|
69 |
+
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
|
70 |
+
"hooks": [
|
71 |
+
"blocks.11.hook_resid_pre",
|
72 |
+
"blocks.10.hook_resid_pre",
|
73 |
+
"blocks.9.hook_resid_pre",
|
74 |
+
"blocks.8.hook_resid_pre",
|
75 |
+
"blocks.7.hook_resid_pre",
|
76 |
+
"blocks.6.hook_resid_pre",
|
77 |
+
"blocks.5.hook_resid_pre",
|
78 |
+
"blocks.4.hook_resid_pre",
|
79 |
+
"blocks.3.hook_resid_pre",
|
80 |
+
"blocks.2.hook_resid_pre",
|
81 |
+
"blocks.1.hook_resid_pre",
|
82 |
+
"blocks.0.hook_resid_pre"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
"gpt2-emgsd": {
|
86 |
+
"id": "holistic-ai/gpt2-EMGSD",
|
87 |
+
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
|
88 |
+
"hooks": [
|
89 |
+
"blocks.11.hook_resid_pre",
|
90 |
+
"blocks.10.hook_resid_pre",
|
91 |
+
"blocks.9.hook_resid_pre",
|
92 |
+
"blocks.8.hook_resid_pre",
|
93 |
+
"blocks.7.hook_resid_pre",
|
94 |
+
"blocks.6.hook_resid_pre",
|
95 |
+
"blocks.5.hook_resid_pre",
|
96 |
+
"blocks.4.hook_resid_pre",
|
97 |
+
"blocks.3.hook_resid_pre",
|
98 |
+
"blocks.2.hook_resid_pre",
|
99 |
+
"blocks.1.hook_resid_pre",
|
100 |
+
"blocks.0.hook_resid_pre"
|
101 |
+
]
|
102 |
+
}
|
103 |
+
}
|
corr_extract.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Project: corr-steer
|
3 |
+
|
4 |
+
This script loads a given dataset (using a datasets configuration), tokenizes the texts,
|
5 |
+
runs a LLM model forward to extract hidden states, and then - for each SAE hook defined
|
6 |
+
for the model - computes binary feature activations. Activations are thresholded and aggregated
|
7 |
+
(max over the sequence), and the point-biserial correlation is computed between each feature
|
8 |
+
and each label category.
|
9 |
+
|
10 |
+
For each hook, the script finds the top 10 features (sorted in descending order by absolute correlation)
|
11 |
+
per category and saves that record to a JSON file:
|
12 |
+
|
13 |
+
features/{model}.{dataset}.{hook}.json
|
14 |
+
|
15 |
+
Additionally, an aggregated JSON file
|
16 |
+
|
17 |
+
features/{model}.{dataset}.json
|
18 |
+
|
19 |
+
is created that, for each category, merges records from all hooks (each record includes its hook)
|
20 |
+
and retains the top 10 features overall.
|
21 |
+
|
22 |
+
This script is callable via Fire, e.g.:
|
23 |
+
python corr_extract.py --dataset emgsd --model gpt2
|
24 |
+
"""
|
25 |
+
|
26 |
+
import os
|
27 |
+
import json
|
28 |
+
import math
|
29 |
+
import torch
|
30 |
+
import numpy as np
|
31 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
32 |
+
from sae_lens import SAE
|
33 |
+
from datasets import load_dataset, concatenate_datasets
|
34 |
+
from tqdm import tqdm
|
35 |
+
from sklearn.preprocessing import LabelBinarizer
|
36 |
+
from scipy.stats import pointbiserialr
|
37 |
+
import wandb
|
38 |
+
import fire
|
39 |
+
|
40 |
+
from config import datasets_config, models_config
|
41 |
+
|
42 |
+
|
43 |
+
# =========
|
44 |
+
# Load a dataset given the dataset configuration.
|
45 |
+
#
|
46 |
+
# Returns:
|
47 |
+
# texts, labels, and the maximum token length.
|
48 |
+
# =========
|
49 |
+
def load_custom_dataset(dataset_name, limit: int):
|
50 |
+
config = datasets_config[dataset_name]
|
51 |
+
# For this example, we use the "train" split.
|
52 |
+
dataset = load_dataset(config["id"], split="test")
|
53 |
+
# Shuffle dataset for extracting features to divide validation
|
54 |
+
dataset = dataset.shuffle(seed=42).select(range(int(dataset.num_rows / 2)))
|
55 |
+
|
56 |
+
# Select only the specified columns.
|
57 |
+
dataset = dataset.select_columns(config["columns"])
|
58 |
+
|
59 |
+
# Apply filtering if specified.
|
60 |
+
if "filter" in config:
|
61 |
+
for key, val in config["filter"].items():
|
62 |
+
dataset = dataset.filter(lambda ex, key=key, val=val: ex[key] == val)
|
63 |
+
|
64 |
+
texts = []
|
65 |
+
labels = []
|
66 |
+
text_field = config["text_field"]
|
67 |
+
label_field = config["label_field"]
|
68 |
+
if limit:
|
69 |
+
dataset = dataset.select(range(limit))
|
70 |
+
for ex in tqdm(dataset, desc="Loading dataset"):
|
71 |
+
texts.append(str(ex[text_field]))
|
72 |
+
labels.append(str(ex[label_field]))
|
73 |
+
return texts, labels, config["max_length"]
|
74 |
+
|
75 |
+
# =========
|
76 |
+
# Tokenize a list of texts.
|
77 |
+
#
|
78 |
+
# Returns a list of tokenized (and device-mapped) inputs.
|
79 |
+
# =========
|
80 |
+
def tokenize_texts(tokenizer, texts, max_length, device):
|
81 |
+
tokenized = []
|
82 |
+
for text in tqdm(texts, desc="Tokenizing texts"):
|
83 |
+
encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
|
84 |
+
for k, v in encoding.items():
|
85 |
+
encoding[k] = v.to(device)
|
86 |
+
tokenized.append(encoding)
|
87 |
+
return tokenized
|
88 |
+
|
89 |
+
# =========
|
90 |
+
# Extract aggregated binary feature activations for each SAE hook.
|
91 |
+
#
|
92 |
+
# For each sample, the model forward is run once to get all hidden states. Then for each hook,
|
93 |
+
# the corresponding hidden state (parsed from the hook string) is passed through its SAE. The output
|
94 |
+
# is thresholded (> 0) and aggregated via max pooling along the sequence dimension.
|
95 |
+
#
|
96 |
+
# Returns:
|
97 |
+
# A dictionary mapping hook names to a numpy array of shape (num_samples, num_features).
|
98 |
+
# =========
|
99 |
+
def extract_features(llm, model_name, tokens_list, hooks, device):
|
100 |
+
# Preload SAE models for each hook.
|
101 |
+
sae_models = {}
|
102 |
+
for hook in hooks:
|
103 |
+
sae, _, _ = SAE.from_pretrained(models_config[model_name]["sae"], hook, device=device)
|
104 |
+
sae_models[hook] = sae
|
105 |
+
|
106 |
+
features_by_hook = {hook: [] for hook in hooks}
|
107 |
+
|
108 |
+
for encoding in tqdm(tokens_list, desc="Extracting activations"):
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = llm(**encoding, output_hidden_states=True)
|
111 |
+
# For each hook, extract its corresponding hidden state.
|
112 |
+
for hook in hooks:
|
113 |
+
# Parse layer index from the hook string.
|
114 |
+
layer = int(hook.split(".")[1])
|
115 |
+
hidden_state = outputs.hidden_states[layer] # shape: (batch_size, seq_len, hidden_dim)
|
116 |
+
activations = sae_models[hook].encode(hidden_state)
|
117 |
+
# Remove the batch dimension (assumes batch size = 1) and move to CPU.
|
118 |
+
activations = activations.squeeze(0).cpu().numpy() # (seq_len, num_features)
|
119 |
+
# Threshold activations.
|
120 |
+
binary_acts = (activations > 0).astype(int)
|
121 |
+
# Aggregate over the sequence (max pooling).
|
122 |
+
aggregated = binary_acts.max(axis=0) # (num_features,)
|
123 |
+
features_by_hook[hook].append(aggregated)
|
124 |
+
# Convert lists to numpy arrays.
|
125 |
+
for hook in hooks:
|
126 |
+
features_by_hook[hook] = np.array(features_by_hook[hook])
|
127 |
+
return features_by_hook
|
128 |
+
|
129 |
+
# =========
|
130 |
+
# Compute correlations per label category.
|
131 |
+
#
|
132 |
+
# Given a feature activation array of shape (n_samples, n_features) and binary labels of shape
|
133 |
+
# (n_samples, n_categories) along with a list of category names, compute the point-biserial correlation
|
134 |
+
# for each feature against each category. For each category, sort by absolute correlation and keep the top 10.
|
135 |
+
#
|
136 |
+
# Returns a dictionary keyed by category.
|
137 |
+
# =========
|
138 |
+
def compute_correlations_by_category(feature_activations, binary_labels, categories):
|
139 |
+
results = {cat: [] for cat in categories}
|
140 |
+
n_features = feature_activations.shape[1]
|
141 |
+
for feat_idx in range(n_features):
|
142 |
+
feat_vec = feature_activations[:, feat_idx]
|
143 |
+
for cat_idx, cat in enumerate(categories):
|
144 |
+
lbl = binary_labels[:, cat_idx]
|
145 |
+
corr, _ = pointbiserialr(lbl, feat_vec)
|
146 |
+
results[cat].append({
|
147 |
+
"feature_index": feat_idx,
|
148 |
+
"correlation": corr
|
149 |
+
})
|
150 |
+
# For each category, sort and keep the top 10 records (by absolute correlation).
|
151 |
+
for cat in categories:
|
152 |
+
results[cat] = sorted(results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True)[:10]
|
153 |
+
return results
|
154 |
+
|
155 |
+
# =========
|
156 |
+
# Main function
|
157 |
+
#
|
158 |
+
# This function initializes wandb (project "corr-steer"), loads the specified dataset and LLM along with
|
159 |
+
# the SAE hooks (from models_config), tokenizes the texts, extracts feature activations,
|
160 |
+
# computes (per hook) the per-category top 10 correlation records, and writes out one JSON per hook as
|
161 |
+
# well as one aggregated JSON file that combines (and sorts) records across hooks.
|
162 |
+
#
|
163 |
+
# Run via, for example:
|
164 |
+
# python corr_extract.py --dataset emgsd --model gpt2
|
165 |
+
# =========
|
166 |
+
def main(dataset="emgsd", model="gpt2", limit=1000):
|
167 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
168 |
+
device = "mps" if torch.backends.mps.is_available() else device
|
169 |
+
|
170 |
+
# Initialize wandb.
|
171 |
+
wandb.init(project="corr-steer", config={"model": model, "dataset": dataset})
|
172 |
+
|
173 |
+
# Load tokenizer and the LLM.
|
174 |
+
print("Loading tokenizer and model...")
|
175 |
+
model_id = models_config[model]["id"]
|
176 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
177 |
+
llm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
|
178 |
+
llm.eval()
|
179 |
+
|
180 |
+
# Load dataset texts, labels, and max token length.
|
181 |
+
print("Loading dataset...")
|
182 |
+
texts, labels, max_length = load_custom_dataset(dataset, limit)
|
183 |
+
|
184 |
+
# Binarize labels.
|
185 |
+
lb = LabelBinarizer()
|
186 |
+
binary_labels = lb.fit_transform(labels)
|
187 |
+
# Ensure binary_labels is 2D.
|
188 |
+
if binary_labels.ndim == 1:
|
189 |
+
binary_labels = binary_labels.reshape(-1, 1)
|
190 |
+
|
191 |
+
# Tokenize texts.
|
192 |
+
tokens_list = tokenize_texts(tokenizer, texts, max_length, device)
|
193 |
+
|
194 |
+
# Determine hooks from models_config.
|
195 |
+
if model in models_config:
|
196 |
+
hooks = models_config[model]["hooks"]
|
197 |
+
else:
|
198 |
+
raise ValueError(f"Model {model} is not configured in models_config.")
|
199 |
+
|
200 |
+
# Extract features for all hooks.
|
201 |
+
print("Extracting features for all hooks...")
|
202 |
+
features_by_hook = extract_features(llm, model, tokens_list, hooks, device)
|
203 |
+
|
204 |
+
out_dir = "features"
|
205 |
+
os.makedirs(out_dir, exist_ok=True)
|
206 |
+
|
207 |
+
# For aggregated results across hooks.
|
208 |
+
categories = lb.classes_
|
209 |
+
aggregated_results = {cat: [] for cat in categories}
|
210 |
+
|
211 |
+
# Process each hook individually.
|
212 |
+
for hook in hooks:
|
213 |
+
print(f"Computing per-category correlations for hook {hook} ...")
|
214 |
+
feat = features_by_hook[hook] # (n_samples, n_features)
|
215 |
+
hook_corrs = compute_correlations_by_category(feat, binary_labels, categories)
|
216 |
+
|
217 |
+
# Save per-hook file.
|
218 |
+
hook_filename = f"{model}.{dataset}.{hook}.json"
|
219 |
+
hook_filepath = os.path.join(out_dir, hook_filename)
|
220 |
+
with open(hook_filepath, "w", encoding="utf-8") as f:
|
221 |
+
json.dump(hook_corrs, f, indent=2, ensure_ascii=False)
|
222 |
+
print(f"Saved per-hook correlations to {hook_filepath}")
|
223 |
+
|
224 |
+
# Log each hook's results to wandb.
|
225 |
+
wandb.log({hook: hook_corrs})
|
226 |
+
|
227 |
+
# For aggregation, add hook info to each record.
|
228 |
+
for cat in categories:
|
229 |
+
for rec in hook_corrs[cat]:
|
230 |
+
rec_with_hook = rec.copy()
|
231 |
+
rec_with_hook["hook"] = hook
|
232 |
+
aggregated_results[cat].append(rec_with_hook)
|
233 |
+
|
234 |
+
# Now, for each category, sort aggregated records across hooks and retain top 10.
|
235 |
+
final_aggregated = {}
|
236 |
+
for cat in categories:
|
237 |
+
sorted_records = sorted(
|
238 |
+
aggregated_results[cat],
|
239 |
+
key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]),
|
240 |
+
reverse=True
|
241 |
+
)[:10]
|
242 |
+
final_aggregated[cat] = sorted_records
|
243 |
+
|
244 |
+
agg_filename = f"{model}.{dataset}.json"
|
245 |
+
agg_filepath = os.path.join(out_dir, agg_filename)
|
246 |
+
with open(agg_filepath, "w", encoding="utf-8") as f:
|
247 |
+
json.dump(final_aggregated, f, indent=2, ensure_ascii=False)
|
248 |
+
print(f"Saved aggregated correlations to {agg_filepath}")
|
249 |
+
|
250 |
+
wandb.finish()
|
251 |
+
|
252 |
+
if __name__ == "__main__":
|
253 |
+
fire.Fire(main)
|
corr_steer/steer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------
|
2 |
+
# Modular Steering Hook Logic
|
3 |
+
# ------------------------------------
|
4 |
+
|
5 |
+
class SteeringHook:
|
6 |
+
def __init__(self, sae, device):
|
7 |
+
"""
|
8 |
+
Initialize the SteeringHook with an SAE and device.
|
9 |
+
"""
|
10 |
+
self.sae = sae
|
11 |
+
self.device = device
|
12 |
+
self.feature_coeffs = [] # List of (feature_index, coefficient)
|
13 |
+
self.hooks = [] # Store hook handles
|
14 |
+
self.steering_enabled = False
|
15 |
+
|
16 |
+
def enable_steering(self, feature_coeffs):
|
17 |
+
"""
|
18 |
+
Enable steering by specifying feature coefficients.
|
19 |
+
Args:
|
20 |
+
feature_coeffs (list): List of (feature_index, coefficient).
|
21 |
+
"""
|
22 |
+
self.feature_coeffs = feature_coeffs
|
23 |
+
self.steering_enabled = True
|
24 |
+
|
25 |
+
def disable_steering(self):
|
26 |
+
"""
|
27 |
+
Disable steering and clear hooks.
|
28 |
+
"""
|
29 |
+
self.steering_enabled = False
|
30 |
+
self.feature_coeffs = []
|
31 |
+
self.remove_hooks()
|
32 |
+
|
33 |
+
def generate_hook(self):
|
34 |
+
"""
|
35 |
+
Create a steering hook function that modifies the residual output.
|
36 |
+
"""
|
37 |
+
def hook_fn(module, inputs, outputs):
|
38 |
+
if not self.steering_enabled:
|
39 |
+
return outputs
|
40 |
+
|
41 |
+
residual = outputs[0] # Residual output of the module
|
42 |
+
for feature_index, coeff in self.feature_coeffs:
|
43 |
+
steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0)
|
44 |
+
residual = residual + coeff * steering_vector
|
45 |
+
return (residual, *outputs[1:])
|
46 |
+
|
47 |
+
return hook_fn
|
48 |
+
|
49 |
+
def register_hooks(self, model, block_idx):
|
50 |
+
"""
|
51 |
+
Register the steering hook to the specified block.
|
52 |
+
Args:
|
53 |
+
model (nn.Module): The target model.
|
54 |
+
block_idx (int): The block index to attach the hook.
|
55 |
+
"""
|
56 |
+
handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook())
|
57 |
+
self.hooks.append(handle)
|
58 |
+
|
59 |
+
def remove_hooks(self):
|
60 |
+
"""
|
61 |
+
Remove all registered hooks.
|
62 |
+
"""
|
63 |
+
for handle in self.hooks:
|
64 |
+
handle.remove()
|
65 |
+
self.hooks.clear()
|
demo/.gitignore
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Logs
|
2 |
+
logs
|
3 |
+
*.log
|
4 |
+
npm-debug.log*
|
5 |
+
yarn-debug.log*
|
6 |
+
yarn-error.log*
|
7 |
+
pnpm-debug.log*
|
8 |
+
lerna-debug.log*
|
9 |
+
|
10 |
+
node_modules
|
11 |
+
dist
|
12 |
+
dist-ssr
|
13 |
+
*.local
|
14 |
+
|
15 |
+
# Editor directories and files
|
16 |
+
.vscode/*
|
17 |
+
!.vscode/extensions.json
|
18 |
+
.idea
|
19 |
+
.DS_Store
|
20 |
+
*.suo
|
21 |
+
*.ntvs*
|
22 |
+
*.njsproj
|
23 |
+
*.sln
|
24 |
+
*.sw?
|
demo/README.md
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CorrSteer Frontend
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
CorrSteer demonstrates how text classification datasets can be used to steer large language models (LLMs), correlating with SAE (Sparse Autoencoder) features. This demo incorporates a modern tech stack for a seamless and efficient experience.
|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
## How to Run the Demo
|
10 |
+
|
11 |
+
1. **Set Environment Variables:**
|
12 |
+
Create a `.env` file in the `demo` directory and include the following:
|
13 |
+
|
14 |
+
```env
|
15 |
+
VITE_API_BASE_URL=<your-api-url>
|
16 |
+
```
|
17 |
+
|
18 |
+
2. **Install Dependencies:**
|
19 |
+
|
20 |
+
```bash
|
21 |
+
pnpm i
|
22 |
+
```
|
23 |
+
|
24 |
+
3. **Start the Development Server:**
|
25 |
+
|
26 |
+
```bash
|
27 |
+
pnpm dev
|
28 |
+
```
|
29 |
+
|
30 |
+
The application will be available at `http://localhost:5173` by default.
|
31 |
+
|
32 |
+
4. **Build for Production (Optional):**
|
33 |
+
|
34 |
+
```bash
|
35 |
+
pnpm build
|
36 |
+
pnpm preview
|
37 |
+
```
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## Key Features
|
42 |
+
|
43 |
+
- **Dataset & Model Selection:**
|
44 |
+
Select datasets and models using dropdown menus.
|
45 |
+
- **Streaming Outputs:**
|
46 |
+
Generate outputs from multiple models with live updates as data streams.
|
47 |
+
- **Interactive Tabs:**
|
48 |
+
Switch between different categories for customized prompts.
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
## Technology Stack
|
53 |
+
|
54 |
+
1. **Vite:**
|
55 |
+
- Development server and build tool.
|
56 |
+
|
57 |
+
2. **React:**
|
58 |
+
- UI library for building components and managing state.
|
59 |
+
|
60 |
+
3. **Tailwind CSS:**
|
61 |
+
- CSS framework for styling.
|
62 |
+
|
63 |
+
4. **ShadCN/UI:**
|
64 |
+
- Pre-built component library for UI elements.
|
65 |
+
|
66 |
+
---
|
67 |
+
|
68 |
+
## License
|
69 |
+
|
70 |
+
This project is licensed under the MIT License.
|
71 |
+
```
|
demo/components.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"$schema": "https://ui.shadcn.com/schema.json",
|
3 |
+
"style": "default",
|
4 |
+
"rsc": false,
|
5 |
+
"tsx": true,
|
6 |
+
"tailwind": {
|
7 |
+
"config": "tailwind.config.ts",
|
8 |
+
"css": "src/index.css",
|
9 |
+
"baseColor": "slate",
|
10 |
+
"cssVariables": true,
|
11 |
+
"prefix": ""
|
12 |
+
},
|
13 |
+
"aliases": {
|
14 |
+
"components": "@/components",
|
15 |
+
"utils": "@/lib/utils",
|
16 |
+
"ui": "@/components/ui",
|
17 |
+
"lib": "@/lib",
|
18 |
+
"hooks": "@/hooks"
|
19 |
+
},
|
20 |
+
"iconLibrary": "lucide"
|
21 |
+
}
|
demo/eslint.config.js
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import js from '@eslint/js'
|
2 |
+
import globals from 'globals'
|
3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
5 |
+
import tseslint from 'typescript-eslint'
|
6 |
+
|
7 |
+
export default tseslint.config(
|
8 |
+
{ ignores: ['dist'] },
|
9 |
+
{
|
10 |
+
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
11 |
+
files: ['**/*.{ts,tsx}'],
|
12 |
+
languageOptions: {
|
13 |
+
ecmaVersion: 2020,
|
14 |
+
globals: globals.browser,
|
15 |
+
},
|
16 |
+
plugins: {
|
17 |
+
'react-hooks': reactHooks,
|
18 |
+
'react-refresh': reactRefresh,
|
19 |
+
},
|
20 |
+
rules: {
|
21 |
+
...reactHooks.configs.recommended.rules,
|
22 |
+
'react-refresh/only-export-components': [
|
23 |
+
'warn',
|
24 |
+
{ allowConstantExport: true },
|
25 |
+
],
|
26 |
+
},
|
27 |
+
},
|
28 |
+
)
|
demo/index.html
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!doctype html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<link rel="icon" type="image/svg+xml" href="/finn.svg" />
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
7 |
+
<title>CorrSteer</title>
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<div id="root"></div>
|
11 |
+
<script type="module" src="/src/main.tsx"></script>
|
12 |
+
</body>
|
13 |
+
</html>
|
deploy.sh
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Deployment script for Hugging Face Spaces
|
4 |
+
# Usage: ./deploy.sh /path/to/your/huggingface/space
|
5 |
+
|
6 |
+
if [ $# -eq 0 ]; then
|
7 |
+
echo "Usage: $0 <path-to-huggingface-space>"
|
8 |
+
echo "Example: $0 /home/user/corr-steer-space"
|
9 |
+
exit 1
|
10 |
+
fi
|
11 |
+
|
12 |
+
SPACE_PATH="$1"
|
13 |
+
|
14 |
+
if [ ! -d "$SPACE_PATH" ]; then
|
15 |
+
echo "Error: Directory $SPACE_PATH does not exist"
|
16 |
+
exit 1
|
17 |
+
fi
|
18 |
+
|
19 |
+
echo "Copying files to Hugging Face Space: $SPACE_PATH"
|
20 |
+
|
21 |
+
# Copy essential files
|
22 |
+
cp Dockerfile "$SPACE_PATH/"
|
23 |
+
cp README.md "$SPACE_PATH/"
|
24 |
+
cp requirements.txt "$SPACE_PATH/"
|
25 |
+
cp start.sh "$SPACE_PATH/"
|
26 |
+
cp server.py "$SPACE_PATH/"
|
27 |
+
cp config.py "$SPACE_PATH/"
|
28 |
+
|
29 |
+
# Copy directories
|
30 |
+
cp -r demo/ "$SPACE_PATH/"
|
31 |
+
cp -r features/ "$SPACE_PATH/"
|
32 |
+
|
33 |
+
# Make start.sh executable
|
34 |
+
chmod +x "$SPACE_PATH/start.sh"
|
35 |
+
|
36 |
+
echo "Files copied successfully!"
|
37 |
+
echo ""
|
38 |
+
echo "Architecture:"
|
39 |
+
echo "- Single Docker container"
|
40 |
+
echo "- Flask serves both API (/api/*) and frontend (/*)"
|
41 |
+
echo "- Frontend built as static files during Docker build"
|
42 |
+
echo "- Port 7860 for Hugging Face Spaces"
|
43 |
+
echo ""
|
44 |
+
echo "Next steps:"
|
45 |
+
echo "1. cd $SPACE_PATH"
|
46 |
+
echo "2. git add ."
|
47 |
+
echo "3. git commit -m 'Deploy CorrSteer'"
|
48 |
+
echo "4. git push origin main"
|
49 |
+
echo "5. Monitor build logs in HF Spaces"
|
features/gpt2.emgsd.json
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"gender": [
|
3 |
+
{
|
4 |
+
"feature_index": 446,
|
5 |
+
"correlation": 0.43646558254084783,
|
6 |
+
"p_value": 2.0078186521554014e-181
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"feature_index": 18672,
|
10 |
+
"correlation": 0.39005401881798113,
|
11 |
+
"p_value": 3.990362653948118e-142
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"feature_index": 7842,
|
15 |
+
"correlation": 0.38303316661317355,
|
16 |
+
"p_value": 1.0436891049906896e-136
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"feature_index": 10097,
|
20 |
+
"correlation": 0.3566824371157726,
|
21 |
+
"p_value": 1.5738547988431885e-117
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"feature_index": 20455,
|
25 |
+
"correlation": 0.35604794762595793,
|
26 |
+
"p_value": 4.3338061984594966e-117
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"feature_index": 10963,
|
30 |
+
"correlation": 0.34987456956336604,
|
31 |
+
"p_value": 7.322243361005002e-113
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"feature_index": 7636,
|
35 |
+
"correlation": 0.3480422410323139,
|
36 |
+
"p_value": 1.2628128346319322e-111
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"feature_index": 12175,
|
40 |
+
"correlation": 0.3363904323430133,
|
41 |
+
"p_value": 5.940623346124206e-104
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"feature_index": 23204,
|
45 |
+
"correlation": 0.32981003754016014,
|
46 |
+
"p_value": 9.175664688774817e-100
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"feature_index": 12425,
|
50 |
+
"correlation": 0.315679171864645,
|
51 |
+
"p_value": 4.1092284867968456e-91
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"lgbtq+": [
|
55 |
+
{
|
56 |
+
"feature_index": 4957,
|
57 |
+
"correlation": 0.9278012434805367,
|
58 |
+
"p_value": 0.0
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"feature_index": 821,
|
62 |
+
"correlation": 0.8621067279866375,
|
63 |
+
"p_value": 0.0
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"feature_index": 23222,
|
67 |
+
"correlation": 0.7904532588108066,
|
68 |
+
"p_value": 0.0
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"feature_index": 17542,
|
72 |
+
"correlation": 0.7578473931771549,
|
73 |
+
"p_value": 0.0
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"feature_index": 1296,
|
77 |
+
"correlation": 0.7135903991331536,
|
78 |
+
"p_value": 0.0
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"feature_index": 10181,
|
82 |
+
"correlation": 0.7069358305925515,
|
83 |
+
"p_value": 0.0
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"feature_index": 630,
|
87 |
+
"correlation": 0.6970711811966723,
|
88 |
+
"p_value": 0.0
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"feature_index": 7283,
|
92 |
+
"correlation": 0.6888782329392438,
|
93 |
+
"p_value": 0.0
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"feature_index": 9957,
|
97 |
+
"correlation": 0.6806925884098848,
|
98 |
+
"p_value": 0.0
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"feature_index": 14756,
|
102 |
+
"correlation": 0.6763654166758747,
|
103 |
+
"p_value": 0.0
|
104 |
+
}
|
105 |
+
],
|
106 |
+
"nationality": [
|
107 |
+
{
|
108 |
+
"feature_index": 16401,
|
109 |
+
"correlation": 0.6201132151869269,
|
110 |
+
"p_value": 0.0
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"feature_index": 9956,
|
114 |
+
"correlation": 0.5991624893732725,
|
115 |
+
"p_value": 0.0
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"feature_index": 14205,
|
119 |
+
"correlation": 0.5218992817048258,
|
120 |
+
"p_value": 6.72102973572993e-272
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"feature_index": 19268,
|
124 |
+
"correlation": -0.47079023450986934,
|
125 |
+
"p_value": 1.0316744103724932e-214
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"feature_index": 12899,
|
129 |
+
"correlation": 0.4698551338221134,
|
130 |
+
"p_value": 9.37287400288644e-214
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"feature_index": 14102,
|
134 |
+
"correlation": 0.46625178173229587,
|
135 |
+
"p_value": 4.3347117601426704e-210
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"feature_index": 18643,
|
139 |
+
"correlation": 0.4645600803424083,
|
140 |
+
"p_value": 2.200542075099065e-208
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"feature_index": 22237,
|
144 |
+
"correlation": 0.4627396121645557,
|
145 |
+
"p_value": 1.4696460283983834e-206
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"feature_index": 19683,
|
149 |
+
"correlation": 0.4586563357732425,
|
150 |
+
"p_value": 1.6591498330413268e-202
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"feature_index": 10387,
|
154 |
+
"correlation": 0.4464409913279595,
|
155 |
+
"p_value": 1.0445581068141617e-190
|
156 |
+
}
|
157 |
+
],
|
158 |
+
"profession": [
|
159 |
+
{
|
160 |
+
"feature_index": 19268,
|
161 |
+
"correlation": 0.6632217362569325,
|
162 |
+
"p_value": 0.0
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"feature_index": 12738,
|
166 |
+
"correlation": 0.49434419021176973,
|
167 |
+
"p_value": 7.31805847873234e-240
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"feature_index": 12240,
|
171 |
+
"correlation": 0.4822464487435754,
|
172 |
+
"p_value": 1.0652308403711134e-226
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"feature_index": 3833,
|
176 |
+
"correlation": 0.4303762555805077,
|
177 |
+
"p_value": 6.587675550958632e-176
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"feature_index": 16688,
|
181 |
+
"correlation": 0.42313751736840466,
|
182 |
+
"p_value": 1.6989194964768673e-169
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"feature_index": 3658,
|
186 |
+
"correlation": 0.42073411062723626,
|
187 |
+
"p_value": 2.1105890047528277e-167
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"feature_index": 4610,
|
191 |
+
"correlation": 0.4148617376031915,
|
192 |
+
"p_value": 2.344446120591354e-162
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"feature_index": 3428,
|
196 |
+
"correlation": 0.4141357335326178,
|
197 |
+
"p_value": 9.7020391684076e-162
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"feature_index": 9956,
|
201 |
+
"correlation": -0.41240054790897485,
|
202 |
+
"p_value": 2.850570734597228e-160
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"feature_index": 14205,
|
206 |
+
"correlation": -0.4053658871423048,
|
207 |
+
"p_value": 2.081224608034156e-154
|
208 |
+
}
|
209 |
+
],
|
210 |
+
"race": [
|
211 |
+
{
|
212 |
+
"feature_index": 11047,
|
213 |
+
"correlation": 0.40880904254801465,
|
214 |
+
"p_value": 2.924658096595988e-157
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"feature_index": 6520,
|
218 |
+
"correlation": 0.32754592044407255,
|
219 |
+
"p_value": 2.399662718793716e-98
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"feature_index": 18320,
|
223 |
+
"correlation": 0.31448088732356977,
|
224 |
+
"p_value": 2.1188408928978433e-90
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"feature_index": 22312,
|
228 |
+
"correlation": 0.30658660427290435,
|
229 |
+
"p_value": 8.652381542873663e-86
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"feature_index": 22263,
|
233 |
+
"correlation": 0.2775426576622264,
|
234 |
+
"p_value": 5.101978921250436e-70
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"feature_index": 18492,
|
238 |
+
"correlation": 0.21686259675294528,
|
239 |
+
"p_value": 8.530010364272566e-43
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"feature_index": 9459,
|
243 |
+
"correlation": 0.20860841730681046,
|
244 |
+
"p_value": 1.1641254990477776e-39
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"feature_index": 22936,
|
248 |
+
"correlation": 0.20146560425468682,
|
249 |
+
"p_value": 4.7105090636264745e-37
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"feature_index": 7170,
|
253 |
+
"correlation": 0.1848603817962114,
|
254 |
+
"p_value": 2.287836815814386e-31
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"feature_index": 4798,
|
258 |
+
"correlation": 0.18320283221180658,
|
259 |
+
"p_value": 7.917784359744477e-31
|
260 |
+
}
|
261 |
+
],
|
262 |
+
"religion": [
|
263 |
+
{
|
264 |
+
"feature_index": 18754,
|
265 |
+
"correlation": 0.4919855460080574,
|
266 |
+
"p_value": 2.968070741341591e-237
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"feature_index": 17056,
|
270 |
+
"correlation": 0.4754689353084256,
|
271 |
+
"p_value": 1.4912893628128467e-219
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"feature_index": 14242,
|
275 |
+
"correlation": 0.4479290137573203,
|
276 |
+
"p_value": 4.048270907596597e-192
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"feature_index": 17886,
|
280 |
+
"correlation": 0.4412481619097158,
|
281 |
+
"p_value": 7.77154112070949e-186
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"feature_index": 7870,
|
285 |
+
"correlation": 0.4383254166163272,
|
286 |
+
"p_value": 3.938542029802173e-183
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"feature_index": 903,
|
290 |
+
"correlation": 0.42866564467200124,
|
291 |
+
"p_value": 2.2285535850517167e-174
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"feature_index": 21637,
|
295 |
+
"correlation": 0.42829559715114474,
|
296 |
+
"p_value": 4.7608700745146375e-174
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"feature_index": 7699,
|
300 |
+
"correlation": 0.4279351187362136,
|
301 |
+
"p_value": 9.96388834015504e-174
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"feature_index": 23932,
|
305 |
+
"correlation": 0.4146661267040663,
|
306 |
+
"p_value": 3.4386189140897016e-162
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"feature_index": 18686,
|
310 |
+
"correlation": 0.4145648062343114,
|
311 |
+
"p_value": 4.1927743639700714e-162
|
312 |
+
}
|
313 |
+
]
|
314 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
flask>=2.3.0
|
4 |
+
flask-cors>=4.0.0
|
5 |
+
sae-lens>=3.0.0
|
6 |
+
huggingface-hub>=0.16.0
|
7 |
+
gradio>=4.0.0
|
8 |
+
numpy>=1.21.0
|
server.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from flask import Flask, Response, request, send_from_directory, send_file
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
4 |
+
from threading import Thread
|
5 |
+
from sae_lens import SAE
|
6 |
+
from flask_cors import CORS
|
7 |
+
from json import load
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Example config reading
|
11 |
+
from config import datasets_config, models_config
|
12 |
+
|
13 |
+
# ------------------------------------
|
14 |
+
# Global Setup: load tokenizer/models
|
15 |
+
# ------------------------------------
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
device = "mps" if torch.backends.mps.is_available() else device
|
18 |
+
|
19 |
+
# Main tokenizer (GPT-2 style). Adjust if using different.
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
21 |
+
if tokenizer.pad_token_id is None:
|
22 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
23 |
+
|
24 |
+
# Original GPT-2
|
25 |
+
original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
|
26 |
+
original_model.eval()
|
27 |
+
|
28 |
+
# "Trained"/"biased" GPT-2 model
|
29 |
+
trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device)
|
30 |
+
trained_model.eval()
|
31 |
+
|
32 |
+
# ------------------------------------
|
33 |
+
# Steering Hook Setup (optional)
|
34 |
+
# ------------------------------------
|
35 |
+
|
36 |
+
# Example steering feature(s)
|
37 |
+
hooks = []
|
38 |
+
def generate_pre_hook(sae: SAE, index: int, coeff: float):
|
39 |
+
def steering_hook(module, inputs):
|
40 |
+
"""
|
41 |
+
Simple version of a steering hook. Adds a weighted vector
|
42 |
+
to the residual. Customize if needed.
|
43 |
+
"""
|
44 |
+
residual = inputs[0]
|
45 |
+
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
|
46 |
+
residual = residual + coeff * steering_vector
|
47 |
+
return (residual)
|
48 |
+
return steering_hook
|
49 |
+
def generate_post_hook(sae: SAE, index: int, coeff: float):
|
50 |
+
def steering_hook(module, inputs, outputs):
|
51 |
+
"""
|
52 |
+
Simple version of a steering hook. Adds a weighted vector
|
53 |
+
to the residual. Customize if needed.
|
54 |
+
"""
|
55 |
+
residual = outputs[0]
|
56 |
+
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
|
57 |
+
residual = residual + coeff * steering_vector
|
58 |
+
return (residual, outputs[1], outputs[2])
|
59 |
+
return steering_hook
|
60 |
+
|
61 |
+
def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str):
|
62 |
+
file_path = f"features/{model_key}.{dataset_key}.json"
|
63 |
+
with open(file_path, "r") as f:
|
64 |
+
feature_map = load(f)
|
65 |
+
top_features = feature_map[category_key]
|
66 |
+
if "+" in gen_type:
|
67 |
+
coeff = 75
|
68 |
+
elif "-" in gen_type:
|
69 |
+
coeff = 50
|
70 |
+
if "+" in gen_type:
|
71 |
+
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
|
72 |
+
elif "-" in gen_type:
|
73 |
+
filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features))
|
74 |
+
if len(filtered_features) == 0:
|
75 |
+
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
|
76 |
+
coeff = 75
|
77 |
+
top_feature = filtered_features[0]
|
78 |
+
|
79 |
+
hook_point = "blocks.11.hook_resid_pre"
|
80 |
+
block_idx = int(hook_point.split(".")[1])
|
81 |
+
index = top_feature["feature_index"]
|
82 |
+
|
83 |
+
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
84 |
+
models_config[model_key]["sae"],
|
85 |
+
hook_point,
|
86 |
+
device=device,
|
87 |
+
)
|
88 |
+
|
89 |
+
module = model.transformer.h[block_idx]
|
90 |
+
if "pre" in hook_point:
|
91 |
+
handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff))
|
92 |
+
elif "post" in hook_point:
|
93 |
+
handle = module.register_forward_hook(generate_post_hook(sae, index, coeff))
|
94 |
+
hooks.append(handle)
|
95 |
+
|
96 |
+
def remove_hooks():
|
97 |
+
for h in hooks:
|
98 |
+
h.remove()
|
99 |
+
hooks.clear()
|
100 |
+
|
101 |
+
# ------------------------------------
|
102 |
+
# Helper: streaming generator
|
103 |
+
# ------------------------------------
|
104 |
+
def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0):
|
105 |
+
"""
|
106 |
+
Yields tokens as they are generated in a separate thread.
|
107 |
+
"""
|
108 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
109 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
110 |
+
|
111 |
+
generation_kwargs = dict(
|
112 |
+
**inputs,
|
113 |
+
streamer=streamer,
|
114 |
+
max_new_tokens=max_new_tokens,
|
115 |
+
do_sample=True,
|
116 |
+
temperature=temperature,
|
117 |
+
top_p=top_p,
|
118 |
+
pad_token_id=tokenizer.eos_token_id,
|
119 |
+
repetition_penalty=repetition_penalty,
|
120 |
+
)
|
121 |
+
|
122 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
123 |
+
thread.start()
|
124 |
+
|
125 |
+
for new_text in streamer:
|
126 |
+
yield new_text
|
127 |
+
|
128 |
+
# ------------------------------------
|
129 |
+
# Flask App
|
130 |
+
# ------------------------------------
|
131 |
+
app = Flask(__name__)
|
132 |
+
CORS(app)
|
133 |
+
|
134 |
+
# API routes first (to avoid conflicts with static serving)
|
135 |
+
@app.route("/api/generate", methods=["POST"])
|
136 |
+
def generate():
|
137 |
+
"""
|
138 |
+
Expects JSON like:
|
139 |
+
{
|
140 |
+
"model": "gpt2",
|
141 |
+
"dataset": "emgsd",
|
142 |
+
"category": "lgbtq+",
|
143 |
+
"type": "original" | "origin+steer" | "trained" | "trained-steer"
|
144 |
+
}
|
145 |
+
Streams back the generated text token by token.
|
146 |
+
"""
|
147 |
+
data = request.json
|
148 |
+
model_key = data["model"]
|
149 |
+
dataset_key = data["dataset"]
|
150 |
+
category_key = data["category"]
|
151 |
+
gen_type = data["type"]
|
152 |
+
|
153 |
+
# 1. Figure out prompt from config
|
154 |
+
try:
|
155 |
+
prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"]
|
156 |
+
except KeyError:
|
157 |
+
return Response("Invalid dataset/category combination.", status=400)
|
158 |
+
|
159 |
+
# 2. Select the model
|
160 |
+
if "trained" in gen_type:
|
161 |
+
chosen_model = trained_model
|
162 |
+
else:
|
163 |
+
chosen_model = original_model
|
164 |
+
|
165 |
+
# 3. Steering logic if "steer" in the request type
|
166 |
+
remove_hooks()
|
167 |
+
if "steer" in gen_type:
|
168 |
+
register_steering(chosen_model, model_key, gen_type, dataset_key, category_key)
|
169 |
+
|
170 |
+
# Return a streaming response of tokens
|
171 |
+
def token_stream():
|
172 |
+
for token in stream_generate(chosen_model, prompt_text):
|
173 |
+
yield token
|
174 |
+
remove_hooks()
|
175 |
+
|
176 |
+
return Response(token_stream(), mimetype="text/event-stream")
|
177 |
+
|
178 |
+
|
179 |
+
# Serve static files for HF Spaces (after API routes)
|
180 |
+
@app.route("/")
|
181 |
+
def serve_frontend():
|
182 |
+
return send_file("demo/dist/index.html")
|
183 |
+
|
184 |
+
@app.route("/<path:path>")
|
185 |
+
def serve_static(path):
|
186 |
+
if path.startswith('api/'):
|
187 |
+
return None
|
188 |
+
if os.path.exists(f"demo/dist/{path}"):
|
189 |
+
return send_from_directory("demo/dist", path)
|
190 |
+
return send_file("demo/dist/index.html")
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
port = int(os.environ.get("PORT", 5174))
|
195 |
+
app.run(host="0.0.0.0", port=port, debug=True)
|
start.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set environment variables for HF Spaces
|
4 |
+
export VITE_API_BASE_URL=""
|
5 |
+
|
6 |
+
# Start Flask backend in the background
|
7 |
+
echo "Starting Flask backend..."
|
8 |
+
python server.py &
|
9 |
+
|
10 |
+
# Wait for backend to start
|
11 |
+
sleep 10
|
12 |
+
|
13 |
+
# Start frontend development server
|
14 |
+
echo "Starting React frontend..."
|
15 |
+
cd demo
|
16 |
+
pnpm dev --host 0.0.0.0 --port 7860 &
|
17 |
+
|
18 |
+
# Keep the container running
|
19 |
+
wait
|
20 |
+
|
21 |
+
->
|
22 |
+
|
23 |
+
#!/bin/bash
|
24 |
+
|
25 |
+
# Set environment variables for HF Spaces
|
26 |
+
export PORT=7860
|
27 |
+
|
28 |
+
# Start Flask server (serves both API and frontend)
|
29 |
+
echo "Starting CorrSteer application..."
|
30 |
+
python server.py
|