This is the sparse autoencoder used in the demo. Here's the code I've been using to steer:

import torch
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import json

# Minimal SAE implementation
class SAE(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.encode = torch.nn.Linear(input_size, hidden_size, bias=True)
        self.decode = torch.nn.Linear(hidden_size, input_size, bias=True)
    
    def forward(self, x):
        features = torch.nn.functional.relu(self.encode(x))
        reconstruction = self.decode(features)
        return reconstruction, features

# Minimal steerable model wrapper
class SteerableOlmo2:
    def __init__(self, model, sae, steering_layer):
        self.model = model
        self.sae = sae
        self.steering_layer = steering_layer
        self.steering_features = {}
        self.hook = None
        self._register_hook()
    
    def _steering_hook(self, module, input, output):
        if not self.steering_features:
            return output
        
        hidden_states = output[0]
        recon, feats = self.sae(hidden_states)
        error = hidden_states - recon
        
        # Apply steering
        feats_steered = feats.clone()
        for idx, value in self.steering_features.items():
            feats_steered[..., idx] = value
        
        # Reconstruct with steering
        recon_steered = self.sae.decode(feats_steered)
        hidden_steered = recon_steered + error
        
        return (hidden_steered,) + output[1:]
    
    def _register_hook(self):
        target_layer = self.model.model.layers[self.steering_layer]
        self.hook = target_layer.register_forward_hook(self._steering_hook)
    
    def set_steering(self, feature_idx, value):
        self.steering_features[feature_idx] = value
    
    def clear_steering(self):
        self.steering_features = {}
    
    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

# Usage example
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load model and tokenizer
    from transformers import Olmo2ForCausalLM
    model_name = "allenai/OLMo-2-1124-7B-Instruct"
    model = Olmo2ForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Download and load SAE
    sae_weights_path = hf_hub_download(
        repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
        filename="sae_weights.safetensors"
    )
    sae_config_path = hf_hub_download(
        repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
        filename="sae_config.json"
    )
    
    # Load SAE
    sae_weights = load_file(sae_weights_path, device=device)
    with open(sae_config_path, "r") as f:
        sae_config = json.load(f)
    
    sae = SAE(sae_config['input_size'], sae_config['hidden_size']).to(device).to(torch.bfloat16)
    sae.load_state_dict(sae_weights)
    
    # Create steerable model
    steering_layer = 15  # Middle layer
    steerable_model = SteerableOlmo2(model, sae, steering_layer)
    
    # Example: Steer towards Batman/superhero concept
    prompt = "What's your favorite hobby?"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate without steering
    print("Without steering:")
    outputs = steerable_model.generate(inputs.input_ids, max_new_tokens=50, do_sample=True, temperature=0.7)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    
    # Generate with steering (feature 758 = Batman/superhero, strength ~11)
    print("\nWith Batman/superhero steering:")
    steerable_model.set_steering(758, 11.0)
    # Other available steering features (uncomment to try):
    # steerable_model.set_steering(29940, 13.0)  # Japan feature
    # steerable_model.set_steering(65023, 6.0)   # Baseball feature
    outputs = steerable_model.generate(inputs.input_ids, max_new_tokens=50, do_sample=True, temperature=0.7)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    
    
    # Clear steering
    steerable_model.clear_steering()

if __name__ == "__main__":
    main() 
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for open-concept-steering/olmo2-7b-sae-65k-v1

Datasets used to train open-concept-steering/olmo2-7b-sae-65k-v1

Space using open-concept-steering/olmo2-7b-sae-65k-v1 1