Robotics
LeRobot
Safetensors
smolvla

Data norm issue during evaluation/inference

#8
by RonanMcGovern - opened

It seems the code has a hack that avoids overwriting the normalisation parameters "# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset" in modeling_smolvla.py.

I think this is ok during training, because there's a dataset being read in. BUT during evaluation, this needs to be commented out because the stats need to come from the model right?

I think you want something like this right:

    skip_normalization_stats: bool = False,
) -> torch.nn.Module:
    state_dict = safetensors.torch.load_file(filename, device=device)

    # Optional user-supplied renames (e.g. "model._orig_mod.//model.")
    if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
        state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)

    state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))

    # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
    # MODIFIED: Only skip normalization stats during training when dataset stats are provided
    norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
    if skip_normalization_stats:
        # Training mode: Skip normalization stats so they come from dataset
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}

    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    # Handle missing/unexpected keys based on whether we skipped normalization
    if skip_normalization_stats:
        # Training mode: Expect normalization keys to be missing (that's intentional)
        if not all(key.startswith(norm_keys) for key in missing) or unexpected:
            raise RuntimeError(
                f"SmolVLA {len(missing)} missing / {len(unexpected)} unexpected keys"
            )
    else:
        # Inference mode: Warn about any missing/unexpected keys
        if missing or unexpected:
            print(f"Warning: SmolVLA loading - {len(missing)} missing keys, {len(unexpected)} unexpected keys")
            if missing:
                print(f"Missing keys: {missing}")
            if unexpected:
                print(f"Unexpected keys: {unexpected}")

It's helpfull !

Sign up or log in to comment