|
import torch
|
|
from flask import Flask, Response, request, send_from_directory, send_file
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
from threading import Thread
|
|
from sae_lens import SAE
|
|
from flask_cors import CORS
|
|
from json import load
|
|
import os
|
|
|
|
|
|
from config import datasets_config, models_config
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
device = "mps" if torch.backends.mps.is_available() else device
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
|
original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
|
|
original_model.eval()
|
|
|
|
|
|
trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device)
|
|
trained_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hooks = []
|
|
def generate_pre_hook(sae: SAE, index: int, coeff: float):
|
|
def steering_hook(module, inputs):
|
|
"""
|
|
Simple version of a steering hook. Adds a weighted vector
|
|
to the residual. Customize if needed.
|
|
"""
|
|
residual = inputs[0]
|
|
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
|
|
residual = residual + coeff * steering_vector
|
|
return (residual)
|
|
return steering_hook
|
|
def generate_post_hook(sae: SAE, index: int, coeff: float):
|
|
def steering_hook(module, inputs, outputs):
|
|
"""
|
|
Simple version of a steering hook. Adds a weighted vector
|
|
to the residual. Customize if needed.
|
|
"""
|
|
residual = outputs[0]
|
|
steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0)
|
|
residual = residual + coeff * steering_vector
|
|
return (residual, outputs[1], outputs[2])
|
|
return steering_hook
|
|
|
|
def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str):
|
|
file_path = f"features/{model_key}.{dataset_key}.json"
|
|
with open(file_path, "r") as f:
|
|
feature_map = load(f)
|
|
top_features = feature_map[category_key]
|
|
if "+" in gen_type:
|
|
coeff = 75
|
|
elif "-" in gen_type:
|
|
coeff = 50
|
|
if "+" in gen_type:
|
|
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
|
|
elif "-" in gen_type:
|
|
filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features))
|
|
if len(filtered_features) == 0:
|
|
filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features))
|
|
coeff = 75
|
|
top_feature = filtered_features[0]
|
|
|
|
hook_point = "blocks.11.hook_resid_pre"
|
|
block_idx = int(hook_point.split(".")[1])
|
|
index = top_feature["feature_index"]
|
|
|
|
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
|
models_config[model_key]["sae"],
|
|
hook_point,
|
|
device=device,
|
|
)
|
|
|
|
module = model.transformer.h[block_idx]
|
|
if "pre" in hook_point:
|
|
handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff))
|
|
elif "post" in hook_point:
|
|
handle = module.register_forward_hook(generate_post_hook(sae, index, coeff))
|
|
hooks.append(handle)
|
|
|
|
def remove_hooks():
|
|
for h in hooks:
|
|
h.remove()
|
|
hooks.clear()
|
|
|
|
|
|
|
|
|
|
def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0):
|
|
"""
|
|
Yields tokens as they are generated in a separate thread.
|
|
"""
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
generation_kwargs = dict(
|
|
**inputs,
|
|
streamer=streamer,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
repetition_penalty=repetition_penalty,
|
|
)
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
thread.start()
|
|
|
|
for new_text in streamer:
|
|
yield new_text
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
|
|
@app.route("/api/generate", methods=["POST"])
|
|
def generate():
|
|
"""
|
|
Expects JSON like:
|
|
{
|
|
"model": "gpt2",
|
|
"dataset": "emgsd",
|
|
"category": "lgbtq+",
|
|
"type": "original" | "origin+steer" | "trained" | "trained-steer"
|
|
}
|
|
Streams back the generated text token by token.
|
|
"""
|
|
data = request.json
|
|
model_key = data["model"]
|
|
dataset_key = data["dataset"]
|
|
category_key = data["category"]
|
|
gen_type = data["type"]
|
|
|
|
|
|
try:
|
|
prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"]
|
|
except KeyError:
|
|
return Response("Invalid dataset/category combination.", status=400)
|
|
|
|
|
|
if "trained" in gen_type:
|
|
chosen_model = trained_model
|
|
else:
|
|
chosen_model = original_model
|
|
|
|
|
|
remove_hooks()
|
|
if "steer" in gen_type:
|
|
register_steering(chosen_model, model_key, gen_type, dataset_key, category_key)
|
|
|
|
|
|
def token_stream():
|
|
for token in stream_generate(chosen_model, prompt_text):
|
|
yield token
|
|
remove_hooks()
|
|
|
|
return Response(token_stream(), mimetype="text/event-stream")
|
|
|
|
|
|
|
|
@app.route("/")
|
|
def serve_frontend():
|
|
return send_file("demo/dist/index.html")
|
|
|
|
@app.route("/<path:path>")
|
|
def serve_static(path):
|
|
if path.startswith('api/'):
|
|
return None
|
|
if os.path.exists(f"demo/dist/{path}"):
|
|
return send_from_directory("demo/dist", path)
|
|
return send_file("demo/dist/index.html")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
port = int(os.environ.get("PORT", 5174))
|
|
app.run(host="0.0.0.0", port=port, debug=True)
|
|
|