seonglae commited on
Commit
889f722
·
0 Parent(s):

feat: hf space corr-steer

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ .env
Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Node.js 18 as base image
2
+ FROM node:18-bullseye
3
+
4
+ # Install Python 3.9 and pip
5
+ RUN apt-get update && apt-get install -y \
6
+ python3.9 \
7
+ python3-pip \
8
+ python3-dev \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Set up aliases for python
13
+ RUN ln -s /usr/bin/python3.9 /usr/bin/python
14
+
15
+ # Install pnpm
16
+ RUN npm install -g pnpm
17
+
18
+ # Set working directory
19
+ WORKDIR /app
20
+
21
+ # Copy Python requirements and install dependencies
22
+ COPY requirements.txt .
23
+ RUN pip3 install --no-cache-dir -r requirements.txt
24
+
25
+ # Copy Python files
26
+ COPY server.py .
27
+ COPY config.py .
28
+ COPY features/ ./features/
29
+
30
+ # Copy frontend files
31
+ COPY demo/ ./demo/
32
+
33
+ # Install frontend dependencies and build
34
+ WORKDIR /app/demo
35
+ RUN pnpm install
36
+ RUN pnpm build
37
+
38
+ # Verify build output
39
+ RUN ls -la dist/
40
+
41
+ # Go back to app directory
42
+ WORKDIR /app
43
+
44
+ # Copy startup script
45
+ COPY start.sh .
46
+ RUN chmod +x start.sh
47
+
48
+ # Expose port 7860 for HF Spaces
49
+ EXPOSE 7860
50
+
51
+ # Set environment variables
52
+ ENV NODE_ENV=production
53
+ ENV FLASK_ENV=production
54
+ ENV PYTHONPATH=/app
55
+
56
+ # Start both services
57
+ CMD ["./start.sh"]
config.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========
2
+ # Configuration: Define dataset information.
3
+ # =========
4
+ datasets_config = {
5
+ "emgsd": {
6
+ "id": "holistic-ai/EMGSD",
7
+ "columns": ["category", "stereotype_type", "text"],
8
+ "filter": {"category": "stereotype"},
9
+ "text_field": "text",
10
+ "label_field": "stereotype_type",
11
+ "max_length": 512,
12
+ "category": {
13
+ "lgbtq+": {
14
+ "prompt": "The gay people are",
15
+ },
16
+ "nationality": {
17
+ "prompt": "The French people are",
18
+ },
19
+ "race": {
20
+ "prompt": "The Black people are",
21
+ }
22
+ }
23
+ },
24
+ "yahoo_question": {
25
+ "id": "community-datasets/yahoo_answers_topics",
26
+ "columns": ["topic", "question_title", "question_content"],
27
+ "text_field": "question_title",
28
+ "label_field": "topic",
29
+ "max_length": 512
30
+ },
31
+ "yahoo_answer": {
32
+ "id": "community-datasets/yahoo_answers_topics",
33
+ "columns": ["topic", "best_answer"],
34
+ "text_field": "best_answer",
35
+ "label_field": "topic",
36
+ "max_length": 512
37
+ },
38
+ "science": {
39
+ "id": "knowledgator/Scientific-text-classification",
40
+ "columns": ["text", "label"],
41
+ "text_field": "text",
42
+ "label_field": "label",
43
+ "max_length": 512
44
+ },
45
+ "wiki256": {
46
+ "id": "seonglae/wikipedia-256",
47
+ "columns": ["text", "title"],
48
+ "text_field": "text",
49
+ "label_field": "title",
50
+ "max_length": 512
51
+ },
52
+ "wiki512": {
53
+ "id": "seonglae/wikipedia-512",
54
+ "columns": ["text", "title"],
55
+ "text_field": "text",
56
+ "label_field": "title",
57
+ "max_length": 1024
58
+ }
59
+ }
60
+
61
+ # =========
62
+ # Configuration: Define model-specific information.
63
+ # For "gpt2", we specify the SAE source and the list of hooks to use.
64
+ # f"{model}-{dataset}" is the key for trained models.
65
+ # =========
66
+ models_config = {
67
+ "gpt2": {
68
+ "id": "gpt2",
69
+ "sae": "jbloom/GPT2-Small-SAEs-Reformatted",
70
+ "hooks": [
71
+ "blocks.11.hook_resid_pre",
72
+ "blocks.10.hook_resid_pre",
73
+ "blocks.9.hook_resid_pre",
74
+ "blocks.8.hook_resid_pre",
75
+ "blocks.7.hook_resid_pre",
76
+ "blocks.6.hook_resid_pre",
77
+ "blocks.5.hook_resid_pre",
78
+ "blocks.4.hook_resid_pre",
79
+ "blocks.3.hook_resid_pre",
80
+ "blocks.2.hook_resid_pre",
81
+ "blocks.1.hook_resid_pre",
82
+ "blocks.0.hook_resid_pre"
83
+ ]
84
+ },
85
+ "gpt2-emgsd": {
86
+ "id": "holistic-ai/gpt2-EMGSD",
87
+ "sae": "jbloom/GPT2-Small-SAEs-Reformatted",
88
+ "hooks": [
89
+ "blocks.11.hook_resid_pre",
90
+ "blocks.10.hook_resid_pre",
91
+ "blocks.9.hook_resid_pre",
92
+ "blocks.8.hook_resid_pre",
93
+ "blocks.7.hook_resid_pre",
94
+ "blocks.6.hook_resid_pre",
95
+ "blocks.5.hook_resid_pre",
96
+ "blocks.4.hook_resid_pre",
97
+ "blocks.3.hook_resid_pre",
98
+ "blocks.2.hook_resid_pre",
99
+ "blocks.1.hook_resid_pre",
100
+ "blocks.0.hook_resid_pre"
101
+ ]
102
+ }
103
+ }
corr_extract.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Project: corr-steer
3
+
4
+ This script loads a given dataset (using a datasets configuration), tokenizes the texts,
5
+ runs a LLM model forward to extract hidden states, and then - for each SAE hook defined
6
+ for the model - computes binary feature activations. Activations are thresholded and aggregated
7
+ (max over the sequence), and the point-biserial correlation is computed between each feature
8
+ and each label category.
9
+
10
+ For each hook, the script finds the top 10 features (sorted in descending order by absolute correlation)
11
+ per category and saves that record to a JSON file:
12
+
13
+ features/{model}.{dataset}.{hook}.json
14
+
15
+ Additionally, an aggregated JSON file
16
+
17
+ features/{model}.{dataset}.json
18
+
19
+ is created that, for each category, merges records from all hooks (each record includes its hook)
20
+ and retains the top 10 features overall.
21
+
22
+ This script is callable via Fire, e.g.:
23
+ python corr_extract.py --dataset emgsd --model gpt2
24
+ """
25
+
26
+ import os
27
+ import json
28
+ import math
29
+ import torch
30
+ import numpy as np
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM
32
+ from sae_lens import SAE
33
+ from datasets import load_dataset, concatenate_datasets
34
+ from tqdm import tqdm
35
+ from sklearn.preprocessing import LabelBinarizer
36
+ from scipy.stats import pointbiserialr
37
+ import wandb
38
+ import fire
39
+
40
+ from config import datasets_config, models_config
41
+
42
+
43
+ # =========
44
+ # Load a dataset given the dataset configuration.
45
+ #
46
+ # Returns:
47
+ # texts, labels, and the maximum token length.
48
+ # =========
49
+ def load_custom_dataset(dataset_name, limit: int):
50
+ config = datasets_config[dataset_name]
51
+ # For this example, we use the "train" split.
52
+ dataset = load_dataset(config["id"], split="test")
53
+ # Shuffle dataset for extracting features to divide validation
54
+ dataset = dataset.shuffle(seed=42).select(range(int(dataset.num_rows / 2)))
55
+
56
+ # Select only the specified columns.
57
+ dataset = dataset.select_columns(config["columns"])
58
+
59
+ # Apply filtering if specified.
60
+ if "filter" in config:
61
+ for key, val in config["filter"].items():
62
+ dataset = dataset.filter(lambda ex, key=key, val=val: ex[key] == val)
63
+
64
+ texts = []
65
+ labels = []
66
+ text_field = config["text_field"]
67
+ label_field = config["label_field"]
68
+ if limit:
69
+ dataset = dataset.select(range(limit))
70
+ for ex in tqdm(dataset, desc="Loading dataset"):
71
+ texts.append(str(ex[text_field]))
72
+ labels.append(str(ex[label_field]))
73
+ return texts, labels, config["max_length"]
74
+
75
+ # =========
76
+ # Tokenize a list of texts.
77
+ #
78
+ # Returns a list of tokenized (and device-mapped) inputs.
79
+ # =========
80
+ def tokenize_texts(tokenizer, texts, max_length, device):
81
+ tokenized = []
82
+ for text in tqdm(texts, desc="Tokenizing texts"):
83
+ encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
84
+ for k, v in encoding.items():
85
+ encoding[k] = v.to(device)
86
+ tokenized.append(encoding)
87
+ return tokenized
88
+
89
+ # =========
90
+ # Extract aggregated binary feature activations for each SAE hook.
91
+ #
92
+ # For each sample, the model forward is run once to get all hidden states. Then for each hook,
93
+ # the corresponding hidden state (parsed from the hook string) is passed through its SAE. The output
94
+ # is thresholded (> 0) and aggregated via max pooling along the sequence dimension.
95
+ #
96
+ # Returns:
97
+ # A dictionary mapping hook names to a numpy array of shape (num_samples, num_features).
98
+ # =========
99
+ def extract_features(llm, model_name, tokens_list, hooks, device):
100
+ # Preload SAE models for each hook.
101
+ sae_models = {}
102
+ for hook in hooks:
103
+ sae, _, _ = SAE.from_pretrained(models_config[model_name]["sae"], hook, device=device)
104
+ sae_models[hook] = sae
105
+
106
+ features_by_hook = {hook: [] for hook in hooks}
107
+
108
+ for encoding in tqdm(tokens_list, desc="Extracting activations"):
109
+ with torch.no_grad():
110
+ outputs = llm(**encoding, output_hidden_states=True)
111
+ # For each hook, extract its corresponding hidden state.
112
+ for hook in hooks:
113
+ # Parse layer index from the hook string.
114
+ layer = int(hook.split(".")[1])
115
+ hidden_state = outputs.hidden_states[layer] # shape: (batch_size, seq_len, hidden_dim)
116
+ activations = sae_models[hook].encode(hidden_state)
117
+ # Remove the batch dimension (assumes batch size = 1) and move to CPU.
118
+ activations = activations.squeeze(0).cpu().numpy() # (seq_len, num_features)
119
+ # Threshold activations.
120
+ binary_acts = (activations > 0).astype(int)
121
+ # Aggregate over the sequence (max pooling).
122
+ aggregated = binary_acts.max(axis=0) # (num_features,)
123
+ features_by_hook[hook].append(aggregated)
124
+ # Convert lists to numpy arrays.
125
+ for hook in hooks:
126
+ features_by_hook[hook] = np.array(features_by_hook[hook])
127
+ return features_by_hook
128
+
129
+ # =========
130
+ # Compute correlations per label category.
131
+ #
132
+ # Given a feature activation array of shape (n_samples, n_features) and binary labels of shape
133
+ # (n_samples, n_categories) along with a list of category names, compute the point-biserial correlation
134
+ # for each feature against each category. For each category, sort by absolute correlation and keep the top 10.
135
+ #
136
+ # Returns a dictionary keyed by category.
137
+ # =========
138
+ def compute_correlations_by_category(feature_activations, binary_labels, categories):
139
+ results = {cat: [] for cat in categories}
140
+ n_features = feature_activations.shape[1]
141
+ for feat_idx in range(n_features):
142
+ feat_vec = feature_activations[:, feat_idx]
143
+ for cat_idx, cat in enumerate(categories):
144
+ lbl = binary_labels[:, cat_idx]
145
+ corr, _ = pointbiserialr(lbl, feat_vec)
146
+ results[cat].append({
147
+ "feature_index": feat_idx,
148
+ "correlation": corr
149
+ })
150
+ # For each category, sort and keep the top 10 records (by absolute correlation).
151
+ for cat in categories:
152
+ results[cat] = sorted(results[cat], key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]), reverse=True)[:10]
153
+ return results
154
+
155
+ # =========
156
+ # Main function
157
+ #
158
+ # This function initializes wandb (project "corr-steer"), loads the specified dataset and LLM along with
159
+ # the SAE hooks (from models_config), tokenizes the texts, extracts feature activations,
160
+ # computes (per hook) the per-category top 10 correlation records, and writes out one JSON per hook as
161
+ # well as one aggregated JSON file that combines (and sorts) records across hooks.
162
+ #
163
+ # Run via, for example:
164
+ # python corr_extract.py --dataset emgsd --model gpt2
165
+ # =========
166
+ def main(dataset="emgsd", model="gpt2", limit=1000):
167
+ device = "cuda" if torch.cuda.is_available() else "cpu"
168
+ device = "mps" if torch.backends.mps.is_available() else device
169
+
170
+ # Initialize wandb.
171
+ wandb.init(project="corr-steer", config={"model": model, "dataset": dataset})
172
+
173
+ # Load tokenizer and the LLM.
174
+ print("Loading tokenizer and model...")
175
+ model_id = models_config[model]["id"]
176
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
177
+ llm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
178
+ llm.eval()
179
+
180
+ # Load dataset texts, labels, and max token length.
181
+ print("Loading dataset...")
182
+ texts, labels, max_length = load_custom_dataset(dataset, limit)
183
+
184
+ # Binarize labels.
185
+ lb = LabelBinarizer()
186
+ binary_labels = lb.fit_transform(labels)
187
+ # Ensure binary_labels is 2D.
188
+ if binary_labels.ndim == 1:
189
+ binary_labels = binary_labels.reshape(-1, 1)
190
+
191
+ # Tokenize texts.
192
+ tokens_list = tokenize_texts(tokenizer, texts, max_length, device)
193
+
194
+ # Determine hooks from models_config.
195
+ if model in models_config:
196
+ hooks = models_config[model]["hooks"]
197
+ else:
198
+ raise ValueError(f"Model {model} is not configured in models_config.")
199
+
200
+ # Extract features for all hooks.
201
+ print("Extracting features for all hooks...")
202
+ features_by_hook = extract_features(llm, model, tokens_list, hooks, device)
203
+
204
+ out_dir = "features"
205
+ os.makedirs(out_dir, exist_ok=True)
206
+
207
+ # For aggregated results across hooks.
208
+ categories = lb.classes_
209
+ aggregated_results = {cat: [] for cat in categories}
210
+
211
+ # Process each hook individually.
212
+ for hook in hooks:
213
+ print(f"Computing per-category correlations for hook {hook} ...")
214
+ feat = features_by_hook[hook] # (n_samples, n_features)
215
+ hook_corrs = compute_correlations_by_category(feat, binary_labels, categories)
216
+
217
+ # Save per-hook file.
218
+ hook_filename = f"{model}.{dataset}.{hook}.json"
219
+ hook_filepath = os.path.join(out_dir, hook_filename)
220
+ with open(hook_filepath, "w", encoding="utf-8") as f:
221
+ json.dump(hook_corrs, f, indent=2, ensure_ascii=False)
222
+ print(f"Saved per-hook correlations to {hook_filepath}")
223
+
224
+ # Log each hook's results to wandb.
225
+ wandb.log({hook: hook_corrs})
226
+
227
+ # For aggregation, add hook info to each record.
228
+ for cat in categories:
229
+ for rec in hook_corrs[cat]:
230
+ rec_with_hook = rec.copy()
231
+ rec_with_hook["hook"] = hook
232
+ aggregated_results[cat].append(rec_with_hook)
233
+
234
+ # Now, for each category, sort aggregated records across hooks and retain top 10.
235
+ final_aggregated = {}
236
+ for cat in categories:
237
+ sorted_records = sorted(
238
+ aggregated_results[cat],
239
+ key=lambda x: 0 if math.isnan(x["correlation"]) else abs(x["correlation"]),
240
+ reverse=True
241
+ )[:10]
242
+ final_aggregated[cat] = sorted_records
243
+
244
+ agg_filename = f"{model}.{dataset}.json"
245
+ agg_filepath = os.path.join(out_dir, agg_filename)
246
+ with open(agg_filepath, "w", encoding="utf-8") as f:
247
+ json.dump(final_aggregated, f, indent=2, ensure_ascii=False)
248
+ print(f"Saved aggregated correlations to {agg_filepath}")
249
+
250
+ wandb.finish()
251
+
252
+ if __name__ == "__main__":
253
+ fire.Fire(main)
corr_steer/steer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------
2
+ # Modular Steering Hook Logic
3
+ # ------------------------------------
4
+
5
+ class SteeringHook:
6
+ def __init__(self, sae, device):
7
+ """
8
+ Initialize the SteeringHook with an SAE and device.
9
+ """
10
+ self.sae = sae
11
+ self.device = device
12
+ self.feature_coeffs = [] # List of (feature_index, coefficient)
13
+ self.hooks = [] # Store hook handles
14
+ self.steering_enabled = False
15
+
16
+ def enable_steering(self, feature_coeffs):
17
+ """
18
+ Enable steering by specifying feature coefficients.
19
+ Args:
20
+ feature_coeffs (list): List of (feature_index, coefficient).
21
+ """
22
+ self.feature_coeffs = feature_coeffs
23
+ self.steering_enabled = True
24
+
25
+ def disable_steering(self):
26
+ """
27
+ Disable steering and clear hooks.
28
+ """
29
+ self.steering_enabled = False
30
+ self.feature_coeffs = []
31
+ self.remove_hooks()
32
+
33
+ def generate_hook(self):
34
+ """
35
+ Create a steering hook function that modifies the residual output.
36
+ """
37
+ def hook_fn(module, inputs, outputs):
38
+ if not self.steering_enabled:
39
+ return outputs
40
+
41
+ residual = outputs[0] # Residual output of the module
42
+ for feature_index, coeff in self.feature_coeffs:
43
+ steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0)
44
+ residual = residual + coeff * steering_vector
45
+ return (residual, *outputs[1:])
46
+
47
+ return hook_fn
48
+
49
+ def register_hooks(self, model, block_idx):
50
+ """
51
+ Register the steering hook to the specified block.
52
+ Args:
53
+ model (nn.Module): The target model.
54
+ block_idx (int): The block index to attach the hook.
55
+ """
56
+ handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook())
57
+ self.hooks.append(handle)
58
+
59
+ def remove_hooks(self):
60
+ """
61
+ Remove all registered hooks.
62
+ """
63
+ for handle in self.hooks:
64
+ handle.remove()
65
+ self.hooks.clear()
demo/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
demo/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CorrSteer Frontend
2
+
3
+ ## Overview
4
+
5
+ CorrSteer demonstrates how text classification datasets can be used to steer large language models (LLMs), correlating with SAE (Sparse Autoencoder) features. This demo incorporates a modern tech stack for a seamless and efficient experience.
6
+
7
+ ---
8
+
9
+ ## How to Run the Demo
10
+
11
+ 1. **Set Environment Variables:**
12
+ Create a `.env` file in the `demo` directory and include the following:
13
+
14
+ ```env
15
+ VITE_API_BASE_URL=<your-api-url>
16
+ ```
17
+
18
+ 2. **Install Dependencies:**
19
+
20
+ ```bash
21
+ pnpm i
22
+ ```
23
+
24
+ 3. **Start the Development Server:**
25
+
26
+ ```bash
27
+ pnpm dev
28
+ ```
29
+
30
+ The application will be available at `http://localhost:5173` by default.
31
+
32
+ 4. **Build for Production (Optional):**
33
+
34
+ ```bash
35
+ pnpm build
36
+ pnpm preview
37
+ ```
38
+
39
+ ---
40
+
41
+ ## Key Features
42
+
43
+ - **Dataset & Model Selection:**
44
+ Select datasets and models using dropdown menus.
45
+ - **Streaming Outputs:**
46
+ Generate outputs from multiple models with live updates as data streams.
47
+ - **Interactive Tabs:**
48
+ Switch between different categories for customized prompts.
49
+
50
+ ---
51
+
52
+ ## Technology Stack
53
+
54
+ 1. **Vite:**
55
+ - Development server and build tool.
56
+
57
+ 2. **React:**
58
+ - UI library for building components and managing state.
59
+
60
+ 3. **Tailwind CSS:**
61
+ - CSS framework for styling.
62
+
63
+ 4. **ShadCN/UI:**
64
+ - Pre-built component library for UI elements.
65
+
66
+ ---
67
+
68
+ ## License
69
+
70
+ This project is licensed under the MIT License.
71
+ ```
demo/components.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "https://ui.shadcn.com/schema.json",
3
+ "style": "default",
4
+ "rsc": false,
5
+ "tsx": true,
6
+ "tailwind": {
7
+ "config": "tailwind.config.ts",
8
+ "css": "src/index.css",
9
+ "baseColor": "slate",
10
+ "cssVariables": true,
11
+ "prefix": ""
12
+ },
13
+ "aliases": {
14
+ "components": "@/components",
15
+ "utils": "@/lib/utils",
16
+ "ui": "@/components/ui",
17
+ "lib": "@/lib",
18
+ "hooks": "@/hooks"
19
+ },
20
+ "iconLibrary": "lucide"
21
+ }
demo/eslint.config.js ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import tseslint from 'typescript-eslint'
6
+
7
+ export default tseslint.config(
8
+ { ignores: ['dist'] },
9
+ {
10
+ extends: [js.configs.recommended, ...tseslint.configs.recommended],
11
+ files: ['**/*.{ts,tsx}'],
12
+ languageOptions: {
13
+ ecmaVersion: 2020,
14
+ globals: globals.browser,
15
+ },
16
+ plugins: {
17
+ 'react-hooks': reactHooks,
18
+ 'react-refresh': reactRefresh,
19
+ },
20
+ rules: {
21
+ ...reactHooks.configs.recommended.rules,
22
+ 'react-refresh/only-export-components': [
23
+ 'warn',
24
+ { allowConstantExport: true },
25
+ ],
26
+ },
27
+ },
28
+ )
demo/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/finn.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>CorrSteer</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.tsx"></script>
12
+ </body>
13
+ </html>
deploy.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Deployment script for Hugging Face Spaces
4
+ # Usage: ./deploy.sh /path/to/your/huggingface/space
5
+
6
+ if [ $# -eq 0 ]; then
7
+ echo "Usage: $0 <path-to-huggingface-space>"
8
+ echo "Example: $0 /home/user/corr-steer-space"
9
+ exit 1
10
+ fi
11
+
12
+ SPACE_PATH="$1"
13
+
14
+ if [ ! -d "$SPACE_PATH" ]; then
15
+ echo "Error: Directory $SPACE_PATH does not exist"
16
+ exit 1
17
+ fi
18
+
19
+ echo "Copying files to Hugging Face Space: $SPACE_PATH"
20
+
21
+ # Copy essential files
22
+ cp Dockerfile "$SPACE_PATH/"
23
+ cp README.md "$SPACE_PATH/"
24
+ cp requirements.txt "$SPACE_PATH/"
25
+ cp start.sh "$SPACE_PATH/"
26
+ cp server.py "$SPACE_PATH/"
27
+ cp config.py "$SPACE_PATH/"
28
+
29
+ # Copy directories
30
+ cp -r demo/ "$SPACE_PATH/"
31
+ cp -r features/ "$SPACE_PATH/"
32
+
33
+ # Make start.sh executable
34
+ chmod +x "$SPACE_PATH/start.sh"
35
+
36
+ echo "Files copied successfully!"
37
+ echo ""
38
+ echo "Architecture:"
39
+ echo "- Single Docker container"
40
+ echo "- Flask serves both API (/api/*) and frontend (/*)"
41
+ echo "- Frontend built as static files during Docker build"
42
+ echo "- Port 7860 for Hugging Face Spaces"
43
+ echo ""
44
+ echo "Next steps:"
45
+ echo "1. cd $SPACE_PATH"
46
+ echo "2. git add ."
47
+ echo "3. git commit -m 'Deploy CorrSteer'"
48
+ echo "4. git push origin main"
49
+ echo "5. Monitor build logs in HF Spaces"
features/gpt2.emgsd.json ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gender": [
3
+ {
4
+ "feature_index": 446,
5
+ "correlation": 0.43646558254084783,
6
+ "p_value": 2.0078186521554014e-181
7
+ },
8
+ {
9
+ "feature_index": 18672,
10
+ "correlation": 0.39005401881798113,
11
+ "p_value": 3.990362653948118e-142
12
+ },
13
+ {
14
+ "feature_index": 7842,
15
+ "correlation": 0.38303316661317355,
16
+ "p_value": 1.0436891049906896e-136
17
+ },
18
+ {
19
+ "feature_index": 10097,
20
+ "correlation": 0.3566824371157726,
21
+ "p_value": 1.5738547988431885e-117
22
+ },
23
+ {
24
+ "feature_index": 20455,
25
+ "correlation": 0.35604794762595793,
26
+ "p_value": 4.3338061984594966e-117
27
+ },
28
+ {
29
+ "feature_index": 10963,
30
+ "correlation": 0.34987456956336604,
31
+ "p_value": 7.322243361005002e-113
32
+ },
33
+ {
34
+ "feature_index": 7636,
35
+ "correlation": 0.3480422410323139,
36
+ "p_value": 1.2628128346319322e-111
37
+ },
38
+ {
39
+ "feature_index": 12175,
40
+ "correlation": 0.3363904323430133,
41
+ "p_value": 5.940623346124206e-104
42
+ },
43
+ {
44
+ "feature_index": 23204,
45
+ "correlation": 0.32981003754016014,
46
+ "p_value": 9.175664688774817e-100
47
+ },
48
+ {
49
+ "feature_index": 12425,
50
+ "correlation": 0.315679171864645,
51
+ "p_value": 4.1092284867968456e-91
52
+ }
53
+ ],
54
+ "lgbtq+": [
55
+ {
56
+ "feature_index": 4957,
57
+ "correlation": 0.9278012434805367,
58
+ "p_value": 0.0
59
+ },
60
+ {
61
+ "feature_index": 821,
62
+ "correlation": 0.8621067279866375,
63
+ "p_value": 0.0
64
+ },
65
+ {
66
+ "feature_index": 23222,
67
+ "correlation": 0.7904532588108066,
68
+ "p_value": 0.0
69
+ },
70
+ {
71
+ "feature_index": 17542,
72
+ "correlation": 0.7578473931771549,
73
+ "p_value": 0.0
74
+ },
75
+ {
76
+ "feature_index": 1296,
77
+ "correlation": 0.7135903991331536,
78
+ "p_value": 0.0
79
+ },
80
+ {
81
+ "feature_index": 10181,
82
+ "correlation": 0.7069358305925515,
83
+ "p_value": 0.0
84
+ },
85
+ {
86
+ "feature_index": 630,
87
+ "correlation": 0.6970711811966723,
88
+ "p_value": 0.0
89
+ },
90
+ {
91
+ "feature_index": 7283,
92
+ "correlation": 0.6888782329392438,
93
+ "p_value": 0.0
94
+ },
95
+ {
96
+ "feature_index": 9957,
97
+ "correlation": 0.6806925884098848,
98
+ "p_value": 0.0
99
+ },
100
+ {
101
+ "feature_index": 14756,
102
+ "correlation": 0.6763654166758747,
103
+ "p_value": 0.0
104
+ }
105
+ ],
106
+ "nationality": [
107
+ {
108
+ "feature_index": 16401,
109
+ "correlation": 0.6201132151869269,
110
+ "p_value": 0.0
111
+ },
112
+ {
113
+ "feature_index": 9956,
114
+ "correlation": 0.5991624893732725,
115
+ "p_value": 0.0
116
+ },
117
+ {
118
+ "feature_index": 14205,
119
+ "correlation": 0.5218992817048258,
120
+ "p_value": 6.72102973572993e-272
121
+ },
122
+ {
123
+ "feature_index": 19268,
124
+ "correlation": -0.47079023450986934,
125
+ "p_value": 1.0316744103724932e-214
126
+ },
127
+ {
128
+ "feature_index": 12899,
129
+ "correlation": 0.4698551338221134,
130
+ "p_value": 9.37287400288644e-214
131
+ },
132
+ {
133
+ "feature_index": 14102,
134
+ "correlation": 0.46625178173229587,
135
+ "p_value": 4.3347117601426704e-210
136
+ },
137
+ {
138
+ "feature_index": 18643,
139
+ "correlation": 0.4645600803424083,
140
+ "p_value": 2.200542075099065e-208
141
+ },
142
+ {
143
+ "feature_index": 22237,
144
+ "correlation": 0.4627396121645557,
145
+ "p_value": 1.4696460283983834e-206
146
+ },
147
+ {
148
+ "feature_index": 19683,
149
+ "correlation": 0.4586563357732425,
150
+ "p_value": 1.6591498330413268e-202
151
+ },
152
+ {
153
+ "feature_index": 10387,
154
+ "correlation": 0.4464409913279595,
155
+ "p_value": 1.0445581068141617e-190
156
+ }
157
+ ],
158
+ "profession": [
159
+ {
160
+ "feature_index": 19268,
161
+ "correlation": 0.6632217362569325,
162
+ "p_value": 0.0
163
+ },
164
+ {
165
+ "feature_index": 12738,
166
+ "correlation": 0.49434419021176973,
167
+ "p_value": 7.31805847873234e-240
168
+ },
169
+ {
170
+ "feature_index": 12240,
171
+ "correlation": 0.4822464487435754,
172
+ "p_value": 1.0652308403711134e-226
173
+ },
174
+ {
175
+ "feature_index": 3833,
176
+ "correlation": 0.4303762555805077,
177
+ "p_value": 6.587675550958632e-176
178
+ },
179
+ {
180
+ "feature_index": 16688,
181
+ "correlation": 0.42313751736840466,
182
+ "p_value": 1.6989194964768673e-169
183
+ },
184
+ {
185
+ "feature_index": 3658,
186
+ "correlation": 0.42073411062723626,
187
+ "p_value": 2.1105890047528277e-167
188
+ },
189
+ {
190
+ "feature_index": 4610,
191
+ "correlation": 0.4148617376031915,
192
+ "p_value": 2.344446120591354e-162
193
+ },
194
+ {
195
+ "feature_index": 3428,
196
+ "correlation": 0.4141357335326178,
197
+ "p_value": 9.7020391684076e-162
198
+ },
199
+ {
200
+ "feature_index": 9956,
201
+ "correlation": -0.41240054790897485,
202
+ "p_value": 2.850570734597228e-160
203
+ },
204
+ {
205
+ "feature_index": 14205,
206
+ "correlation": -0.4053658871423048,
207
+ "p_value": 2.081224608034156e-154
208
+ }
209
+ ],
210
+ "race": [
211
+ {
212
+ "feature_index": 11047,
213
+ "correlation": 0.40880904254801465,
214
+ "p_value": 2.924658096595988e-157
215
+ },
216
+ {
217
+ "feature_index": 6520,
218
+ "correlation": 0.32754592044407255,
219
+ "p_value": 2.399662718793716e-98
220
+ },
221
+ {
222
+ "feature_index": 18320,
223
+ "correlation": 0.31448088732356977,
224
+ "p_value": 2.1188408928978433e-90
225
+ },
226
+ {
227
+ "feature_index": 22312,
228
+ "correlation": 0.30658660427290435,
229
+ "p_value": 8.652381542873663e-86
230
+ },
231
+ {
232
+ "feature_index": 22263,
233
+ "correlation": 0.2775426576622264,
234
+ "p_value": 5.101978921250436e-70
235
+ },
236
+ {
237
+ "feature_index": 18492,
238
+ "correlation": 0.21686259675294528,
239
+ "p_value": 8.530010364272566e-43
240
+ },
241
+ {
242
+ "feature_index": 9459,
243
+ "correlation": 0.20860841730681046,
244
+ "p_value": 1.1641254990477776e-39
245
+ },
246
+ {
247
+ "feature_index": 22936,
248
+ "correlation": 0.20146560425468682,
249
+ "p_value": 4.7105090636264745e-37
250
+ },
251
+ {
252
+ "feature_index": 7170,
253
+ "correlation": 0.1848603817962114,
254
+ "p_value": 2.287836815814386e-31
255
+ },
256
+ {
257
+ "feature_index": 4798,
258
+ "correlation": 0.18320283221180658,
259
+ "p_value": 7.917784359744477e-31
260
+ }
261
+ ],
262
+ "religion": [
263
+ {
264
+ "feature_index": 18754,
265
+ "correlation": 0.4919855460080574,
266
+ "p_value": 2.968070741341591e-237
267
+ },
268
+ {
269
+ "feature_index": 17056,
270
+ "correlation": 0.4754689353084256,
271
+ "p_value": 1.4912893628128467e-219
272
+ },
273
+ {
274
+ "feature_index": 14242,
275
+ "correlation": 0.4479290137573203,
276
+ "p_value": 4.048270907596597e-192
277
+ },
278
+ {
279
+ "feature_index": 17886,
280
+ "correlation": 0.4412481619097158,
281
+ "p_value": 7.77154112070949e-186
282
+ },
283
+ {
284
+ "feature_index": 7870,
285
+ "correlation": 0.4383254166163272,
286
+ "p_value": 3.938542029802173e-183
287
+ },
288
+ {
289
+ "feature_index": 903,
290
+ "correlation": 0.42866564467200124,
291
+ "p_value": 2.2285535850517167e-174
292
+ },
293
+ {
294
+ "feature_index": 21637,
295
+ "correlation": 0.42829559715114474,
296
+ "p_value": 4.7608700745146375e-174
297
+ },
298
+ {
299
+ "feature_index": 7699,
300
+ "correlation": 0.4279351187362136,
301
+ "p_value": 9.96388834015504e-174
302
+ },
303
+ {
304
+ "feature_index": 23932,
305
+ "correlation": 0.4146661267040663,
306
+ "p_value": 3.4386189140897016e-162
307
+ },
308
+ {
309
+ "feature_index": 18686,
310
+ "correlation": 0.4145648062343114,
311
+ "p_value": 4.1927743639700714e-162
312
+ }
313
+ ]
314
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ flask>=2.3.0
4
+ flask-cors>=4.0.0
5
+ sae-lens>=3.0.0
6
+ huggingface-hub>=0.16.0
7
+ gradio>=4.0.0
8
+ numpy>=1.21.0
server.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flask import Flask, Response, request, send_from_directory, send_file
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from threading import Thread
5
+ from sae_lens import SAE
6
+ from flask_cors import CORS
7
+ from json import load
8
+ import os
9
+
10
+ # Example config reading
11
+ from config import datasets_config, models_config
12
+
13
+ # ------------------------------------
14
+ # Global Setup: load tokenizer/models
15
+ # ------------------------------------
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ device = "mps" if torch.backends.mps.is_available() else device
18
+
19
+ # Main tokenizer (GPT-2 style). Adjust if using different.
20
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
21
+ if tokenizer.pad_token_id is None:
22
+ tokenizer.pad_token_id = tokenizer.eos_token_id
23
+
24
+ # Original GPT-2
25
+ original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
26
+ original_model.eval()
27
+
28
+ # "Trained"/"biased" GPT-2 model
29
+ trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device)
30
+ trained_model.eval()
31
+
32
+ # ------------------------------------
33
+ # Steering Hook Setup (optional)
34
+ # ------------------------------------
35
+
36
+ # Example steering feature(s)
37
+ hooks = []
38
+ def generate_pre_hook(sae: SAE, index: int, coeff: float):
39
+ def steering_hook(module, inputs):
40
+ """
41
+ Simple version of a steering hook. Adds a weighted vector
42
+ to the residual. Customize if needed.
43
+ """
44
+ residual = inputs[0]
45
+ steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
46
+ residual = residual + coeff * steering_vector
47
+ return (residual)
48
+ return steering_hook
49
+ def generate_post_hook(sae: SAE, index: int, coeff: float):
50
+ def steering_hook(module, inputs, outputs):
51
+ """
52
+ Simple version of a steering hook. Adds a weighted vector
53
+ to the residual. Customize if needed.
54
+ """
55
+ residual = outputs[0]
56
+ steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
57
+ residual = residual + coeff * steering_vector
58
+ return (residual, outputs[1], outputs[2])
59
+ return steering_hook
60
+
61
+ def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str):
62
+ file_path = f"features/{model_key}.{dataset_key}.json"
63
+ with open(file_path, "r") as f:
64
+ feature_map = load(f)
65
+ top_features = feature_map[category_key]
66
+ if "+" in gen_type:
67
+ coeff = 75
68
+ elif "-" in gen_type:
69
+ coeff = 50
70
+ if "+" in gen_type:
71
+ filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
72
+ elif "-" in gen_type:
73
+ filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features))
74
+ if len(filtered_features) == 0:
75
+ filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
76
+ coeff = 75
77
+ top_feature = filtered_features[0]
78
+
79
+ hook_point = "blocks.11.hook_resid_pre"
80
+ block_idx = int(hook_point.split(".")[1])
81
+ index = top_feature["feature_index"]
82
+
83
+ sae, cfg_dict, sparsity = SAE.from_pretrained(
84
+ models_config[model_key]["sae"],
85
+ hook_point,
86
+ device=device,
87
+ )
88
+
89
+ module = model.transformer.h[block_idx]
90
+ if "pre" in hook_point:
91
+ handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff))
92
+ elif "post" in hook_point:
93
+ handle = module.register_forward_hook(generate_post_hook(sae, index, coeff))
94
+ hooks.append(handle)
95
+
96
+ def remove_hooks():
97
+ for h in hooks:
98
+ h.remove()
99
+ hooks.clear()
100
+
101
+ # ------------------------------------
102
+ # Helper: streaming generator
103
+ # ------------------------------------
104
+ def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0):
105
+ """
106
+ Yields tokens as they are generated in a separate thread.
107
+ """
108
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
109
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
110
+
111
+ generation_kwargs = dict(
112
+ **inputs,
113
+ streamer=streamer,
114
+ max_new_tokens=max_new_tokens,
115
+ do_sample=True,
116
+ temperature=temperature,
117
+ top_p=top_p,
118
+ pad_token_id=tokenizer.eos_token_id,
119
+ repetition_penalty=repetition_penalty,
120
+ )
121
+
122
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
123
+ thread.start()
124
+
125
+ for new_text in streamer:
126
+ yield new_text
127
+
128
+ # ------------------------------------
129
+ # Flask App
130
+ # ------------------------------------
131
+ app = Flask(__name__)
132
+ CORS(app)
133
+
134
+ # API routes first (to avoid conflicts with static serving)
135
+ @app.route("/api/generate", methods=["POST"])
136
+ def generate():
137
+ """
138
+ Expects JSON like:
139
+ {
140
+ "model": "gpt2",
141
+ "dataset": "emgsd",
142
+ "category": "lgbtq+",
143
+ "type": "original" | "origin+steer" | "trained" | "trained-steer"
144
+ }
145
+ Streams back the generated text token by token.
146
+ """
147
+ data = request.json
148
+ model_key = data["model"]
149
+ dataset_key = data["dataset"]
150
+ category_key = data["category"]
151
+ gen_type = data["type"]
152
+
153
+ # 1. Figure out prompt from config
154
+ try:
155
+ prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"]
156
+ except KeyError:
157
+ return Response("Invalid dataset/category combination.", status=400)
158
+
159
+ # 2. Select the model
160
+ if "trained" in gen_type:
161
+ chosen_model = trained_model
162
+ else:
163
+ chosen_model = original_model
164
+
165
+ # 3. Steering logic if "steer" in the request type
166
+ remove_hooks()
167
+ if "steer" in gen_type:
168
+ register_steering(chosen_model, model_key, gen_type, dataset_key, category_key)
169
+
170
+ # Return a streaming response of tokens
171
+ def token_stream():
172
+ for token in stream_generate(chosen_model, prompt_text):
173
+ yield token
174
+ remove_hooks()
175
+
176
+ return Response(token_stream(), mimetype="text/event-stream")
177
+
178
+
179
+ # Serve static files for HF Spaces (after API routes)
180
+ @app.route("/")
181
+ def serve_frontend():
182
+ return send_file("demo/dist/index.html")
183
+
184
+ @app.route("/<path:path>")
185
+ def serve_static(path):
186
+ if path.startswith('api/'):
187
+ return None
188
+ if os.path.exists(f"demo/dist/{path}"):
189
+ return send_from_directory("demo/dist", path)
190
+ return send_file("demo/dist/index.html")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ port = int(os.environ.get("PORT", 5174))
195
+ app.run(host="0.0.0.0", port=port, debug=True)
start.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set environment variables for HF Spaces
4
+ export VITE_API_BASE_URL=""
5
+
6
+ # Start Flask backend in the background
7
+ echo "Starting Flask backend..."
8
+ python server.py &
9
+
10
+ # Wait for backend to start
11
+ sleep 10
12
+
13
+ # Start frontend development server
14
+ echo "Starting React frontend..."
15
+ cd demo
16
+ pnpm dev --host 0.0.0.0 --port 7860 &
17
+
18
+ # Keep the container running
19
+ wait
20
+
21
+ ->
22
+
23
+ #!/bin/bash
24
+
25
+ # Set environment variables for HF Spaces
26
+ export PORT=7860
27
+
28
+ # Start Flask server (serves both API and frontend)
29
+ echo "Starting CorrSteer application..."
30
+ python server.py