File size: 5,102 Bytes
1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 1845c59 c02b2f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
---
license: mit
language:
- en
datasets:
- wikitext
- glue
pipeline_tag: text-generation
tags:
- transformer
- attention
- mla
- research
- output-subspace
---
# DeepSeek-Tiny with MLA-o V0.1
6-layer DeepSeek-V3 with MLA + shared output latent space ("MLA-o") trained for research on shared subspaces in Transformer attention mechanisms.
## Model Description
- **Model Type**: Transformer Decoder (DeepSeek-V3 based)
- **Architecture**: 6-layer decoder with Mixture of Experts
- **Parameters**: 16.17M
- **Hidden Size**: 256
- **Attention Heads**: 8
- **Head Dimension**: 32
- **Sequence Length**: 1,024 tokens
- **Query Latent Dimension**: 96
- **Key-Value Latent Dimension**: 64
- **Output Latent Dimension**: 96
## Performance
- **SST-2 Accuracy**: 86.24%
- **WikiText-103 Perplexity**: 29.33
## Research Context
This model is part of the [shared-subspaces](https://github.com/chrisjmccormick/shared-subspaces) research project investigating the impact of shared output latent spaces in Transformer attention mechanisms.
### Output Subspace Decomposition
This model implements a shared output latent space where the attention output projection W^O is decomposed into:
```
W^O = W^OA · W^OB
```
Where W^OA are per-head projections to the latent space and W^OB is a shared projection back to the model dimension.
## Usage
Rather than overwrite the entire attention layer, we simply patched the `o_proj` parameter with a `nn.Sequential`. It's an easy way to modify the model prior to pre-training, but loading the weights is a different story.
The below code applies the patch, and then loads in the necessary weights manually.
```python
import torch
import torch.nn as nn
from transformers import DeepseekV3ForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
def load_mla_o_model(repo_id="ChrisMcCormick/deepseek-tiny-mla-o-v0.1"):
"""
Load the MLA-o model with output subspace decomposition
"""
print("\n<<Ignore the 'weights not used' warning>>\n")
# Load base model (without decomposed weights)
model = DeepseekV3ForCausalLM.from_pretrained(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
print("\nPatching weights...\n")
# Download the safetensors file to get the decomposed weights
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
weights = load_file(weights_path)
# Apply output subspace decomposition to all attention layers
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Calculate dimensions
in_features = attn.num_heads * attn.v_head_dim # 8 * 32 = 256
o_latent_dim = 96 # Output latent dimension
out_features = model.config.hidden_size # 256
bias = bool(getattr(model.config, "attention_bias", False))
# Replace o_proj with sequential decomposition
attn.o_proj = nn.Sequential(
nn.Linear(in_features, o_latent_dim, bias=False), # W^OA: 256 -> 96
nn.RMSNorm(o_latent_dim, eps=model.config.rms_norm_eps), # Normalization
nn.Linear(o_latent_dim, out_features, bias=bias), # W^OB: 96 -> 256
)
# Load the decomposed weights
layer_prefix = f"model.layers.{layer_idx}.self_attn.o_proj"
# Load W^OA weights (o_proj.0.weight)
w_oa_key = f"{layer_prefix}.0.weight"
if w_oa_key in weights:
attn.o_proj[0].weight.data = weights[w_oa_key]
# Load RMSNorm weights (o_proj.1.weight)
w_norm_key = f"{layer_prefix}.1.weight"
if w_norm_key in weights:
attn.o_proj[1].weight.data = weights[w_norm_key]
# Load W^OB weights (o_proj.2.weight)
w_ob_key = f"{layer_prefix}.2.weight"
if w_ob_key in weights:
attn.o_proj[2].weight.data = weights[w_ob_key]
# Load W^OB bias if it exists
w_ob_bias_key = f"{layer_prefix}.2.bias"
if w_ob_bias_key in weights and attn.o_proj[2].bias is not None:
attn.o_proj[2].bias.data = weights[w_ob_bias_key]
print("Model loaded and patched.")
return model, tokenizer
# Load the model
model, tokenizer = load_mla_o_model()
# Generate text
inputs = tokenizer("The future of AI is", return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=50,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## Training Details
- **Pre-training Dataset**: WikiText-103
- **Optimizer**: AdamW
- **Learning Rate**: 5e-4
- **Weight Decay**: 0.01
- **Precision**: bfloat16
- **Compilation**: torch.compile with inductor backend
- **Training Steps**: 12,500
- **Effective Batch Size**: 1,024
## Limitations
- Small scale model (16M parameters) intended for research purposes
- Trained on limited data compared to production models |