Data norm issue during evaluation/inference
#8
by
RonanMcGovern
- opened
It seems the code has a hack that avoids overwriting the normalisation parameters "# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset" in modeling_smolvla.py
.
I think this is ok during training, because there's a dataset being read in. BUT during evaluation, this needs to be commented out because the stats need to come from the model right?
I think you want something like this right:
skip_normalization_stats: bool = False,
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
# MODIFIED: Only skip normalization stats during training when dataset stats are provided
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
if skip_normalization_stats:
# Training mode: Skip normalization stats so they come from dataset
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# Handle missing/unexpected keys based on whether we skipped normalization
if skip_normalization_stats:
# Training mode: Expect normalization keys to be missing (that's intentional)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
f"SmolVLA {len(missing)} missing / {len(unexpected)} unexpected keys"
)
else:
# Inference mode: Warn about any missing/unexpected keys
if missing or unexpected:
print(f"Warning: SmolVLA loading - {len(missing)} missing keys, {len(unexpected)} unexpected keys")
if missing:
print(f"Missing keys: {missing}")
if unexpected:
print(f"Unexpected keys: {unexpected}")
It's helpfull !