Spaces:
Running
Running
File size: 7,093 Bytes
bad4ddc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
import torch
from torch import nn
from torch.nn import functional as F
from utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os
# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
# list all the files in the directory
print(f"Loading weights from: {data_dir}")
print(f"Files in directory: {list(Path(data_dir).glob('*'))}")
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
class GptOssRouter(nn.Module):
def __init__(self, router_weight, router_bias):
super().__init__()
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.hidden_dim = HIDDEN_SIZE
self.weight = nn.Parameter(router_weight.clone())
self.bias = nn.Parameter(router_bias.clone())
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight, self.bias)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class GptOssExperts(nn.Module):
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.num_experts = NUM_EXPERTS
self.hidden_size = HIDDEN_SIZE
self.expert_dim = self.hidden_size
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
self.down_proj = nn.Parameter(down_proj.clone())
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
self.alpha = 1.702
self.limit = 7.0
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
num_experts = routing_weights.shape[1]
if hidden_states.device.type == "cpu" or self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
expert_idx = expert_idx[0]
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
else:
hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
next_states = next_states + self.down_proj_bias[..., None, :]
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.sum(dim=0)
return next_states
class GptOssMoEMLP(nn.Module):
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.router = GptOssRouter(router_weight, router_bias)
self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states)
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores
# Run the model
set_seed(GENERAL_SEED)
device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)
print("\n=== GPT-OSS Implementation ===")
# Initialize model with loaded weights
model = GptOssMoEMLP(
router_weight.to(device, dtype=dtype),
router_bias.to(device, dtype=dtype),
gate_up_proj.to(device, dtype=dtype),
gate_up_proj_bias.to(device, dtype=dtype),
down_proj.to(device, dtype=dtype),
down_proj_bias.to(device, dtype=dtype)
).to(device=device, dtype=dtype)
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
# Benchmark the model using different input tensors on each iteration
tokens = BATCH_SIZE * SEQ_LEN
input_shape = (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens,
save_json="gptoss_results.json", input_shape=input_shape, input_seed_base=INPUT_SEED) as bench:
output, stats = bench(model)
print(f"\nOutput sum: {output[0].sum().item():.6f}") |