File size: 6,676 Bytes
889f722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import torch
from flask import Flask, Response, request, send_from_directory, send_file
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from sae_lens import SAE
from flask_cors import CORS
from json import load
import os
# Example config reading
from config import datasets_config, models_config
# ------------------------------------
# Global Setup: load tokenizer/models
# ------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
# Main tokenizer (GPT-2 style). Adjust if using different.
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Original GPT-2
original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
original_model.eval()
# "Trained"/"biased" GPT-2 model
trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device)
trained_model.eval()
# ------------------------------------
# Steering Hook Setup (optional)
# ------------------------------------
# Example steering feature(s)
hooks = []
def generate_pre_hook(sae: SAE, index: int, coeff: float):
def steering_hook(module, inputs):
"""
Simple version of a steering hook. Adds a weighted vector
to the residual. Customize if needed.
"""
residual = inputs[0]
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
residual = residual + coeff * steering_vector
return (residual)
return steering_hook
def generate_post_hook(sae: SAE, index: int, coeff: float):
def steering_hook(module, inputs, outputs):
"""
Simple version of a steering hook. Adds a weighted vector
to the residual. Customize if needed.
"""
residual = outputs[0]
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
residual = residual + coeff * steering_vector
return (residual, outputs[1], outputs[2])
return steering_hook
def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str):
file_path = f"features/{model_key}.{dataset_key}.json"
with open(file_path, "r") as f:
feature_map = load(f)
top_features = feature_map[category_key]
if "+" in gen_type:
coeff = 75
elif "-" in gen_type:
coeff = 50
if "+" in gen_type:
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
elif "-" in gen_type:
filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features))
if len(filtered_features) == 0:
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
coeff = 75
top_feature = filtered_features[0]
hook_point = "blocks.11.hook_resid_pre"
block_idx = int(hook_point.split(".")[1])
index = top_feature["feature_index"]
sae, cfg_dict, sparsity = SAE.from_pretrained(
models_config[model_key]["sae"],
hook_point,
device=device,
)
module = model.transformer.h[block_idx]
if "pre" in hook_point:
handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff))
elif "post" in hook_point:
handle = module.register_forward_hook(generate_post_hook(sae, index, coeff))
hooks.append(handle)
def remove_hooks():
for h in hooks:
h.remove()
hooks.clear()
# ------------------------------------
# Helper: streaming generator
# ------------------------------------
def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0):
"""
Yields tokens as they are generated in a separate thread.
"""
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=repetition_penalty,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
# ------------------------------------
# Flask App
# ------------------------------------
app = Flask(__name__)
CORS(app)
# API routes first (to avoid conflicts with static serving)
@app.route("/api/generate", methods=["POST"])
def generate():
"""
Expects JSON like:
{
"model": "gpt2",
"dataset": "emgsd",
"category": "lgbtq+",
"type": "original" | "origin+steer" | "trained" | "trained-steer"
}
Streams back the generated text token by token.
"""
data = request.json
model_key = data["model"]
dataset_key = data["dataset"]
category_key = data["category"]
gen_type = data["type"]
# 1. Figure out prompt from config
try:
prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"]
except KeyError:
return Response("Invalid dataset/category combination.", status=400)
# 2. Select the model
if "trained" in gen_type:
chosen_model = trained_model
else:
chosen_model = original_model
# 3. Steering logic if "steer" in the request type
remove_hooks()
if "steer" in gen_type:
register_steering(chosen_model, model_key, gen_type, dataset_key, category_key)
# Return a streaming response of tokens
def token_stream():
for token in stream_generate(chosen_model, prompt_text):
yield token
remove_hooks()
return Response(token_stream(), mimetype="text/event-stream")
# Serve static files for HF Spaces (after API routes)
@app.route("/")
def serve_frontend():
return send_file("demo/dist/index.html")
@app.route("/<path:path>")
def serve_static(path):
if path.startswith('api/'):
return None
if os.path.exists(f"demo/dist/{path}"):
return send_from_directory("demo/dist", path)
return send_file("demo/dist/index.html")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 5174))
app.run(host="0.0.0.0", port=port, debug=True)
|