File size: 5,102 Bytes
1845c59
c02b2f3
1845c59
 
 
 
 
 
 
 
 
 
 
 
 
 
c02b2f3
1845c59
c02b2f3
1845c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c02b2f3
1845c59
 
 
 
 
c02b2f3
 
 
 
1845c59
 
c02b2f3
1845c59
c02b2f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1845c59
 
 
c02b2f3
 
 
 
 
 
 
 
 
 
1845c59
c02b2f3
1845c59
 
 
 
 
 
c02b2f3
 
1845c59
 
c02b2f3
 
1845c59
 
 
 
c02b2f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
---
license: mit
language:
- en
datasets:
- wikitext
- glue
pipeline_tag: text-generation
tags:
- transformer
- attention
- mla
- research
- output-subspace
---

# DeepSeek-Tiny with MLA-o V0.1

6-layer DeepSeek-V3 with MLA + shared output latent space ("MLA-o") trained for research on shared subspaces in Transformer attention mechanisms.

## Model Description

- **Model Type**: Transformer Decoder (DeepSeek-V3 based)
- **Architecture**: 6-layer decoder with Mixture of Experts
- **Parameters**: 16.17M
- **Hidden Size**: 256
- **Attention Heads**: 8
- **Head Dimension**: 32
- **Sequence Length**: 1,024 tokens
- **Query Latent Dimension**: 96
- **Key-Value Latent Dimension**: 64
- **Output Latent Dimension**: 96

## Performance

- **SST-2 Accuracy**: 86.24%
- **WikiText-103 Perplexity**: 29.33

## Research Context

This model is part of the [shared-subspaces](https://github.com/chrisjmccormick/shared-subspaces) research project investigating the impact of shared output latent spaces in Transformer attention mechanisms.

### Output Subspace Decomposition
This model implements a shared output latent space where the attention output projection W^O is decomposed into:
```
W^O = W^OA · W^OB
```
Where W^OA are per-head projections to the latent space and W^OB is a shared projection back to the model dimension.

## Usage

Rather than overwrite the entire attention layer, we simply patched the `o_proj` parameter with a `nn.Sequential`. It's an easy way to modify the model prior to pre-training, but loading the weights is a different story.

The below code applies the patch, and then loads in the necessary weights manually.

```python
import torch
import torch.nn as nn
from transformers import DeepseekV3ForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

def load_mla_o_model(repo_id="ChrisMcCormick/deepseek-tiny-mla-o-v0.1"):
    """
    Load the MLA-o model with output subspace decomposition
    """
    
    print("\n<<Ignore the 'weights not used' warning>>\n")

    # Load base model (without decomposed weights)
    model = DeepseekV3ForCausalLM.from_pretrained(repo_id)
    tokenizer = AutoTokenizer.from_pretrained(repo_id)

    print("\nPatching weights...\n")

    # Download the safetensors file to get the decomposed weights
    weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    weights = load_file(weights_path)
    
    # Apply output subspace decomposition to all attention layers
    for layer_idx, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        
        # Calculate dimensions
        in_features = attn.num_heads * attn.v_head_dim  # 8 * 32 = 256
        o_latent_dim = 96  # Output latent dimension
        out_features = model.config.hidden_size  # 256
        bias = bool(getattr(model.config, "attention_bias", False))
        
        # Replace o_proj with sequential decomposition
        attn.o_proj = nn.Sequential(
            nn.Linear(in_features, o_latent_dim, bias=False),      # W^OA: 256 -> 96
            nn.RMSNorm(o_latent_dim, eps=model.config.rms_norm_eps),  # Normalization
            nn.Linear(o_latent_dim, out_features, bias=bias),      # W^OB: 96 -> 256
        )
        
        # Load the decomposed weights
        layer_prefix = f"model.layers.{layer_idx}.self_attn.o_proj"
        
        # Load W^OA weights (o_proj.0.weight)
        w_oa_key = f"{layer_prefix}.0.weight"
        if w_oa_key in weights:
            attn.o_proj[0].weight.data = weights[w_oa_key]
        
        # Load RMSNorm weights (o_proj.1.weight)
        w_norm_key = f"{layer_prefix}.1.weight"
        if w_norm_key in weights:
            attn.o_proj[1].weight.data = weights[w_norm_key]
        
        # Load W^OB weights (o_proj.2.weight)
        w_ob_key = f"{layer_prefix}.2.weight"
        if w_ob_key in weights:
            attn.o_proj[2].weight.data = weights[w_ob_key]
        
        # Load W^OB bias if it exists
        w_ob_bias_key = f"{layer_prefix}.2.bias"
        if w_ob_bias_key in weights and attn.o_proj[2].bias is not None:
            attn.o_proj[2].bias.data = weights[w_ob_bias_key]
    
    print("Model loaded and patched.")
    return model, tokenizer

# Load the model
model, tokenizer = load_mla_o_model()

# Generate text
inputs = tokenizer("The future of AI is", return_tensors="pt")
with torch.no_grad():
    outputs = model.generate(
        **inputs, 
        max_length=50, 
        temperature=0.7, 
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

```

## Training Details

- **Pre-training Dataset**: WikiText-103
- **Optimizer**: AdamW
- **Learning Rate**: 5e-4
- **Weight Decay**: 0.01
- **Precision**: bfloat16
- **Compilation**: torch.compile with inductor backend
- **Training Steps**: 12,500
- **Effective Batch Size**: 1,024

## Limitations

- Small scale model (16M parameters) intended for research purposes
- Trained on limited data compared to production models