Spaces:
Running
Running
# /// script | |
# dependencies = [ | |
# "torch", | |
# "numpy", | |
# ] | |
# /// | |
"""Generate and save shared weights for consistent comparison.""" | |
import torch | |
import numpy as np | |
from pathlib import Path | |
# Model configuration | |
NUM_EXPERTS = 128 | |
HIDDEN_SIZE = 1152 | |
INTERMEDIATE_SIZE = 3072 | |
TOP_K = 4 | |
# Input configuration | |
BATCH_SIZE = 1 | |
SEQ_LEN = 100 | |
DTYPE = "float32" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Seeds for reproducibility | |
WEIGHT_SEED = 999 | |
EXPERT_SEED = 777 | |
INPUT_SEED = 123 | |
GENERAL_SEED = 42 | |
def set_seed(seed: int): | |
"""Set seeds for reproducibility.""" | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# Generate shared weights for all implementations | |
print("Generating shared weights...") | |
# Router weights | |
set_seed(WEIGHT_SEED) | |
router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE) | |
torch.nn.init.kaiming_uniform_(router_weight) | |
router_bias = torch.zeros(NUM_EXPERTS) | |
# Expert weights - using proper dimensions for gate/up combined projection | |
set_seed(EXPERT_SEED) | |
gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02) | |
gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE) | |
down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02) | |
down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE) | |
# Save weights | |
torch.save(router_weight, 'router_weight.pt') | |
torch.save(router_bias, 'router_bias.pt') | |
torch.save(gate_up_proj, 'gate_up_proj.pt') | |
torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt') | |
torch.save(down_proj, 'down_proj.pt') | |
torch.save(down_proj_bias, 'down_proj_bias.pt') | |
print(f"Saved weights:") | |
print(f" Router: {tuple(router_weight.shape)}") | |
print(f" Gate/Up proj: {tuple(gate_up_proj.shape)}") | |
print(f" Down proj: {tuple(down_proj.shape)}") | |
print(f" Hidden size: {HIDDEN_SIZE}") |