import torch import torch # Import the exception type to catch from torch._subclasses.fake_tensor import DataDependentOutputException def load_checkpoint(path): return torch.load(path, map_location=torch.device('cpu')) def tensors_allclose(val_cur, val_ref): try: # Attempt to compare using torch.allclose directly. if not torch.allclose(val_cur, val_ref): return False except torch._subclasses.fake_tensor.DataDependentOutputException: # If we catch a fake tensor exception, convert tensors to real CPU tensors. real_val_cur = val_cur.detach().cpu() if hasattr(val_cur, "detach") else val_cur real_val_ref = val_ref.detach().cpu() if hasattr(val_ref, "detach") else val_ref if not torch.allclose(real_val_cur, real_val_ref): return False return True def compare_checkpoints(paths): checkpoints = {path: load_checkpoint(path) for path in paths} # Use the first checkpoint as the reference. ref_path, ref_ckpt = list(checkpoints.items())[0] all_same = True for path, ckpt in checkpoints.items(): # Check if the types match. if type(ckpt) != type(ref_ckpt): print(f"Type mismatch: {path} is of type {type(ckpt)}, expected {type(ref_ckpt)}") all_same = False continue # If checkpoint is a dictionary, compare keys and values. if isinstance(ckpt, dict): if set(ckpt.keys()) != set(ref_ckpt.keys()): print(f"Key mismatch in {path}.") all_same = False continue for key in ckpt: val_ref = ref_ckpt[key] val_cur = ckpt[key] # If the value is a tensor, compare using our helper function. if isinstance(val_ref, torch.Tensor) and isinstance(val_cur, torch.Tensor): if not tensors_allclose(val_cur, val_ref): print(f"Tensor values differ for key '{key}' in {path}.") all_same = False else: if val_cur != val_ref: print(f"Value for key '{key}' differs in {path}.") all_same = False else: # If the checkpoints are not dictionaries, compare them directly. if ckpt != ref_ckpt: print(f"Checkpoint {path} differs from {ref_path}.") all_same = False if all_same: print("All checkpoints are identical.") else: print("Not all checkpoints are identical.") # Generate file paths for ranks 0 through 7. paths = [f"grpo_cg_packfix_128_24576_1_1e-6_0_0_1_new_veRL/global_step_40/actor/model_world_size_8_rank_{i}.pt" for i in range(8)] compare_checkpoints(paths)