clip
vision
text

Matryoshka Sparse Autoencoders (MSAE) for CLIP

This repository provides PyTorch implementations of Matryoshka Sparse Autoencoders (MSAEs) trained on the image encoder of CLIP (ViT-L/14 and ViT-B/16). These models are designed to learn interpretable, hierarchical features from complex multimodal representations.

For a deeper dive into the underlying theory and the full research implementation, please see the original MSAE repository and the accompanying paper.

What is a Sparse Autoencoder (SAE)?

Sparse autoencoders (SAEs) are useful for detecting and steering interpretable features within complex neural networks. They learn to represent complex data in a sparse manner, meaning that only a small number of neurons are activated at any given time which enable to reconstruct the input data. This sparsity leads to more interpretable representations, as each active neuron can be associated with a specific feature or concept. As a result, SAEs can be used to identify and manipulate specific features in the data, making them powerful tools for understanding and controlling the behavior of neural networks.

Key Features

  • Interpretability: SAEs learn to decompose complex representations into sparse, interpretable features. This allows for a better understanding of what the model has learned.
  • Hierarchical Features: The Matryoshka SAE (MSAE) architecture learns features at multiple granularities simultaneously, from fine-grained details to high-level concepts.
  • Model Steering: By identifying and manipulating specific features, you can steer the behavior of the CLIP model.
  • Simple Integration: The provided sae.py module allows for easy loading and integration of the trained models into your own projects.

Repository Structure

The repository is organized as follows:

  • sae.py: A self-contained Python module with the SAE and MSAE model implementations to run the inference.
  • clip_disect_20k.txt: A vocabulary file containing 20,000 concept names used for interpreting the learned features.
  • ViT-L_14/: Contains the trained SAE models for the CLIP ViT-L/14 image encoder.
  • ViT-B_16/: Contains the trained SAE models for the CLIP ViT-B/16 image encoder.

Each model directory (ViT-L_14 and ViT-B_16) is further subdivided into:

  • centered/: Models trained on mean-centered features.
  • not_centered/: Models trained on non-centered features.

Additionally, each directory contains .pth files for the model weights and .npy files for the concept matching scores.

Understanding the Model Names

The model filenames follow a consistent naming convention that encodes the model's hyperparameters. Here's how to interpret a typical filename:

{n_latents}_{n_inputs}_{activation}_{k}_{weighting}_{tied}_{normalized}_{soft_cap}_{dataset}.pth

Where:

  • n_latents: The number of latent features in the SAE.
  • n_inputs: The input dimensionality (e.g., 768 for ViT-L/14, 512 for ViT-B/16).
  • activation: The activation function used (e.g., TopKReLU).
  • k: The number of smallest trained active latents for the TopK activation.
  • weighting: Whether the model was trained with uniform weighting (UW) or reverse weighting (RW).
  • tied: Indicates if the model encoder is tied to the decoder.
  • normalized: Indicates if the model was trained with normalized inputs.
  • soft_cap: Indicates if the model uses soft capping for the latent features.
  • dataset: The dataset used for training (e.g., cc3m).

The concept matching scores are stored in .npy files with a similar naming convention: Concept_Interpreter_{model_name}_{vocab_name}.npy, where vocab_name indicates the vocabulary used for concept matching.

How to Use

To get started, you'll need to have PyTorch and NumPy installed.

pip install torch numpy

First, copy the sae.py file to your working directory. Then, you can load a model and its corresponding concept vocabulary as follows:

import torch
import numpy as np
from sae import SAE
from huggingface_hub import hf_hub_download

# Download the SAE model weights
weights_path = hf_hub_download(
    repo_id="WolodjaZ/MSAE",
    filename="ViT-L_14/centered/6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth"
)
sae_model = SAE(weights_path)

# Download the concept matching scores for the model
vocab_path = hf_hub_download(
    repo_id="WolodjaZ/MSAE",
    filename="ViT-L_14/centered/Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy"
)
concept_match_scores = np.load(vocab_path)

# Load the vocabulary names
with open('clip_disect_20k.txt', 'r') as f:
    vocab_names = [line.strip() for line in f.readlines()]

print(f"Concept match scores shape: {concept_match_scores.shape}")
print(f"Vocabulary size: {len(vocab_names)}")

# Now you can use the model to encode and decode your own data
# For a detailed example, please refer to the demo notebook in the original repository:
# https://github.com/WolodjaZ/MSAE/blob/main/demo.ipynb

This example demonstrates how to load a specific SAE model and its associated concept names. You can adapt the filename in hf_hub_download to load any of the other available models. For a complete guide on how to use the model for feature extraction and steering, please refer to the demo notebook in the original MSAE repository.

Citation

Paper: https://arxiv.org/abs/2502.20578

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for WolodjaZ/MSAE

Finetuned
(41)
this model

Dataset used to train WolodjaZ/MSAE