import torch from flask import Flask, Response, request, send_from_directory, send_file from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread from sae_lens import SAE from flask_cors import CORS from json import load import os # Example config reading from config import datasets_config, models_config # ------------------------------------ # Global Setup: load tokenizer/models # ------------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" device = "mps" if torch.backends.mps.is_available() else device # Main tokenizer (GPT-2 style). Adjust if using different. tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Original GPT-2 original_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) original_model.eval() # "Trained"/"biased" GPT-2 model trained_model = AutoModelForCausalLM.from_pretrained("holistic-ai/gpt2-EMGSD").to(device) trained_model.eval() # ------------------------------------ # Steering Hook Setup (optional) # ------------------------------------ # Example steering feature(s) hooks = [] def generate_pre_hook(sae: SAE, index: int, coeff: float): def steering_hook(module, inputs): """ Simple version of a steering hook. Adds a weighted vector to the residual. Customize if needed. """ residual = inputs[0] steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0) residual = residual + coeff * steering_vector return (residual) return steering_hook def generate_post_hook(sae: SAE, index: int, coeff: float): def steering_hook(module, inputs, outputs): """ Simple version of a steering hook. Adds a weighted vector to the residual. Customize if needed. """ residual = outputs[0] steering_vector = sae.W_dec[index].to(device).unsqueeze(0).unsqueeze(0) residual = residual + coeff * steering_vector return (residual, outputs[1], outputs[2]) return steering_hook def register_steering(model, model_key: str, gen_type: str, dataset_key: str, category_key: str): file_path = f"features/{model_key}.{dataset_key}.json" with open(file_path, "r") as f: feature_map = load(f) top_features = feature_map[category_key] if "+" in gen_type: coeff = 75 elif "-" in gen_type: coeff = 50 if "+" in gen_type: filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features)) elif "-" in gen_type: filtered_features = list(filter(lambda x: x["correlation"] < 0, top_features)) if len(filtered_features) == 0: filtered_features = list(filter(lambda x: x["correlation"] > 0, top_features)) coeff = 75 top_feature = filtered_features[0] hook_point = "blocks.11.hook_resid_pre" block_idx = int(hook_point.split(".")[1]) index = top_feature["feature_index"] sae, cfg_dict, sparsity = SAE.from_pretrained( models_config[model_key]["sae"], hook_point, device=device, ) module = model.transformer.h[block_idx] if "pre" in hook_point: handle = module.register_forward_pre_hook(generate_pre_hook(sae, index, coeff)) elif "post" in hook_point: handle = module.register_forward_hook(generate_post_hook(sae, index, coeff)) hooks.append(handle) def remove_hooks(): for h in hooks: h.remove() hooks.clear() # ------------------------------------ # Helper: streaming generator # ------------------------------------ def stream_generate(model, prompt, max_new_tokens=50, temperature=1.0, top_p=0.1, repetition_penalty=10.0): """ Yields tokens as they are generated in a separate thread. """ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) inputs = tokenizer(prompt, return_tensors="pt").to(device) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, repetition_penalty=repetition_penalty, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for new_text in streamer: yield new_text # ------------------------------------ # Flask App # ------------------------------------ app = Flask(__name__) CORS(app) # API routes first (to avoid conflicts with static serving) @app.route("/api/generate", methods=["POST"]) def generate(): """ Expects JSON like: { "model": "gpt2", "dataset": "emgsd", "category": "lgbtq+", "type": "original" | "origin+steer" | "trained" | "trained-steer" } Streams back the generated text token by token. """ data = request.json model_key = data["model"] dataset_key = data["dataset"] category_key = data["category"] gen_type = data["type"] # 1. Figure out prompt from config try: prompt_text = datasets_config[dataset_key]["category"][category_key]["prompt"] except KeyError: return Response("Invalid dataset/category combination.", status=400) # 2. Select the model if "trained" in gen_type: chosen_model = trained_model else: chosen_model = original_model # 3. Steering logic if "steer" in the request type remove_hooks() if "steer" in gen_type: register_steering(chosen_model, model_key, gen_type, dataset_key, category_key) # Return a streaming response of tokens def token_stream(): for token in stream_generate(chosen_model, prompt_text): yield token remove_hooks() return Response(token_stream(), mimetype="text/event-stream") # Serve static files for HF Spaces (after API routes) @app.route("/") def serve_frontend(): return send_file("demo/dist/index.html") @app.route("/") def serve_static(path): if path.startswith('api/'): return None if os.path.exists(f"demo/dist/{path}"): return send_from_directory("demo/dist", path) return send_file("demo/dist/index.html") if __name__ == "__main__": port = int(os.environ.get("PORT", 5174)) app.run(host="0.0.0.0", port=port, debug=True)