File size: 13,723 Bytes
e50ecc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58d30dd
e50ecc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/env python3
# ───────────────────  make local repo override any wheel  ────────────────────
import sys, os; sys.path.insert(0, os.path.abspath("."))

# ───────────────────  Flash-Attention + CUDA stubs  ──────────────────────────
import types, torch, torch.nn.functional as F, importlib.machinery as im
flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[]
sys.modules["flash_attn"] = flash_pkg
fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None)
def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0)
for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa)
sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa
pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None)
pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x
sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad

if not torch.cuda.is_available():
    torch.cuda.is_available=lambda:False
    torch.cuda.get_device_capability=lambda dev=None:(0,0)
    torch.cuda.current_device=lambda:0
    torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0)

import importlib.metadata as _im
if "flash_attn" not in _im.packages_distributions():
    rv, rd = _im.version, _im.distribution
    _im.version      = lambda p:"0.0.0" if p=="flash_attn" else rv(p)
    _im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p)

# ───────────────────  std imports  ───────────────────────────────────────────
from pathlib import Path
import argparse, json, shutil
from huggingface_hub import hf_hub_download
from transformers import AutoConfig
from gr00t.model.gr00t_n1 import GR00T_N1_5

# ───────────────────  helpers  ───────────────────────────────────────────────
def patched_cfg():
    p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json")
    d = json.load(open(p))
    if d.get("model_type") != "gr00t_n1_5":
        d["model_type"] = "gr00t_n1_5"
        patched = Path(p).with_name("config_patched.json")
        patched.write_text(json.dumps(d)); return str(patched)
    return p

def build_blank():
    cfg = AutoConfig.from_pretrained(patched_cfg(),
                                     trust_remote_code=True,
                                     local_files_only=True)
    cfg.backbone_cfg.update(dict(tune_llm=True))          # enable L-tower
    cfg.backbone_cfg.pop("checkpoint_path", None)
    cfg.backbone_cfg.pop("use_pretrained",   None)
    cfg.action_head_cfg.pop("checkpoint_path", None)
    torch.manual_seed(0)
    return GR00T_N1_5(cfg, local_model_path="")           # random weights

def maybe_add_lm_head(model):
    """Ensure lm_head is properly initialized with weights"""
    # Navigate to the language model
    lm = model.backbone.eagle_model.language_model
    
    # Get dimensions from embed_tokens
    embed_tokens = lm.model.embed_tokens
    vocab_size = embed_tokens.num_embeddings
    hidden_size = embed_tokens.embedding_dim
    
    print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}")
    
    # Expected shape based on architecture: [151680, 2048]
    if vocab_size != 151680 or hidden_size != 2048:
        print(f"⚠️  Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048")
    
    # Check if lm_head exists
    if hasattr(lm, "lm_head"):
        print(f"lm_head attribute exists: {lm.lm_head is not None}")
        
        # Even if lm_head exists, it might not have weights properly initialized
        # Just replace it with a properly initialized one
        print("Creating new lm_head with proper initialization...")
    else:
        print("lm_head attribute missing, creating...")
    
    # Create a new lm_head with proper initialization
    # Note: nn.Linear uses (in_features, out_features), so it's (hidden_size, vocab_size)
    new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
    
    # Initialize weights with normal distribution (std=0.02 is standard for LM heads)
    torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02)
    
    # Convert to bfloat16 to match backbone
    new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16)
    
    # Replace the lm_head
    lm.lm_head = new_lm_head
    
    print(f"βœ“ Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)")
    print(f"  Weight shape: {lm.lm_head.weight.shape}")
    print(f"  Weight dtype: {lm.lm_head.weight.dtype}")
    print(f"  Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M")

def set_mixed(model):
    """Set mixed precision: backbone in bf16, action head in fp32"""
    for n,p in model.named_parameters():
        if n.startswith("backbone.") or "lm_head" in n:
            p.data = p.data.to(torch.bfloat16)
        else:
            p.data = p.data.to(torch.float32)

def copy_tokenizer(out):
    for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"):
        try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f)
        except Exception: pass

def diagnose_model(model):
    """Print diagnostic info about the model"""
    print("\nModel diagnostics:")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  Total params: {total_params/1e6:,.0f}M")
    
    # Check for key components
    has_lm_head = False
    lm_head_params = 0
    lm_head_location = None
    
    for name, param in model.named_parameters():
        if "lm_head" in name:
            has_lm_head = True
            lm_head_params += param.numel()
            lm_head_location = name
    
    print(f"  Has lm_head: {'βœ“' if has_lm_head else 'βœ—'}")
    if has_lm_head:
        print(f"  lm_head params: {lm_head_params/1e6:,.0f}M")
        print(f"  lm_head location: {lm_head_location}")
        
        # Check if the params are actually counted in the total
        lm = model.backbone.eagle_model.language_model
        if hasattr(lm, 'lm_head') and lm.lm_head is not None:
            actual_params = lm.lm_head.weight.numel()
            print(f"  lm_head actual params: {actual_params/1e6:,.0f}M")
            print(f"  lm_head weight shape: {lm.lm_head.weight.shape}")
            print(f"  lm_head weight dtype: {lm.lm_head.weight.dtype}")

def validate_model_architecture(model):
    """Validate model against the architecture specification"""
    print("\n" + "="*60)
    print("ARCHITECTURE VALIDATION")
    print("="*60)
    
    # Expected architecture based on the spec
    expected_shapes = {
        # Key layers to check - using actual parameter names with .weight suffix
        "backbone.eagle_model.language_model.lm_head.weight": (151680, 2048),
        "backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048),
        "backbone.eagle_model.language_model.model.norm.weight": (2048,),
        "backbone.eagle_model.mlp1.0.weight": (2048, 1152),
        "backbone.eagle_model.mlp1.0.bias": (2048,),
        "action_head.position_embedding.weight": (1024, 1536),  # Fixed: added .weight
        "action_head.vlln.weight": (2048,),
        "action_head.vlln.bias": (2048,),
    }
    
    errors = []
    warnings = []
    
    # Get all parameters
    param_dict = dict(model.named_parameters())
    
    # Debug: print actual action_head parameter names to see the pattern
    action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")]
    if action_head_params:
        print("\nFound position embedding parameters:")
        for name in action_head_params[:5]:
            print(f"  {name}: {param_dict[name].shape}")
    
    # Check key shapes
    for name, expected_shape in expected_shapes.items():
        if name in param_dict:
            actual_shape = tuple(param_dict[name].shape)
            if actual_shape != expected_shape:
                errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}")
            else:
                print(f"βœ“ {name}: {actual_shape}")
        else:
            errors.append(f"Missing parameter: {name}")
    
    # Check dtypes
    dtype_issues = []
    for name, param in param_dict.items():
        if name.startswith("backbone."):
            if param.dtype != torch.bfloat16:
                dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}")
        elif name.startswith("action_head."):
            if param.dtype != torch.float32:
                dtype_issues.append(f"{name}: expected float32, got {param.dtype}")
    
    if dtype_issues:
        warnings.extend(dtype_issues[:5])  # Only show first 5
    
    # Count parameters by component
    component_params = {
        "backbone": 0,
        "action_head": 0,
        "other": 0
    }
    
    for name, param in param_dict.items():
        count = param.numel()
        if name.startswith("backbone."):
            component_params["backbone"] += count
        elif name.startswith("action_head."):
            component_params["action_head"] += count
        else:
            component_params["other"] += count
    
    # Special check for lm_head
    lm_head_found = False
    lm_head_params = 0
    for name, param in param_dict.items():
        if "lm_head" in name:
            lm_head_found = True
            lm_head_params += param.numel()
    
    # Report results
    print("\nValidation Results:")
    print(f"  Errors: {len(errors)}")
    print(f"  Warnings: {len(warnings)}")
    
    if errors:
        print("\n❌ ERRORS:")
        for error in errors:
            print(f"  - {error}")
    
    if warnings:
        print("\n⚠️  WARNINGS (showing first 5):")
        for warning in warnings[:5]:
            print(f"  - {warning}")
        if len(warnings) > 5:
            print(f"  ... and {len(warnings) - 5} more")
    
    print("\nπŸ“Š Parameter Summary:")
    total = sum(component_params.values())
    print(f"  Total: {total/1e6:,.1f}M")
    print(f"  Backbone: {component_params['backbone']/1e6:,.1f}M")
    print(f"  Action Head: {component_params['action_head']/1e6:,.1f}M")
    if component_params['other'] > 0:
        print(f"  Other: {component_params['other']/1e6:,.1f}M")
    
    print(f"\n  lm_head found: {'βœ“' if lm_head_found else 'βœ—'}")
    if lm_head_found:
        print(f"  lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)")
    
    # Expected totals based on NVIDIA model
    expected_total = 2724  # Million params
    actual_total = total / 1e6
    diff = actual_total - expected_total
    
    print(f"\n  Expected total: {expected_total}M")
    print(f"  Actual total: {actual_total:.1f}M")
    print(f"  Difference: {diff:+.1f}M")
    
    if abs(diff) < 1:  # Within 1M params
        print("\nβœ… Model architecture matches expected specification!")
    else:
        print("\n❌ Model architecture does NOT match specification!")
    
    return len(errors) == 0

# ───────────────────  main  ──────────────────────────────────────────────────
def main(device: str, out_dir: str):
    print("="*60)
    print("Creating blank GR00T-N1.5-3B model")
    print("="*60)
    
    model = build_blank()
    
    # Add diagnostics before adding lm_head
    print("\nBefore adding lm_head:")
    diagnose_model(model)
    
    maybe_add_lm_head(model)
    
    # Add diagnostics after adding lm_head
    print("\nAfter adding lm_head:")
    diagnose_model(model)
    
    set_mixed(model)
    model = model.to(device)
    
    # Validate against architecture spec
    validate_model_architecture(model)

    out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True)
    
    print(f"\nSaving model to {out}...")
    model.save_pretrained(out, max_shard_size="2GB")
    copy_tokenizer(out)
    (out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n")
    
    # Final summary
    print("\n" + "="*60)
    print("FINAL SUMMARY")
    print("="*60)
    print(f"βœ… Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) β†’ {out}")
    print(f"βœ… Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params")
    print(f"βœ… Ready for training with Apache-2.0 license")

# ───────────────────  CLI  ───────────────────────────────────────────────────
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--device", default="cpu")
    ap.add_argument("--out_dir", default="DolphinGR00T-N1.5-3B-Zero")
    args = ap.parse_args(); main(args.device, args.out_dir)