Spaces:
Running
Running
File size: 1,924 Bytes
bad4ddc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
"""Generate and save shared weights for consistent comparison."""
import torch
import numpy as np
from pathlib import Path
# Model configuration
NUM_EXPERTS = 128
HIDDEN_SIZE = 1152
INTERMEDIATE_SIZE = 3072
TOP_K = 4
# Input configuration
BATCH_SIZE = 1
SEQ_LEN = 100
DTYPE = "float32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Seeds for reproducibility
WEIGHT_SEED = 999
EXPERT_SEED = 777
INPUT_SEED = 123
GENERAL_SEED = 42
def set_seed(seed: int):
"""Set seeds for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Generate shared weights for all implementations
print("Generating shared weights...")
# Router weights
set_seed(WEIGHT_SEED)
router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
torch.nn.init.kaiming_uniform_(router_weight)
router_bias = torch.zeros(NUM_EXPERTS)
# Expert weights - using proper dimensions for gate/up combined projection
set_seed(EXPERT_SEED)
gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
# Save weights
torch.save(router_weight, 'router_weight.pt')
torch.save(router_bias, 'router_bias.pt')
torch.save(gate_up_proj, 'gate_up_proj.pt')
torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
torch.save(down_proj, 'down_proj.pt')
torch.save(down_proj_bias, 'down_proj_bias.pt')
print(f"Saved weights:")
print(f" Router: {tuple(router_weight.shape)}")
print(f" Gate/Up proj: {tuple(gate_up_proj.shape)}")
print(f" Down proj: {tuple(down_proj.shape)}")
print(f" Hidden size: {HIDDEN_SIZE}") |