DeepSeek-Tiny with MLA-o V0.1

6-layer DeepSeek-V3 with MLA + shared output latent space ("MLA-o") trained for research on shared subspaces in Transformer attention mechanisms.

Model Description

  • Model Type: Transformer Decoder (DeepSeek-V3 based)
  • Architecture: 6-layer decoder with Mixture of Experts
  • Parameters: 16.17M
  • Hidden Size: 256
  • Attention Heads: 8
  • Head Dimension: 32
  • Sequence Length: 1,024 tokens
  • Query Latent Dimension: 96
  • Key-Value Latent Dimension: 64
  • Output Latent Dimension: 96

Performance

  • SST-2 Accuracy: 86.24%
  • WikiText-103 Perplexity: 29.33

Research Context

This model is part of the shared-subspaces research project investigating the impact of shared output latent spaces in Transformer attention mechanisms.

Output Subspace Decomposition

This model implements a shared output latent space where the attention output projection W^O is decomposed into:

W^O = W^OA · W^OB

Where W^OA are per-head projections to the latent space and W^OB is a shared projection back to the model dimension.

Usage

Rather than overwrite the entire attention layer, we simply patched the o_proj parameter with a nn.Sequential. It's an easy way to modify the model prior to pre-training, but loading the weights is a different story.

The below code applies the patch, and then loads in the necessary weights manually.

import torch
import torch.nn as nn
from transformers import DeepseekV3ForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

def load_mla_o_model(repo_id="ChrisMcCormick/deepseek-tiny-mla-o-v0.1"):
    """
    Load the MLA-o model with output subspace decomposition
    """
    
    print("\n<<Ignore the 'weights not used' warning>>\n")

    # Load base model (without decomposed weights)
    model = DeepseekV3ForCausalLM.from_pretrained(repo_id)
    tokenizer = AutoTokenizer.from_pretrained(repo_id)

    print("\nPatching weights...\n")

    # Download the safetensors file to get the decomposed weights
    weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    weights = load_file(weights_path)
    
    # Apply output subspace decomposition to all attention layers
    for layer_idx, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        
        # Calculate dimensions
        in_features = attn.num_heads * attn.v_head_dim  # 8 * 32 = 256
        o_latent_dim = 96  # Output latent dimension
        out_features = model.config.hidden_size  # 256
        bias = bool(getattr(model.config, "attention_bias", False))
        
        # Replace o_proj with sequential decomposition
        attn.o_proj = nn.Sequential(
            nn.Linear(in_features, o_latent_dim, bias=False),      # W^OA: 256 -> 96
            nn.RMSNorm(o_latent_dim, eps=model.config.rms_norm_eps),  # Normalization
            nn.Linear(o_latent_dim, out_features, bias=bias),      # W^OB: 96 -> 256
        )
        
        # Load the decomposed weights
        layer_prefix = f"model.layers.{layer_idx}.self_attn.o_proj"
        
        # Load W^OA weights (o_proj.0.weight)
        w_oa_key = f"{layer_prefix}.0.weight"
        if w_oa_key in weights:
            attn.o_proj[0].weight.data = weights[w_oa_key]
        
        # Load RMSNorm weights (o_proj.1.weight)
        w_norm_key = f"{layer_prefix}.1.weight"
        if w_norm_key in weights:
            attn.o_proj[1].weight.data = weights[w_norm_key]
        
        # Load W^OB weights (o_proj.2.weight)
        w_ob_key = f"{layer_prefix}.2.weight"
        if w_ob_key in weights:
            attn.o_proj[2].weight.data = weights[w_ob_key]
        
        # Load W^OB bias if it exists
        w_ob_bias_key = f"{layer_prefix}.2.bias"
        if w_ob_bias_key in weights and attn.o_proj[2].bias is not None:
            attn.o_proj[2].bias.data = weights[w_ob_bias_key]
    
    print("Model loaded and patched.")
    return model, tokenizer

# Load the model
model, tokenizer = load_mla_o_model()

# Generate text
inputs = tokenizer("The future of AI is", return_tensors="pt")
with torch.no_grad():
    outputs = model.generate(
        **inputs, 
        max_length=50, 
        temperature=0.7, 
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Training Details

  • Pre-training Dataset: WikiText-103
  • Optimizer: AdamW
  • Learning Rate: 5e-4
  • Weight Decay: 0.01
  • Precision: bfloat16
  • Compilation: torch.compile with inductor backend
  • Training Steps: 12,500
  • Effective Batch Size: 1,024

Limitations

  • Small scale model (16M parameters) intended for research purposes
  • Trained on limited data compared to production models
Downloads last month
20
Safetensors
Model size
17M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train ChrisMcCormick/deepseek-tiny-mla-o-v0.1