CorrSteer / server.py
seonglae's picture
feat: hf space corr-steer
889f722
import torch
from flask import Flask, Response, request, send_from_directory, send_file
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from sae_lens import SAE
from flask_cors import CORS
from json import load
import os
# Example config reading
from config import datasets_config, models_config
# ------------------------------------
# Global Setup: load tokenizer/models
# ------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
# Main tokenizer (GPT-2 style). Adjust if using different.
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Original GPT-2
original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
original_model.eval()
# "Trained"/"biased" GPT-2 model
trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device)
trained_model.eval()
# ------------------------------------
# Steering Hook Setup (optional)
# ------------------------------------
# Example steering feature(s)
hooks = []
def generate_pre_hook(sae: SAE, index: int, coeff: float):
def steering_hook(module, inputs):
"""
Simple version of a steering hook. Adds a weighted vector
to the residual. Customize if needed.
"""
residual = inputs[0]
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
residual = residual + coeff * steering_vector
return (residual)
return steering_hook
def generate_post_hook(sae: SAE, index: int, coeff: float):
def steering_hook(module, inputs, outputs):
"""
Simple version of a steering hook. Adds a weighted vector
to the residual. Customize if needed.
"""
residual = outputs[0]
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
residual = residual + coeff * steering_vector
return (residual, outputs[1], outputs[2])
return steering_hook
def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str):
file_path = f"features/{model_key}.{dataset_key}.json"
with open(file_path, "r") as f:
feature_map = load(f)
top_features = feature_map[category_key]
if "+" in gen_type:
coeff = 75
elif "-" in gen_type:
coeff = 50
if "+" in gen_type:
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
elif "-" in gen_type:
filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features))
if len(filtered_features) == 0:
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
coeff = 75
top_feature = filtered_features[0]
hook_point = "blocks.11.hook_resid_pre"
block_idx = int(hook_point.split(".")[1])
index = top_feature["feature_index"]
sae, cfg_dict, sparsity = SAE.from_pretrained(
models_config[model_key]["sae"],
hook_point,
device=device,
)
module = model.transformer.h[block_idx]
if "pre" in hook_point:
handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff))
elif "post" in hook_point:
handle = module.register_forward_hook(generate_post_hook(sae, index, coeff))
hooks.append(handle)
def remove_hooks():
for h in hooks:
h.remove()
hooks.clear()
# ------------------------------------
# Helper: streaming generator
# ------------------------------------
def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0):
"""
Yields tokens as they are generated in a separate thread.
"""
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=repetition_penalty,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
# ------------------------------------
# Flask App
# ------------------------------------
app = Flask(__name__)
CORS(app)
# API routes first (to avoid conflicts with static serving)
@app.route("/api/generate", methods=["POST"])
def generate():
"""
Expects JSON like:
{
"model": "gpt2",
"dataset": "emgsd",
"category": "lgbtq+",
"type": "original" | "origin+steer" | "trained" | "trained-steer"
}
Streams back the generated text token by token.
"""
data = request.json
model_key = data["model"]
dataset_key = data["dataset"]
category_key = data["category"]
gen_type = data["type"]
# 1. Figure out prompt from config
try:
prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"]
except KeyError:
return Response("Invalid dataset/category combination.", status=400)
# 2. Select the model
if "trained" in gen_type:
chosen_model = trained_model
else:
chosen_model = original_model
# 3. Steering logic if "steer" in the request type
remove_hooks()
if "steer" in gen_type:
register_steering(chosen_model, model_key, gen_type, dataset_key, category_key)
# Return a streaming response of tokens
def token_stream():
for token in stream_generate(chosen_model, prompt_text):
yield token
remove_hooks()
return Response(token_stream(), mimetype="text/event-stream")
# Serve static files for HF Spaces (after API routes)
@app.route("/")
def serve_frontend():
return send_file("demo/dist/index.html")
@app.route("/<path:path>")
def serve_static(path):
if path.startswith('api/'):
return None
if os.path.exists(f"demo/dist/{path}"):
return send_from_directory("demo/dist", path)
return send_file("demo/dist/index.html")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 5174))
app.run(host="0.0.0.0", port=port, debug=True)