File size: 1,924 Bytes
bad4ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# /// script
# dependencies = [
#     "torch",
#     "numpy",
# ]
# ///

"""Generate and save shared weights for consistent comparison."""
import torch
import numpy as np
from pathlib import Path

# Model configuration
NUM_EXPERTS = 128
HIDDEN_SIZE = 1152
INTERMEDIATE_SIZE = 3072
TOP_K = 4

# Input configuration
BATCH_SIZE = 1
SEQ_LEN = 100
DTYPE = "float32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Seeds for reproducibility
WEIGHT_SEED = 999
EXPERT_SEED = 777
INPUT_SEED = 123
GENERAL_SEED = 42

def set_seed(seed: int):
    """Set seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Generate shared weights for all implementations
print("Generating shared weights...")

# Router weights
set_seed(WEIGHT_SEED)
router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
torch.nn.init.kaiming_uniform_(router_weight)
router_bias = torch.zeros(NUM_EXPERTS)

# Expert weights - using proper dimensions for gate/up combined projection
set_seed(EXPERT_SEED)
gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)

# Save weights
torch.save(router_weight, 'router_weight.pt')
torch.save(router_bias, 'router_bias.pt')
torch.save(gate_up_proj, 'gate_up_proj.pt')
torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
torch.save(down_proj, 'down_proj.pt')
torch.save(down_proj_bias, 'down_proj_bias.pt')

print(f"Saved weights:")
print(f"  Router: {tuple(router_weight.shape)}")
print(f"  Gate/Up proj: {tuple(gate_up_proj.shape)}")
print(f"  Down proj: {tuple(down_proj.shape)}")
print(f"  Hidden size: {HIDDEN_SIZE}")