# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import json import os import time from datetime import timedelta import torch from datasets import interleave_datasets, load_dataset from torch.distributed.elastic.multiprocessing.errors import record from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import fla # noqa from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss from fla.ops.common.utils import prepare_position_ids from flame.components.checkpoint import TrainState from flame.config_manager import JobConfig from flame.data import build_dataloader, shuffle from flame.models.parallelize_fla import parallelize_fla from flame.models.pipeline_fla import pipeline_fla from flame.tools.utils import get_nparams_and_flops from flame.utils.checkpoint import cleanup_local_checkpoints from flame.utils.convert_dcp_to_hf import save_pretrained from flame.utils.hf_utils import upload_checkpoint_to_hf from datetime import datetime from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.ft import FTParallelDims, init_ft_manager from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible from torchtitan.components.optimizer import build_optimizers from torchtitan.distributed import ParallelDims from torchtitan.distributed import utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling def build_tokenizer(job_config: JobConfig) -> AutoTokenizer: return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path) register_train_spec( TrainSpec( name="fla", cls=AutoModelForCausalLM, config=AutoConfig, parallelize_fn=parallelize_fla, pipelining_fn=pipeline_fla, build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_dataloader, build_tokenizer_fn=build_tokenizer, build_loss_fn=build_cross_entropy_loss, ) ) # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): logger.info(f"Starting job: {job_config.job.description}") if job_config.experimental.custom_model_path: utils.import_module_from_path(job_config.experimental.custom_model_path) # used for colorful printing color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color if job_config.job.print_args: logger.info( f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}" ) # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) device_module, device_type = utils.device_module, utils.device_type device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(device) ft_manager = init_ft_manager(job_config) run_specific_repo_id = None if getattr(job_config.checkpoint, "hf_upload_enabled", False): hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None) if hf_repo_base: # Generate timestamp (adjust format if desired) timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") run_specific_repo_id = f"{hf_repo_base}-{timestamp}" logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}") else: logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.") # Disable upload if base name is missing job_config.checkpoint.hf_upload_enabled = False # init distributed world_size = int(os.environ["WORLD_SIZE"]) if not ft_manager.enabled: parallel_dims = ParallelDims( dp_shard=job_config.training.data_parallel_shard_degree, dp_replicate=job_config.training.data_parallel_replicate_degree, cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, ) else: parallel_dims = FTParallelDims( dp_shard=job_config.training.data_parallel_shard_degree, dp_replicate=job_config.training.data_parallel_replicate_degree, cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = build_device_memory_monitor() gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 if parallel_dims.pp_enabled: raise NotImplementedError( "Pipeline parallelism is not supported in this version" ) """ ! TODO[flame]: We need to fix the pipeline parallelism for flame [x] Match the key of models' components with the actual naming [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically forces to tie if head is None, we need to handle this case [ ] """ pp_mesh = world_mesh["pp"] # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss) dist_utils.set_determinism( world_mesh, device, job_config.training.seed, job_config.training.deterministic ) train_spec = get_train_spec(job_config.model.name) logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( job_config.model.tokenizer_path, trust_remote_code=True, model_max_length=int(1e10), ) logger.info(f"{tokenizer}") logger.info( f"Loading dataset {job_config.training.dataset}" f":{job_config.training.dataset_name}" if job_config.training.dataset_name is not None else "" ) min_num_shards = dp_degree * job_config.training.num_workers if len(job_config.training.dataset.split(",")) == 1: dataset = load_dataset( path=job_config.training.dataset, name=getattr(job_config.training, "dataset_name", None), data_dir=getattr(job_config.training, "data_dir", None), data_files=getattr(job_config.training, "data_files", None), split=job_config.training.dataset_split or "train", trust_remote_code=True, streaming=job_config.training.streaming, num_proc=( job_config.training.num_workers if not job_config.training.streaming else None ), ) logger.info(f"{dataset}") logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") if not job_config.training.streaming: # the states of map-style dataset is recoverable after shuffling dataset = dataset.shuffle( seed=job_config.training.seed ).to_iterable_dataset(num_shards=min_num_shards) else: if dataset.num_shards < min_num_shards: logger.warning( f"{color.red}" f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). " f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " f"{job_config.training.num_workers} dataloader workers. " f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." f"{color.reset}" ) dataset = ( load_dataset( path=job_config.training.dataset, name=getattr(job_config.training, "dataset_name", None), data_dir=getattr(job_config.training, "data_dir", None), data_files=getattr(job_config.training, "data_files", None), split=job_config.training.dataset_split or "train", trust_remote_code=True, streaming=False, num_proc=job_config.training.num_workers, ) .shuffle(seed=job_config.training.seed) .to_iterable_dataset(num_shards=min_num_shards) ) else: dataset = shuffle(dataset, seed=job_config.training.seed) else: datasets = job_config.training.dataset.split(",") if job_config.training.dataset_name is not None: dataset_names = [ name or None for name in job_config.training.dataset_name.split(",") ] assert len(dataset_names) == len(datasets), ( "The number of dataset names must match the number of datasets" ) else: dataset_names = [None] * len(datasets) if job_config.training.dataset_split is not None: dataset_splits = [ split or "train" for split in job_config.training.dataset_split.split(",") ] assert len(dataset_splits) == len(datasets), ( "The number of dataset splits must match the number of datasets" ) else: dataset_splits = ["train"] * len(datasets) if job_config.training.data_dir is not None: data_dirs = [ data_dir or None for data_dir in job_config.training.data_dir.split(",") ] assert len(data_dirs) == len(datasets), ( "The number of data dirs must match the number of datasets" ) else: data_dirs = [None] * len(datasets) if job_config.training.data_files is not None: data_files = job_config.training.data_files.split(",") assert len(data_files) == len(datasets), ( "The number of data files must match the number of datasets" ) else: data_files = [None] * len(datasets) if job_config.training.data_probs is not None: data_probs = [float(p) for p in job_config.training.data_probs.split(",")] assert len(data_probs) == len(datasets), ( "The number of data probabilities must match the number of datasets" ) else: raise ValueError( "Data sampling probabilities are required if using multiple datasets" ) subsets = [] for i, prob in enumerate(data_probs): subset = load_dataset( path=datasets[i], name=dataset_names[i], data_dir=data_dirs[i], data_files=data_files[i], split=dataset_splits[i], trust_remote_code=True, streaming=job_config.training.streaming, num_proc=( job_config.training.num_workers if not job_config.training.streaming else None ), ) logger.info( f"Subset {color.cyan}{datasets[i]}" + (f":{dataset_names[i]} " if dataset_names[i] else " ") + f"(p = {prob:.3f}){color.reset}:\n" + f"{subset}" ) logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") if not job_config.training.streaming: # the states of map-style dataset is recoverable after shuffling subset = subset.shuffle( seed=job_config.training.seed ).to_iterable_dataset(num_shards=min_num_shards) else: if subset.num_shards < min_num_shards: logger.warning( f"{color.red}" f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " f"{job_config.training.num_workers} dataloader workers. " f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." f"{color.reset}" ) # again, it's ok to directly shuffle the map-style dataset # we expect an error raised if the map-style dataset still has not enough data shards subset = ( load_dataset( path=datasets[i], name=dataset_names[i], data_dir=data_dirs[i], data_files=data_files[i], split=dataset_splits[i], trust_remote_code=True, streaming=False, num_proc=job_config.training.num_workers, ) .shuffle(seed=job_config.training.seed) .to_iterable_dataset(min_num_shards) ) else: # we set relatively small buffer size here as interleaving could provide some randomness subset = shuffle( subset, seed=job_config.training.seed, buffer_size=max(128, 1024 // len(datasets)), ) if "text" in subset.column_names: subset = subset.select_columns("text") elif "content" in subset.column_names: subset = subset.select_columns("content") else: raise ValueError( f"Subset {datasets[i]} has no 'text' or 'content' column" ) subsets.append(subset) logger.info( f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" ) dataset = interleave_datasets( datasets=subsets, probabilities=data_probs, stopping_strategy="all_exhausted", seed=job_config.training.seed, ) logger.info(f"{dataset}") logger.info("Building dataloader...") dataloader = build_dataloader( dataset=dataset, tokenizer=tokenizer, rank=dp_rank, world_size=dp_degree, batch_size=job_config.training.batch_size, seq_len=job_config.training.seq_len, context_len=job_config.training.context_len, varlen=job_config.training.varlen, num_workers=job_config.training.num_workers, pin_memory=job_config.training.pin_memory, persistent_workers=job_config.training.persistent_workers, snapshot_every_n_steps=job_config.checkpoint.interval, ) logger.info(f"Loading model config from {job_config.model.config}") model_config = AutoConfig.from_pretrained(job_config.model.config) # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. disable fused norm if TP is enabled # 3. vocab size from tokenizer # 4. context_len base on inputs if parallel_dims.tp_enabled: if model_config.fuse_norm: logger.warning( f"{color.red}" f"Fused norm is not compatible with tensor parallelism. " f"Disabling it for now." f"{color.reset}" ) model_config.fuse_norm = False if parallel_dims.loss_parallel_enabled: if model_config.fuse_cross_entropy: logger.warning( f"{color.red}" f"Loss parallel enabled. Disabling fused cross entropy for now." f"{color.reset}" ) model_config.fuse_cross_entropy = False model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size) logger.info( f"Building model from the config\n{color.green}{model_config}{color.reset}" ) with torch.device("meta"): model = AutoModelForCausalLM.from_config(model_config) if ( getattr(model_config, "fuse_cross_entropy", False) and FusedLinearCrossEntropyLoss is not None ): model.criterion = FusedLinearCrossEntropyLoss( num_chunks=8 // parallel_dims.tp ) # defer weight initialization until after parallelisms are applied model.apply(lambda m: setattr(m, "_is_hf_initialized", False)) logger.info(f"{color.blue}\n{model}{color.reset}\n") # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) # calculate model size and flops per token model_param_count, num_flops_per_token = get_nparams_and_flops( model, model_config, job_config.training.context_len ) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: init_device = "cpu" elif job_config.training.enable_cpu_offload: init_device = "cpu" else: init_device = device_type # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel ( pp_schedule, model_parts, has_first_stage, has_last_stage, ) = train_spec.pipelining_fn( model, pp_mesh, parallel_dims, job_config, device, model_config, train_spec.loss_fn, ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead del model # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for m in model_parts: # apply SPMD-style PT-D techniques train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.to_empty(device=init_device) with torch.no_grad(): m.post_init() m.train() # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): model.post_init() model.train() model_parts = [model] device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " f"{device_mem_stats.max_reserved_gib:.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%)" ) # build optimizer after applying parallelisms to the model optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance optimizers.register_step_post_hook( lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) ) train_state = TrainState() # load initial checkpoint checkpoint = CheckpointManager( dataloader=dataloader, model_parts=model_parts, optimizers=optimizers, lr_schedulers=lr_schedulers, states={"train_state": train_state}, job_config=job_config, ft_manager=ft_manager, ) if job_config.checkpoint.create_seed_checkpoint: assert world_size == 1, ( "Must create seed checkpoint using a single device, to disable sharding" ) assert job_config.checkpoint.enable_checkpoint, ( "Must enable checkpointing when creating a seed checkpoint" ) checkpoint.save(curr_step=0, force=True) logger.info("Created seed checkpoint") return checkpoint.load(step=job_config.checkpoint.load_step) metric_logger = build_metrics_processor(job_config, parallel_dims) # Set dependent attributes for metric_logger metric_logger.num_flops_per_token = num_flops_per_token metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic metric_logger.lr_schedulers = ( lr_schedulers # Pass schedulers if needed by logger logic ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq if train_state.step > 0 and len(metric_logger.data_loading_times) > 0: for idx, step in enumerate(train_state.log_steps): metric_logger.log( step, global_avg_loss=train_state.global_avg_losses[idx], global_max_loss=train_state.global_max_losses[idx], ) data_iterator = iter(dataloader) train_context = dist_utils.get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, ) # variables used to keep info for metrics logging device_memory_monitor.reset_peak_stats() global_batch_size = ( job_config.training.batch_size * dp_degree * job_config.training.gradient_accumulation_steps ) num_tokens_per_step = global_batch_size * job_config.training.seq_len # train loop logger.info(f"{color.red}***** Running training *****{color.reset}") logger.info(f"{color.green} Training starts at step {train_state.step + 1}") logger.info( f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}" ) logger.info( f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}" ) logger.info( f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}" ) logger.info( f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}" f" ({num_tokens_per_step:,} tokens)" ) logger.info( f"{color.green} Total optimization steps = {job_config.training.steps:,} " f"({job_config.training.steps * num_tokens_per_step:,} tokens)" ) logger.info( f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}" f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)" ) logger.info( f"{color.green} Number of parameters = {model_param_count:,} {color.reset}" ) with ( maybe_enable_profiling( job_config, global_step=train_state.step ) as torch_profiler, maybe_enable_memory_snapshot( job_config, global_step=train_state.step ) as memory_profiler, ): while train_state.step < job_config.training.steps: train_state.step += 1 gc_handler.run(train_state.step) optimizers.zero_grad() losses = [] # do gradient accumulation if enabled for _ in range(job_config.training.gradient_accumulation_steps): # get batch data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch["input_ids"], batch["labels"] # Update metrics processor state before forward/backward metric_logger.ntokens_since_last_log += labels.numel() metric_logger.data_loading_times.append( time.perf_counter() - data_load_start ) input_ids = input_ids.to(device_type) """ TODO[flame]: We need to carefully handle the position_ids for TP/CP Depending on the Models'PE, the position_ids might be different. e.g. for TP For RoPE, all ranks have the same position_ids. [FOR HF model] For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model] e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids] Each rank has the coresponding chunked position_ids. [FOR All model] """ labels = labels.to(device_type) cu_seqlens = ( batch["cu_seqlens"].to(device_type) if "cu_seqlens" in batch else None ) if cu_seqlens is not None: position_ids = prepare_position_ids(cu_seqlens).to(torch.int32) else: position_ids = ( torch.arange(0, input_ids.shape[1], device=device_type) .repeat(input_ids.shape[0], 1) .to(torch.int32) ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=world_mesh["cp"], cp_buffers=[input_ids, labels, position_ids], cp_seq_dims=[1, 1, 1], cp_no_restore_buffers={input_ids, labels, position_ids}, cp_rotate_method=job_config.experimental.context_parallel_rotate_method, ) if parallel_dims.cp_enabled else None ) # #! TODO[flame], we should distribute the position_ids as well with CP if parallel_dims.pp_enabled: raise NotImplementedError( "Pipeline parallelism is not supported in this version" ) # Pipeline Parallel forward / backward inside step() call with train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if has_last_stage else (None, None) ) if has_first_stage: pp_schedule.step(input_ids, target=targets, losses=losses) else: pp_schedule.step(target=targets, losses=losses) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( torch.mean(torch.stack(losses)).to(device) if has_last_stage else torch.tensor([-1.0], device=device) ) else: # Non-PP forward / backward with train_context(optional_context_parallel_ctx): output = model( input_ids=input_ids, labels=labels, position_ids=position_ids, cu_seqlens=cu_seqlens, ) loss = ( output.loss / job_config.training.gradient_accumulation_steps ) loss.backward() losses.append(loss) loss = sum(losses) # clip gradients grad_norm = dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], job_config.training.max_norm, foreach=True, pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, ) # optimizer step checkpoint.maybe_wait_for_staging() if job_config.training.skip_nan_inf and ( grad_norm.isnan() or grad_norm.isinf() ): logger.warning( f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}" ) optimizers.zero_grad() train_state.skipped_step += 1 else: optimizers.step() lr_schedulers.step() # log metrics - Use MetricsProcessor if metric_logger.should_log(train_state.step): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): loss = loss.detach() # Use dist_mean/max on the accumulated loss for the step global_avg_loss, global_max_loss = ( dist_utils.dist_mean( loss, world_mesh["dp_cp"], ), dist_utils.dist_max( loss, world_mesh["dp_cp"], ), ) else: # Scale back the loss before logging global_avg_loss = global_max_loss = loss.item() # Update train state tokens and elapsed time time_now = time.perf_counter() time_delta = ( time_now - metric_logger.time_last_log ) # Use metric_logger's time train_state.token += ( metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger * parallel_dims.world_size / parallel_dims.non_data_parallel_size ) train_state.elapsed += timedelta(seconds=time_delta) train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) # Log using the metric processor last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] eta = ( train_state.elapsed * (job_config.training.steps - train_state.step) / train_state.step ) metric_logger.log( train_state.step, global_avg_loss, global_max_loss, extra_metrics={ "optimizer/lr": last_lr, "optimizer/grad_norm": grad_norm.item(), "optimizer/skipped_step": train_state.skipped_step, }, ) logger.info( f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} " f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}" ) checkpoint.save( train_state.step, force=(train_state.step == job_config.training.steps) ) if torch.distributed.get_rank() == 0: if job_config.checkpoint.enable_checkpoint: hf_target_path = None dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}") # TODO: Haven't tested this one yet if getattr(job_config.checkpoint, "convert_to_hf_on_save", False): try: # Get the path where DCP was just saved # Check CheckpointManager API for the best way, assuming get_save_path exists hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}") save_pretrained( # Call the imported function path=hf_target_path, # Pass target HF path as 'path' step=train_state.step, config=job_config.model.config, # Pass model config path/id tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id ) logger.info(f"Successfully converted step {train_state.step} to HF format.") except Exception as e: logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True) base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder) if getattr(job_config.checkpoint, "hf_upload_enabled", True): upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf") keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5) local_path_to_upload = None if upload_format == "hf": if hf_target_path and os.path.isdir(hf_target_path): local_path_to_upload = hf_target_path elif upload_format == "dcp": if dcp_save_path and os.path.isdir(dcp_save_path): local_path_to_upload = dcp_save_path if local_path_to_upload: try: upload_checkpoint_to_hf( local_path=local_path_to_upload, step=train_state.step, hf_repo_id_for_run=run_specific_repo_id, upload_format=upload_format, hf_keep_latest_k=job_config.checkpoint.keep_latest_k, ) except Exception as e: logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True) # signal the profiler that the next profiling step has started if torch_profiler: torch_profiler.step() if memory_profiler: memory_profiler.step() # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if train_state.step == 1: dist_utils.set_pg_timeouts( timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), world_mesh=world_mesh, ) if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) metric_logger.close() logger.info("Training completed") if __name__ == "__main__": init_logger() config = JobConfig() config.parse_args() main(config) torch.distributed.destroy_process_group()