|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import importlib |
|
import inspect |
|
import os |
|
import sys |
|
from collections import defaultdict |
|
from typing import Tuple, Union |
|
|
|
import torch |
|
|
|
try: |
|
import tomllib |
|
except ModuleNotFoundError: |
|
import tomli as tomllib |
|
|
|
from torchtitan.tools.logging import logger |
|
|
|
TORCH_DTYPE_MAP = { |
|
"float16": torch.float16, |
|
"float32": torch.float32, |
|
"bfloat16": torch.bfloat16, |
|
} |
|
|
|
|
|
def string_list(raw_arg): |
|
"""Comma-separated string list argument.""" |
|
return [s.strip() for s in raw_arg.split(",") if s.strip()] |
|
|
|
|
|
def check_string_list_argument(args_dict: dict[str, any], fullargname: str): |
|
section, name = fullargname.split(".") |
|
|
|
if ( |
|
section in args_dict |
|
and name in args_dict[section] |
|
and isinstance(args_dict[section][name], str) |
|
): |
|
sec = args_dict[section] |
|
sec[name] = string_list(sec[name]) |
|
|
|
|
|
class JobConfig: |
|
""" |
|
A helper class to manage the train configuration. |
|
Semantics: |
|
- Default config is loaded from a toml file. If no toml file is provided, |
|
then the default config is loaded from argparse defaults. |
|
- if toml file has missing keys, they are filled with argparse defaults. |
|
- if additional explicit cmd args are provided in addition to the toml |
|
file, they will override the toml config and the argparse defaults |
|
|
|
precedence order: cmdline > toml > argparse default |
|
|
|
Arg parsing semantics: |
|
|
|
Each argument starts with <prefix>_ which is the section name in the toml file |
|
followed by name of the option in the toml file. For ex, |
|
model.name translates to: |
|
[model] |
|
name |
|
in the toml file |
|
""" |
|
|
|
def __init__(self): |
|
self.args_dict = None |
|
|
|
self.parser = argparse.ArgumentParser(description="torchtitan arg parser.") |
|
|
|
self.parser.add_argument( |
|
"--job.config_file", |
|
type=str, |
|
default=None, |
|
help="Job config file", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--job.dump_folder", |
|
type=str, |
|
default="./torchtitan/outputs", |
|
help="Folder to dump job outputs", |
|
) |
|
self.parser.add_argument( |
|
"--job.description", |
|
type=str, |
|
default="default job", |
|
help="Description of the job", |
|
) |
|
self.parser.add_argument( |
|
"--job.use_for_integration_test", |
|
action="store_true", |
|
help="Add this config to the integration test suite", |
|
) |
|
self.parser.add_argument( |
|
"--job.print_args", |
|
action="store_true", |
|
help="Print the args to terminal", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--profiling.enable_profiling", |
|
action="store_true", |
|
help="Whether to enable pytorch profiler", |
|
) |
|
self.parser.add_argument( |
|
"--profiling.save_traces_folder", |
|
type=str, |
|
default="profile_traces", |
|
help="Trace files location", |
|
) |
|
self.parser.add_argument( |
|
"--profiling.profile_freq", |
|
type=int, |
|
default=10, |
|
help="How often to collect profiler traces, in iterations", |
|
) |
|
self.parser.add_argument( |
|
"--profiling.enable_memory_snapshot", |
|
action="store_true", |
|
help="Whether to dump memory snapshot", |
|
) |
|
self.parser.add_argument( |
|
"--profiling.save_memory_snapshot_folder", |
|
type=str, |
|
default="memory_snapshot", |
|
help="Memeory snapshot files location", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--metrics.log_freq", |
|
type=int, |
|
default=10, |
|
help="How often to log metrics to TensorBoard, in iterations", |
|
) |
|
self.parser.add_argument( |
|
"--metrics.enable_tensorboard", |
|
action="store_true", |
|
help="Whether to log metrics to TensorBoard", |
|
) |
|
self.parser.add_argument( |
|
"--metrics.disable_color_printing", |
|
action="store_true", |
|
help="Whether to disable color printing in logs", |
|
) |
|
self.parser.add_argument( |
|
"--metrics.save_tb_folder", |
|
type=str, |
|
default="tb", |
|
help="Folder to dump TensorBoard states", |
|
) |
|
self.parser.add_argument( |
|
"--metrics.save_for_all_ranks", |
|
action="store_true", |
|
default=False, |
|
help=""" |
|
Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks. |
|
When this option is False and pipeline_parallel_degree is > 1, the metrics |
|
component uses the 0th rank of the last stage pipeline group, which is the |
|
only stage that computes loss metrics. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--metrics.enable_wandb", |
|
action="store_true", |
|
help="Whether to log metrics to Weights & Biases", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--model.name", |
|
type=str, |
|
default="llama3", |
|
help="Which model to train", |
|
) |
|
self.parser.add_argument( |
|
"--model.flavor", |
|
type=str, |
|
default="debugmodel", |
|
help="Which model config to train", |
|
) |
|
self.parser.add_argument( |
|
"--model.norm_type", |
|
type=str, |
|
default="rmsnorm", |
|
choices=["layernorm", "np_layernorm", "rmsnorm"], |
|
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]", |
|
) |
|
self.parser.add_argument( |
|
"--model.use_flex_attn", |
|
action="store_true", |
|
help=""" |
|
Whether to use Flex Attention. |
|
Mixed usage of SDPA and FlexAttention is not upported yet. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--model.attn_mask_type", |
|
type=str, |
|
default="causal", |
|
choices=["causal", "block_causal"], |
|
help=""" |
|
Specifies the type of bias/mask used for attention. If SDPA is used, |
|
only the causal mask is supported by default. If FlexAttention is used, |
|
both causal and block_causal masks are supported. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--model.tokenizer_path", |
|
type=str, |
|
default="./assets/tokenizer/original/tokenizer.model", |
|
help="Tokenizer path", |
|
) |
|
self.parser.add_argument( |
|
"--model.converters", |
|
type=string_list, |
|
nargs="+", |
|
default=[], |
|
help=""" |
|
Comma separated list of converters to apply to the model. |
|
|
|
For instance, the `float8` converter swaps `torch.nn.Linear` |
|
with `Float8Linear`. This feature requires you to install 'torchao' |
|
which can be found here: https://github.com/pytorch/ao |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--model.print_after_conversion", |
|
action="store_true", |
|
help=""" |
|
If true, model definition will be printed to stdout after all model |
|
converters have been applied. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--optimizer.name", type=str, default="AdamW", help="Optimizer to use" |
|
) |
|
self.parser.add_argument( |
|
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use" |
|
) |
|
self.parser.add_argument( |
|
"--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" |
|
) |
|
self.parser.add_argument( |
|
"--optimizer.implementation", |
|
type=str, |
|
default="fused", |
|
choices=["for-loop", "foreach", "fused"], |
|
help=""" |
|
Specify which optimizer implementation to use: |
|
- 'fused': Use fused implementation (CUDA only) for best performance. |
|
- 'foreach': Use some horizontal fusion of tensors for better performance. |
|
- 'for-loop': Use the default implementation for the optimizer (slowest). |
|
- more info: https://pytorch.org/docs/stable/optim.html |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--optimizer.early_step_in_backward", |
|
action="store_true", |
|
help=""" |
|
Whether to apply optimizer in the backward. Caution, optimizer_in_backward |
|
is not compatible with gradients clipping, users should not call |
|
register_post_accumulate_grad_hook after the optimizer is built.""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--lr_scheduler.warmup_steps", |
|
type=int, |
|
default=200, |
|
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", |
|
) |
|
self.parser.add_argument( |
|
"--lr_scheduler.decay_ratio", |
|
type=float, |
|
default=None, |
|
help=""" |
|
Controls the proportion of the training steps allocated to the learning rate decay phase. |
|
|
|
If `None`, the learning rate will begin decaying immediately after the warmup period. |
|
Otherwise, the learning rate will remain stable after the warmup period and |
|
only start decaying during the last `decay_ratio` portion of the total training steps. |
|
|
|
This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--lr_scheduler.decay_type", |
|
type=str, |
|
default="linear", |
|
choices=["linear", "sqrt", "cosine"], |
|
help=""" |
|
Learning rate decay type to use during training: |
|
- 'linear': linearly decays learning rate from initial to final value |
|
- 'sqrt': decays learning rate following a 1 minus square root curve |
|
- 'cosine': smoothly decays learning rate following a cosine curve |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--lr_scheduler.lr_min", |
|
type=float, |
|
default=0.0, |
|
help=""" |
|
Min lr ratio for lr scheduler. |
|
|
|
If provided, the range of decay factor is scaled from 1 to `lr_min` |
|
to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--training.dataset", type=str, default="c4_test", help="Dataset to use" |
|
) |
|
self.parser.add_argument( |
|
"--training.dataset_path", |
|
type=str, |
|
help=""" |
|
Path to the dataset in the file system. If provided, data will be |
|
loaded from this path instead of downloaded.""", |
|
) |
|
self.parser.add_argument( |
|
"--training.batch_size", type=int, default=8, help="Batch size" |
|
) |
|
self.parser.add_argument( |
|
"--training.seq_len", type=int, default=2048, help="Sequence length" |
|
) |
|
self.parser.add_argument( |
|
"--training.max_norm", |
|
type=Union[float, int], |
|
default=1.0, |
|
help="Max norm for gradient clipping", |
|
) |
|
self.parser.add_argument( |
|
"--training.steps", |
|
type=int, |
|
default=10000, |
|
help="How many train steps to run", |
|
) |
|
self.parser.add_argument( |
|
"--training.enable_cpu_offload", |
|
action="store_true", |
|
help=""" |
|
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""", |
|
) |
|
self.parser.add_argument( |
|
"--training.mixed_precision_param", |
|
type=str, |
|
default="bfloat16", |
|
choices=["bfloat16", "float32"], |
|
help=""" |
|
torch dtype to use for parameters when applying mixed precision via FSDP. |
|
This feature only takes effect when data_parallel_shard_degree > 1 |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--training.mixed_precision_reduce", |
|
type=str, |
|
default="float32", |
|
choices=["float32"], |
|
help=""" |
|
torch dtype to use for reductions when applying mixed precision via FSDP. |
|
This feature only takes effect when data_parallel_shard_degree > 1 |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--training.compile", |
|
action="store_true", |
|
help="Whether to compile the model", |
|
) |
|
self.parser.add_argument( |
|
"--training.gc_freq", |
|
type=int, |
|
default=50, |
|
help="Python garbage control scheduling interval, in steps", |
|
) |
|
self.parser.add_argument( |
|
"--training.seed", |
|
type=int, |
|
default=None, |
|
help="Choose the base RNG seed used for training", |
|
) |
|
self.parser.add_argument( |
|
"--training.deterministic", |
|
action="store_true", |
|
help="Use deterministic algorithms wherever possible, may be slower", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--parallelism.data_parallel_replicate_degree", |
|
type=int, |
|
default=1, |
|
help=""" |
|
The `data_parallel_replicate_degree` argument specifies the degree of |
|
data parallelism for weight replication. When this value is greater |
|
than 1, weights will be replicated across `data_parallel_replicate_degree` |
|
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism |
|
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the |
|
parallelism method used is DDP (Distributed Data Parallelism). |
|
1 means disabled.""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.enable_compiled_autograd", |
|
action="store_true", |
|
help="Enable CompiledAutograd to compile the backward.", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.data_parallel_shard_degree", |
|
type=int, |
|
default=-1, |
|
help=""" |
|
The `data_parallel_shard_degree` argument specifies the degree of data |
|
parallelism for weight sharding. When this value is greater than 1, weights |
|
will be sharded across `data_parallel_shard_degree` ranks. If |
|
`data_parallel_replicate_degree` is also greater than 1, the parallelism |
|
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the |
|
parallelism method used is FSDP (Fully Sharded Data Parallelism). |
|
|
|
-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that |
|
only `data_parallel_shard_degree` can be negative. 1 means disabled.""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.fsdp_reshard_after_forward", |
|
type=str, |
|
default="default", |
|
choices=["default", "always", "never"], |
|
help=""" |
|
`reshard_after_forward` specifies the policy for applying `reshard_after_forward` |
|
within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, |
|
trading off memory and communication. See torch's `fully_shard` API for more documentation |
|
on `reshard_after_forward`. |
|
The supported policies include "default", "always" and "never": |
|
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal |
|
scenarios. |
|
- "always" will enable `reshard_after_forward` for all forward passes. |
|
- "never" will disable `reshard_after_forward` for all forward passes. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.tensor_parallel_degree", |
|
type=int, |
|
default=1, |
|
help="Tensor Parallelism degree. 1 means disabled.", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.disable_loss_parallel", |
|
action="store_true", |
|
help="Whether to apply loss parallel when sequence parallel is enabled", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.enable_async_tensor_parallel", |
|
action="store_true", |
|
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_degree", |
|
type=int, |
|
default=1, |
|
help=""" |
|
Pipeline Parallelism degree, or number of ranks. 1 means disabled. |
|
If using looped schedules, this still specifies the number of physical ranks, not the number |
|
of stages. Stages per rank are inferred from split points degree, and schedule.""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_split_points", |
|
type=string_list, |
|
nargs="+", |
|
default=[], |
|
help=""" |
|
Specify comma-separated names of modules to use as the beginning of a split point. |
|
|
|
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, |
|
the first containing all the layers up to layers.0, |
|
the second containing layers.0 and up to layers.2, |
|
the third containing layers.2 and all the remaining layers. |
|
|
|
Note: fully-automated splitting may be enabled in the future, |
|
but currently the split points must be specified manually.""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_layers_per_stage", |
|
type=int, |
|
default=None, |
|
help=""" |
|
The number of layers per stage. If specified, the split points will be calculated from |
|
the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will |
|
be inferred from the model, schedule, and pipeline_parallel_degree.""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_schedule", |
|
type=str, |
|
default="1F1B", |
|
help=""" |
|
Specify the Pipeline Parallel schedule to use. The supported schedules are: |
|
https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161. |
|
The schedule must be compatible with the split points and stages_per_rank. |
|
|
|
Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks, |
|
and split_points = number of stages - 1 |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_schedule_csv", |
|
type=str, |
|
default="", |
|
help=""" |
|
Specify the path to the pipeline parallel schedule csv file to use. |
|
The pipeline_parallel_schedule argument must be either |
|
PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.pipeline_parallel_microbatch_size", |
|
type=int, |
|
default=1, |
|
help=""" |
|
The size of each pipeline parallel microbatch (default 1). |
|
|
|
This value is used to compute the total number of microbatches by dividing batch_size with |
|
pipeline_parallel_microbatch_size. |
|
|
|
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.context_parallel_degree", |
|
type=int, |
|
default=1, |
|
help="Context parallelism degree. 1 means disabled.", |
|
) |
|
self.parser.add_argument( |
|
"--parallelism.context_parallel_rotate_method", |
|
type=str, |
|
default="allgather", |
|
help=""" |
|
The collective to use in context parallel SDPA for kv shards exchange. |
|
|
|
'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation, |
|
|
|
'alltoall' means to all-to-all shuffle the kv shards. |
|
|
|
The default value is 'allgather'. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--checkpoint.enable_checkpoint", |
|
action="store_true", |
|
help="Whether to enable checkpoint", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.folder", |
|
type=str, |
|
default="checkpoint", |
|
help=""" |
|
The folder to store the checkpoints. |
|
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.interval", |
|
type=int, |
|
default=500, |
|
help="Checkpointing interval in steps.", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.model_weights_only", |
|
action="store_true", |
|
help=""" |
|
When model_weights_only=True, only model weights will be saved at the end of training. |
|
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. |
|
When model_weights_only=False, the full checkpoint will be saved. |
|
A full checkpoint includes model, optimizer and train_state, which can be used to resume training. |
|
The default value is false. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.export_dtype", |
|
type=str, |
|
default="float32", |
|
choices=["float16", "bfloat16", "float32"], |
|
help=""" |
|
Converts to the specified precision when training completes and model_weights_only=true. |
|
Currently supports float32, float16, and bfloat16. |
|
The default value is float32. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.create_seed_checkpoint", |
|
action="store_true", |
|
help=""" |
|
Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint. |
|
Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1. |
|
Could be implemented as a separate script, but this way shares more code. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.async_mode", |
|
type=str, |
|
default="disabled", |
|
help=""" |
|
Which async checkpoint mode to use. Currently there are 3 different modes. |
|
1. "disabled": synchronized checkpointing will be used. |
|
2. "async": torch.distributed.checkpoint.async_save will be used. |
|
3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory |
|
space and creates a separate process for faster GPU->CPU transfer |
|
performance and eliminating GIL contention. The cost is increased CPU |
|
memory usage. If insufficient CPU memory is available, performance may |
|
degrade due to memory paging. For most users, "async" should suffice as |
|
the performance overhead is typically small (on the order of tens of |
|
seconds) compared to checkpointing frequency. This mode can be employed |
|
to pursue near-zero checkpointing times (e.g., < 1 second) given |
|
appropriate hardware support such as ample CPU memory and fast PCIe. |
|
|
|
"disabled" is the default mode. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.keep_latest_k", |
|
type=int, |
|
default=10, |
|
help=""" |
|
Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints. |
|
K cannot be 1 as the last one may be in the process of being saved. As a result, |
|
the metadata of the last one may not be ready yet. The default value is 10 to avoid |
|
filling up the disk. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.load_step", |
|
type=int, |
|
default=-1, |
|
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", |
|
) |
|
self.parser.add_argument( |
|
"--checkpoint.exclude_from_loading", |
|
type=string_list, |
|
nargs="*", |
|
default=[], |
|
help=""" |
|
Exclude specific keys from being loaded from the checkpoint. |
|
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'. |
|
This will load the model only, excluding the specified keys. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--activation_checkpoint.mode", |
|
type=str, |
|
default="selective", |
|
help="Type of activation checkpointing to use ['none', 'full', 'selective']", |
|
) |
|
self.parser.add_argument( |
|
"--activation_checkpoint.selective_ac_option", |
|
type=str, |
|
default="2", |
|
help=""" |
|
Selective activation checkpointing options ['int', 'op']. |
|
'int' (e.g., 2) for every nth layer, or 'op' for op level ac. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--float8.enable_fsdp_float8_all_gather", |
|
action="store_true", |
|
help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling", |
|
) |
|
self.parser.add_argument( |
|
"--float8.precompute_float8_dynamic_scale_for_fsdp", |
|
action="store_true", |
|
help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling", |
|
) |
|
self.parser.add_argument( |
|
"--float8.force_recompute_fp8_weight_in_bwd", |
|
action="store_true", |
|
help=""" |
|
Whether to force the recomputation of FP8 weights during backward pass. |
|
When using FSDP with tensorwise scaling, it is recommended to enable |
|
`force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights |
|
for backward computation. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--float8.recipe_name", |
|
type=str, |
|
default=None, |
|
choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"], |
|
help=""" |
|
If specified, creates float8 config from recipe name, valid choices are |
|
`tensorwise`, `rowwise` and `rowwise_with_gw_hp`. |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--float8.filter_fqns", |
|
type=string_list, |
|
default=[], |
|
nargs="+", |
|
help=""" |
|
Comma-separated list of fully qualified names of modules to skip applying float8 training to. |
|
nn.Linear modules with any dim size not divisible by 16 are always skipped due to hardware requirements. |
|
Example: --float8.module_filter_fqns "attention.wq,attention.wk,attention.wv,output" |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--comm.init_timeout_seconds", |
|
type=int, |
|
default=300, |
|
help="Timeout for communication operations, during initialization and first train step.", |
|
) |
|
self.parser.add_argument( |
|
"--comm.train_timeout_seconds", |
|
type=int, |
|
default=100, |
|
help=( |
|
"Timeout for communication operations after the first train step -- " |
|
"usually a tighter bound than during initialization." |
|
), |
|
) |
|
self.parser.add_argument( |
|
"--comm.trace_buf_size", |
|
type=int, |
|
default=20000, |
|
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--memory_estimation.enabled", |
|
help="Whether to estimate memory usage for FSDP", |
|
action="store_true", |
|
) |
|
|
|
self.parser.add_argument( |
|
"--memory_estimation.disable_fake_mode", |
|
help="Whether to estimate memory under FakeTensorMode", |
|
action="store_true", |
|
) |
|
|
|
self.parser.add_argument( |
|
"--fault_tolerance.enable", |
|
action="store_true", |
|
help=""" |
|
Enable TorchFT integration. When TorchFT is enabled, HSDP will be used. |
|
And --fault_tolerance.data_parallel_replicate_degree should be 1 and |
|
--fault_tolerance.group_size will be used to control the maximum |
|
replicate group size as the replicate group size is dynamic. |
|
|
|
Note that this is still an experimental feature. |
|
""", |
|
) |
|
|
|
|
|
self.parser.add_argument( |
|
"--fault_tolerance.replica_id", |
|
type=int, |
|
default=0, |
|
help="The TorchFT replica ID of this run.", |
|
) |
|
self.parser.add_argument( |
|
"--fault_tolerance.group_size", |
|
type=int, |
|
default=0, |
|
help=""" |
|
The number of TorchFT replicate groups. This number will be used for |
|
dataloader to split the dataset across the replicate groups and FSDP |
|
dimension |
|
""", |
|
) |
|
self.parser.add_argument( |
|
"--fault_tolerance.min_replica_size", |
|
type=int, |
|
default=1, |
|
help="The minimum number of FT replica for each step.", |
|
) |
|
|
|
self.parser.add_argument( |
|
"--experimental.custom_import", |
|
type=str, |
|
default="", |
|
help=""" |
|
This option enables the importation of external modules. |
|
Currently, it only supports dotted import modules (e.g., some_package.model_x). |
|
It is the user's responsibility to ensure that the specified path can be |
|
successfully imported. One method to achieve this, you can place your module |
|
inside the ``torchtitan/torchtitan`` folder and execute ``pip install -e .`` to |
|
make it available for import. |
|
""", |
|
) |
|
|
|
self.parser.add_argument( |
|
"--experimental.custom_args_module", |
|
type=str, |
|
default="", |
|
help=""" |
|
This option allows users to extend TorchTitan's existing JobConfig by importing |
|
a customized module. Similar to ``--experimental.custom_model_path``, the user |
|
needs to ensure that the path can be imported. The module should contain exactly |
|
one public function and the function has the signature |
|
``def func(parser: argparse.ArgumentParser) -> None:``. The user can use the |
|
given parser to add new argument by calling``parser.add_argument``, as wish. |
|
""", |
|
) |
|
|
|
self._is_parsed = False |
|
self._allow_unkown_args = False |
|
|
|
def maybe_add_custom_args(self) -> None: |
|
"""Add custom arguments to the parser if --experimental.custom_args_module is set. |
|
|
|
Note: This function should be called before the parser is used to parse arguments. |
|
""" |
|
if self._is_parsed: |
|
raise RuntimeError( |
|
"JobConfig has already been parsed. We could not add new arguments." |
|
) |
|
|
|
self._allow_unkown_args = True |
|
self.parse_args(sys.argv[1:]) |
|
self._allow_unkown_args = False |
|
|
|
if self.experimental.custom_args_module: |
|
module = importlib.import_module(self.experimental.custom_args_module) |
|
public_functions = [ |
|
name |
|
for name, func in inspect.getmembers(module) |
|
if inspect.isfunction(func) and not name.startswith("_") |
|
] |
|
func = getattr(module, public_functions[0]) |
|
func(self.parser) |
|
|
|
def to_dict(self): |
|
return self.args_dict |
|
|
|
def parse_args(self, args_list: list = sys.argv[1:]): |
|
self._is_parsed = True |
|
args, cmd_args = self.parse_args_from_command_line(args_list) |
|
config_file = getattr(args, "job.config_file", None) |
|
|
|
args_dict = self._args_to_two_level_dict(args) |
|
if config_file is not None: |
|
try: |
|
with open(config_file, "rb") as f: |
|
for k, v in tomllib.load(f).items(): |
|
|
|
args_dict[k] |= v |
|
except (FileNotFoundError, tomllib.TOMLDecodeError) as e: |
|
logger.exception( |
|
f"Error while loading the configuration file: {config_file}" |
|
) |
|
logger.exception(f"Error details: {str(e)}") |
|
raise e |
|
|
|
|
|
|
|
string_list_argnames = self._get_string_list_argument_names() |
|
for n in string_list_argnames: |
|
check_string_list_argument(args_dict, n) |
|
|
|
|
|
cmd_args_dict = self._args_to_two_level_dict(cmd_args) |
|
for section, section_args in cmd_args_dict.items(): |
|
for k, v in section_args.items(): |
|
args_dict[section][k] = v |
|
|
|
self.args_dict = args_dict |
|
|
|
for k, v in args_dict.items(): |
|
class_type = type(k.title(), (), v) |
|
setattr(self, k, class_type()) |
|
self._validate_config() |
|
|
|
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: |
|
args_dict = defaultdict(defaultdict) |
|
for k, v in vars(args).items(): |
|
first_level_key, second_level_key = k.split(".", 1) |
|
args_dict[first_level_key][second_level_key] = v |
|
return args_dict |
|
|
|
def _validate_config(self) -> None: |
|
|
|
|
|
if not os.path.exists(self.model.tokenizer_path): |
|
logger.warning( |
|
f"Tokenizer path {self.model.tokenizer_path} does not exist!" |
|
) |
|
old_tokenizer_path = ( |
|
"torchtitan/datasets/tokenizer/original/tokenizer.model" |
|
) |
|
if os.path.exists(old_tokenizer_path): |
|
self.model.tokenizer_path = old_tokenizer_path |
|
logger.warning( |
|
f"Temporarily switching to previous default tokenizer path {old_tokenizer_path}. " |
|
"Please update your config." |
|
) |
|
|
|
def _get_string_list_argument_names(self) -> list[str]: |
|
"""Get the parser argument names of type `string_list`.""" |
|
string_list_args = [ |
|
v.dest for v in self.parser._actions if v.type is string_list |
|
] |
|
return string_list_args |
|
|
|
def parse_args_from_command_line( |
|
self, args_list |
|
) -> Tuple[argparse.Namespace, argparse.Namespace]: |
|
""" |
|
Parse command line arguments and return the parsed args and the command line only args |
|
""" |
|
if self._allow_unkown_args: |
|
args, _ = self.parser.parse_known_args(args_list) |
|
else: |
|
args = self.parser.parse_args(args_list) |
|
string_list_argnames = set(self._get_string_list_argument_names()) |
|
|
|
|
|
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) |
|
for arg, val in vars(args).items(): |
|
if isinstance(val, bool): |
|
aux_parser.add_argument( |
|
"--" + arg, action="store_true" if val else "store_false" |
|
) |
|
elif arg in string_list_argnames: |
|
|
|
|
|
|
|
aux_parser.add_argument("--" + arg, type=string_list) |
|
else: |
|
aux_parser.add_argument("--" + arg, type=type(val)) |
|
|
|
cmd_args, _ = aux_parser.parse_known_args(args_list) |
|
|
|
return args, cmd_args |
|
|