File size: 2,284 Bytes
0298ad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import os
import glob
import re
import shutil
from torchtitan.tools.logging import logger
def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
"""Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
if keep_latest_k <= 0:
return # Keep all checkpoints
logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
# Cleanup DCP checkpoints (step-*)
dcp_checkpoints = sorted(
glob.glob(os.path.join(checkpoint_dir, "step-*")),
key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
reverse=True
)
# Filter out HF format directories
dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
if len(dcp_checkpoints) > keep_latest_k:
checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
for ckpt_path in checkpoints_to_delete:
if os.path.isdir(ckpt_path): # Ensure it's a directory
try:
shutil.rmtree(ckpt_path)
except OSError as e:
logger.error(f"Error removing directory {ckpt_path}: {e}")
# Cleanup HF checkpoints (step-*-hf)
hf_checkpoints = sorted(
glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
reverse=True
)
if len(hf_checkpoints) > keep_latest_k:
checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
for ckpt_path in checkpoints_to_delete:
if os.path.isdir(ckpt_path): # Ensure it's a directory
try:
shutil.rmtree(ckpt_path)
except OSError as e:
logger.error(f"Error removing directory {ckpt_path}: {e}")
|