zaydzuhri commited on
Commit
8fbfec1
·
verified ·
1 Parent(s): 64f196f

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +471 -0
  3. config.json +34 -0
  4. configs/delta_net_1B.json +29 -0
  5. configs/delta_net_340M.json +27 -0
  6. configs/gla_340M.json +24 -0
  7. configs/gla_7B.json +25 -0
  8. configs/gsa_340M.json +29 -0
  9. configs/hgrn2_340M.json +20 -0
  10. configs/mtp_transformer_120M.json +19 -0
  11. configs/mtp_transformer_1B.json +23 -0
  12. configs/mtp_transformer_340M.json +19 -0
  13. configs/mtp_transformer_7B.json +22 -0
  14. configs/top_transformer_120M.json +20 -0
  15. configs/top_transformer_1B.json +24 -0
  16. configs/top_transformer_340M.json +20 -0
  17. configs/top_transformer_7B.json +23 -0
  18. configs/transformer_120M.json +18 -0
  19. configs/transformer_1B.json +22 -0
  20. configs/transformer_340M.json +18 -0
  21. configs/transformer_7B.json +21 -0
  22. download_checkpoint.py +35 -0
  23. fla/__init__.py +110 -0
  24. fla/utils.py +223 -0
  25. flame/__init__.py +1 -0
  26. flame/__pycache__/train.cpython-312.pyc +0 -0
  27. flame/config_manager.py +940 -0
  28. flame/data.py +570 -0
  29. flame/models/__init__.py +0 -0
  30. flame/tools/utils.py +41 -0
  31. flame/train.py +897 -0
  32. generation_config.json +7 -0
  33. model.safetensors.index.json +298 -0
  34. pyproject.toml +43 -0
  35. setup.py +51 -0
  36. special_tokens_map.json +23 -0
  37. tb/20250716-2210/wandb/debug-internal.log +90 -0
  38. tb/20250716-2210/wandb/debug.log +28 -0
  39. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/config.yaml +150 -0
  40. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/requirements.txt +101 -0
  41. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/wandb-metadata.json +146 -0
  42. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/wandb-summary.json +1 -0
  43. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-core.log +16 -0
  44. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug.log +28 -0
  45. tokenizer.json +0 -0
  46. tokenizer_config.json +44 -0
  47. torchtitan/__init__.py +15 -0
  48. torchtitan/config_manager.py +947 -0
  49. torchtitan/train.py +482 -0
  50. train.sh +121 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
8
+
9
+ **Feature Highlights:**
10
+
11
+ - 🚀 Minimal, easy-to-use, extensible training framework
12
+ - 🤗 Seamless integration with `fla` and `transformers`
13
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
14
+ - 🔮 4D parallelism (coming soon)
15
+
16
+ ## Setup
17
+
18
+ To get started, clone the `flame` repository and install the required dependencies:
19
+
20
+ ```bash
21
+ git clone https://github.com/fla-org/flame.git
22
+ cd flame
23
+ pip install .
24
+ ```
25
+
26
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
27
+ After installation, initialize and update the submodules:
28
+ ```sh
29
+ git submodule update --init --recursive
30
+ ```
31
+
32
+ ## Dataset Preparation
33
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
34
+
35
+ ```py
36
+ from datasets import load_dataset
37
+
38
+ # load fineweb-edu with parallel processing
39
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
40
+
41
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
42
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
43
+ ```
44
+
45
+ ## Training Recipes
46
+
47
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
48
+
49
+ > [!WARNING]
50
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
51
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
52
+
53
+ ```sh
54
+ bash train.sh \
55
+ --job.config_file flame/models/fla.toml \
56
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
57
+ --model.config configs/transformer_340M.json \
58
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
59
+ --optimizer.name AdamW \
60
+ --optimizer.eps 1e-15 \
61
+ --optimizer.lr 3e-4 \
62
+ --lr_scheduler.warmup_steps 1024 \
63
+ --lr_scheduler.lr_min 0.1 \
64
+ --lr_scheduler.decay_type cosine \
65
+ --training.batch_size 1 \
66
+ --training.seq_len 65536 \
67
+ --training.context_len 4096 \
68
+ --training.varlen \
69
+ --training.gradient_accumulation_steps 1 \
70
+ --training.steps 20480 \
71
+ --training.max_norm 1.0 \
72
+ --training.skip_nan_inf \
73
+ --training.dataset HuggingFaceFW/fineweb-edu \
74
+ --training.dataset_name sample-100BT \
75
+ --training.dataset_split train \
76
+ --training.streaming \
77
+ --training.num_workers 32 \
78
+ --training.prefetch_factor 2 \
79
+ --training.seed 42 \
80
+ --training.compile \
81
+ --checkpoint.interval 2048 \
82
+ --checkpoint.load_step -1 \
83
+ --checkpoint.keep_latest_k 2 \
84
+ --metrics.log_freq 1
85
+ ```
86
+
87
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
88
+ **For single-GPU debugging, set `NGPU=1`.**
89
+
90
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
91
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
92
+
93
+ **Key parameters:**
94
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
95
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
96
+ - `--training.steps`: Total number of training steps.
97
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
98
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
99
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
100
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
101
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
102
+
103
+ > [!WARNING]
104
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
105
+ > Each step processes `global_batch_size * seq_len` tokens.
106
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
107
+
108
+ For a detailed explanation of all parameters, run:
109
+
110
+ ```sh
111
+ bash train.sh -h
112
+ ```
113
+
114
+ <details>
115
+ <summary>Usage</summary>
116
+
117
+ ```py
118
+ options:
119
+ -h, --help show this help message and exit
120
+ --job.config_file JOB.CONFIG_FILE
121
+ Job config file
122
+ --job.dump_folder JOB.DUMP_FOLDER
123
+ Folder to dump job outputs
124
+ --job.description JOB.DESCRIPTION
125
+ Description of the job
126
+ --job.use_for_integration_test
127
+ Add this config to the integration test suite
128
+ --job.print_args Print the args to terminal
129
+ --model.config MODEL.CONFIG
130
+ Path to the model config
131
+ --model.norm_type MODEL.NORM_TYPE
132
+ Type of layer normalization to use [layernorm,
133
+ np_layernorm, rmsnorm, fused_rmsnorm]
134
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
135
+ Tokenizer path
136
+ --profiling.enable_profiling
137
+ Whether to enable pytorch profiler
138
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
139
+ Trace files location
140
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
141
+ How often to collect profiler traces, in iterations
142
+ --profiling.enable_memory_snapshot
143
+ Whether to dump memory snapshot
144
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
145
+ Memeory snapshot files location
146
+ --optimizer.name OPTIMIZER.NAME
147
+ Optimizer to use
148
+ --optimizer.eps OPTIMIZER.EPS
149
+ Epsilon value for the optimizer.
150
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
151
+ --optimizer.scheduler {wsd,cosine,linear}
152
+ Scheduler to use. Currently supported: wsd, cosine,
153
+ and linear.
154
+ --optimizer.lr OPTIMIZER.LR
155
+ Learning rate to use
156
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
157
+ Min lr ratio for lr scheduler
158
+ --optimizer.early_step_in_backward
159
+ Whether to apply optimizer in the backward. Caution,
160
+ optimizer_in_backward is not compatible with gradients
161
+ clipping, users should not call
162
+ register_post_accumulate_grad_hook after the optimizer
163
+ is built.
164
+ --training.batch_size TRAINING.BATCH_SIZE
165
+ Batch size
166
+ --training.seq_len TRAINING.SEQ_LEN
167
+ Sequence length
168
+ --training.context_len TRAINING.CONTEXT_LEN
169
+ Max length allowed for each sequence
170
+ --training.varlen Whether to take sequences of variable length as input
171
+ --training.warmup_steps TRAINING.WARMUP_STEPS
172
+ Steps for lr scheduler warmup, normally 1/5 of
173
+ --training.steps
174
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
175
+ Number of steps to accumulate gradients before
176
+ updating parameters
177
+ --training.steps TRAINING.STEPS
178
+ How many train steps to run
179
+ --training.max_norm TRAINING.MAX_NORM
180
+ Max norm for gradient clipping
181
+ --training.skip_nan_inf
182
+ Skip batch updates when NaN or INF gradients are
183
+ encountered during training
184
+ --training.dataset TRAINING.DATASET
185
+ Dataset to use, with comma separated values
186
+ --training.dataset_name TRAINING.DATASET_NAME
187
+ The name of the dataset config, with comma separated
188
+ values if provided
189
+ --training.dataset_split TRAINING.DATASET_SPLIT
190
+ Dataset split to use, with comma separated values if
191
+ provided
192
+ --training.data_dir TRAINING.DATA_DIR
193
+ Data dirs to use, with comma separated values if
194
+ provided
195
+ --training.data_files TRAINING.DATA_FILES
196
+ Data files to use, with comma separated values if
197
+ provided
198
+ --training.data_probs TRAINING.DATA_PROBS
199
+ Data sampling probabilities, with comma separated
200
+ values if provided
201
+ --training.streaming Whether to load dataset in streaming mode, used for
202
+ huge dataset
203
+ --training.num_workers TRAINING.NUM_WORKERS
204
+ Number of subprocesses to use for data loading. 0
205
+ means that the data will be loaded in the main
206
+ process.
207
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
208
+ Number of batches loaded in advance by each worker.2
209
+ means there will be a total of 2 * num_workers batches
210
+ prefetched across all workers.
211
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
212
+ The `data_parallel_replicate_degree` argument
213
+ specifies the degree of data parallelism for weight
214
+ replication. When this value is greater than 1,
215
+ weights will be replicated across
216
+ `data_parallel_replicate_degree` ranks. If
217
+ `data_parallel_shard_degree` is also greater than 1,
218
+ the parallelism method used is HSDP (Hybrid Sharded
219
+ Data Parallelism). Otherwise, the parallelism method
220
+ used is DDP (Distributed Data Parallelism). 1 means
221
+ disabled.
222
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
223
+ The `data_parallel_shard_degree` argument specifies
224
+ the degree of data parallelism for weight sharding.
225
+ When this value is greater than 1, weights will be
226
+ sharded across `data_parallel_shard_degree` ranks. If
227
+ `data_parallel_replicate_degree` is also greater than
228
+ 1, the parallelism method used is HSDP (Hybrid Sharded
229
+ Data Parallelism). Otherwise, the parallelism method
230
+ used is FSDP (Fully Sharded Data Parallelism). -1
231
+ means leftover ranks will be used (After
232
+ DP_REPLICATE/SP/PP). Note that only
233
+ `data_parallel_shard_degree` can be negative. 1 means
234
+ disabled.
235
+ --training.enable_cpu_offload
236
+ Whether to apply CPU offloading of parameters,
237
+ gradients, and optimizer states in FSDP
238
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
239
+ Tensor Parallelism degree. 1 means disabled.
240
+ --training.disable_loss_parallel
241
+ Whether to apply loss parallel when sequence parallel
242
+ is enabled
243
+ --training.mixed_precision_param {bfloat16,float32}
244
+ torch dtype to use for parameters when applying mixed
245
+ precision via FSDP. This feature only takes effect
246
+ when data_parallel_shard_degree > 1
247
+ --training.mixed_precision_reduce {float32}
248
+ torch dtype to use for reductions when applying mixed
249
+ precision via FSDP. This feature only takes effect
250
+ when data_parallel_shard_degree > 1
251
+ --training.compile Whether to compile the model
252
+ --training.gc_freq TRAINING.GC_FREQ
253
+ Python garbage control scheduling interval, in steps
254
+ --training.seed TRAINING.SEED
255
+ Choose the base RNG seed used for training
256
+ --training.deterministic
257
+ Use deterministic algorithms wherever possible, may be
258
+ slower
259
+ --metrics.log_freq METRICS.LOG_FREQ
260
+ How often to log metrics to TensorBoard, in iterations
261
+ --metrics.enable_tensorboard
262
+ Whether to log metrics to TensorBoard
263
+ --metrics.disable_color_printing
264
+ Whether to disable color printing in logs
265
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
266
+ Folder to dump TensorBoard states
267
+ --metrics.rank_0_only
268
+ Whether to save TensorBoard metrics only for rank 0 or
269
+ for all ranks. When pipeline_parallel_degree is > 1,
270
+ this option uses the 0th rank of the last stage
271
+ pipeline group, which is the only stage that computes
272
+ loss metrics.
273
+ --metrics.enable_wandb
274
+ Whether to log metrics to Weights & Biases
275
+ --experimental.enable_async_tensor_parallel
276
+ Whether to apply async tensor parallel (currently only
277
+ effective when compile is enabled)
278
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
279
+ Pipeline Parallelism degree, or number of ranks. 1
280
+ means disabled. If using looped schedules, this still
281
+ specifies the number of physical ranks, not the number
282
+ of stages. Stages per rank are inferred from split
283
+ points degree, and schedule.
284
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
285
+ Specify comma-separated names of modules to use as the
286
+ beginning of a split point. e.g. "layers.0,layers.2"
287
+ will cause the model to be split into 3 stages, the
288
+ first containing all the layers up to layers.0, the
289
+ second containing layers.0 and up to layers.2, the
290
+ third containing layers.2 and all the remaining
291
+ layers. Note: fully-automated splitting may be enabled
292
+ in the future, but currently the split points must be
293
+ specified manually.
294
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
295
+ Specify the Pipeline Parallel schedule to use. The
296
+ supported schedules are: https://github.com/pytorch/py
297
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
298
+ rch/distributed/pipelining/schedules.py#L2161. The
299
+ schedule must be compatible with the split points and
300
+ stages_per_rank. Looped schedules (e.g.
301
+ Interleaved1F1B) require specifying
302
+ pipeline_parallel_degree = number of ranks, and
303
+ split_points = number of stages - 1
304
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
305
+ Specify the path to the pipeline parallel schedule csv
306
+ file to use. The pipeline_parallel_schedule argument
307
+ must be either PipelineScheduleSingle,
308
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
309
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
310
+ How many microbatches to split the global training
311
+ batch into when using pipeline parallelism. The global
312
+ training batch size must be evenly divisible by the
313
+ number of microbatches. The default value will be the
314
+ number of pipeline stages, if unspecified.
315
+ --experimental.enable_compiled_autograd
316
+ Enable CompiledAutograd to compile the backward.
317
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
318
+ Context parallelism degree. 1 means disabled.
319
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
320
+ The collective to use in context parallel SDPA for kv
321
+ shards exchange. 'allgather' means to all-gather all
322
+ kv shards on ranks after the first sub-SDPA
323
+ computation, 'alltoall' means to all-to-all shuffle
324
+ the kv shards. The default value is 'allgather'.
325
+ --checkpoint.enable_checkpoint
326
+ Whether to enable checkpoint
327
+ --checkpoint.folder CHECKPOINT.FOLDER
328
+ The folder to store the checkpoints. When
329
+ enable_checkpoint is set to true, checkpoints will be
330
+ in {--job.dump_folder}/{--checkpoint.folder}.
331
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
332
+ Checkpointing interval unit of measurement ['step',
333
+ 'seconds']
334
+ --checkpoint.interval CHECKPOINT.INTERVAL
335
+ Checkpointing interval, in steps or seconds depending
336
+ on --checkpoint.interval_type
337
+ --checkpoint.model_weights_only
338
+ When model_weights_only=True, only model weights will
339
+ be saved at the end of training. With this,
340
+ checkpoints can be loaded using `torch.load(...,
341
+ weights_only=True)` after conversion. When
342
+ model_weights_only=False, the full checkpoint will be
343
+ saved. A full checkpoint includes model, optimizer and
344
+ train_state, which can be used to resume training. The
345
+ default value is false.
346
+ --checkpoint.export_dtype {float16,bfloat16,float32}
347
+ Converts to the specified precision when training
348
+ completes and model_weights_only=true. Currently
349
+ supports float32, float16, and bfloat16. The default
350
+ value is float32.
351
+ --checkpoint.create_seed_checkpoint
352
+ Initializes the full model without applying
353
+ parallelisms, and then saves it as a seed checkpoint.
354
+ Note: requires user to call train.py without
355
+ specifying any parallelisms, e.g. NGPU=1. Could be
356
+ implemented as a separate script, but this way shares
357
+ more code.
358
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
359
+ Which async checkpoint mode to use. Currently there
360
+ are 3 different modes. 1. "disabled": synchronized
361
+ checkpointing will be used. 2. "async":
362
+ torch.distributed.checkpoint.async_save will be used.
363
+ 1. "async_with_pinned_mem": this option utilizes a
364
+ dedicated pinned memory space and creates a separate
365
+ process for faster GPU->CPU transfer performance and
366
+ eliminating GIL contention. The cost is increased CPU
367
+ memory usage. If insufficient CPU memory is available,
368
+ performance may degrade due to memory paging. For most
369
+ users, "async" should suffice as the performance
370
+ overhead is typically small (on the order of tens of
371
+ seconds) compared to checkpointing frequency. This
372
+ mode can be employed to pursue near-zero checkpointing
373
+ times (e.g., < 1 second) given appropriate hardware
374
+ support such as ample CPU memory and fast PCIe.
375
+ "disabled" is the default mode.
376
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
377
+ Keeps only the latest k checkpoints, and purging older
378
+ ones. If 0, keep all checkpoints. 0 is the default
379
+ value.
380
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
381
+ Load the checkpoint at the specified step. If -1, load
382
+ the latest checkpoint.
383
+ --float8.enable_float8_linear
384
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
385
+ This feature requires you to install 'torchao' which
386
+ can be found here: https://github.com/pytorch/ao
387
+ --float8.enable_fsdp_float8_all_gather
388
+ Whether enable float8 all-gather in FSDP
389
+ --float8.precompute_float8_dynamic_scale_for_fsdp
390
+ Whether precompute float8 scales dynamically for FSDP
391
+ --float8.scaling_type_input {dynamic,delayed}
392
+ float8 scaling for input, dynamic (default) or delayed
393
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
394
+ float8 scaling for input, dynamic (default) or delayed
395
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
396
+ float8 scaling for input, dynamic (default) or delayed
397
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
398
+ Timeout for communication operations, during
399
+ initialization and first train step.
400
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
401
+ Timeout for communication operations after the first
402
+ train step -- usually a tighter bound than during
403
+ initialization.
404
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
405
+ Flight recorder ring buffer size, >0 means recording
406
+ by default, 0 means disabled
407
+ --memory_estimation.enabled
408
+ Whether to estimate memory usage for FSDP
409
+ --memory_estimation.disable_fake_mode
410
+ Whether to estimate memory under FakeTensorMode
411
+ ```
412
+ </details>
413
+
414
+ ### Training with `torch.compile`
415
+
416
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
417
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
418
+
419
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
420
+ We are actively working on resolving these issues to make compilation transparent to users.
421
+ In the meantime, please ensure you are using the latest dependencies.
422
+
423
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
424
+
425
+ ### Training with multiple datasets
426
+
427
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
428
+ `flame` allows training with multiple datasets easily.
429
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
430
+
431
+ ```sh
432
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
433
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
434
+ ```
435
+
436
+ ### ~Finalizing training~
437
+
438
+ > [!NOTE]
439
+ > We have done this conversion automatically in the training script since our latest updates.
440
+
441
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
442
+ To facilitate this, we provide a straightforward conversion script:
443
+
444
+ ```sh
445
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
446
+ ```
447
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
448
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
449
+
450
+ ### Continual training
451
+
452
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
453
+ This allows you to seamlessly resume training with `flame`.
454
+ ```sh
455
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
456
+ ```
457
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
458
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
459
+
460
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
461
+
462
+ ## Multi-node training
463
+
464
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
465
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
466
+
467
+ To set up multi-node training:
468
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
469
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
470
+
471
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MTPTransformerForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "elementwise_affine": true,
7
+ "eos_token_id": 2,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "fuse_swiglu": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "max_position_embeddings": 8192,
17
+ "model_type": "mtp_transformer",
18
+ "n_future_tokens": 4,
19
+ "norm_eps": 1e-06,
20
+ "num_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "num_kv_heads": null,
23
+ "pad_token_id": 2,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.50.3",
30
+ "use_cache": true,
31
+ "use_custom_backward": false,
32
+ "vocab_size": 32000,
33
+ "window_size": null
34
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/mtp_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "mtp_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "n_future_tokens": 4
23
+ }
configs/mtp_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "mtp_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "n_future_tokens": 4
22
+ }
configs/top_transformer_120M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 2048
20
+ }
configs/top_transformer_1B.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "top_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "use_top_loss": true,
23
+ "top_window_size": 4096
24
+ }
configs/top_transformer_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 4096
20
+ }
configs/top_transformer_7B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "top_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "use_top_loss": true,
22
+ "top_window_size": 4096
23
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel
55
+ )
56
+
57
+ __all__ = [
58
+ 'ABCAttention',
59
+ 'Attention',
60
+ 'BasedLinearAttention',
61
+ 'BitAttention',
62
+ 'DeltaNet',
63
+ 'GatedDeltaNet',
64
+ 'GatedDeltaProduct',
65
+ 'GatedLinearAttention',
66
+ 'GatedSlotAttention',
67
+ 'HGRNAttention',
68
+ 'HGRN2Attention',
69
+ 'LightNetAttention',
70
+ 'LinearAttention',
71
+ 'MultiScaleRetention',
72
+ 'NativeSparseAttention',
73
+ 'ReBasedLinearAttention',
74
+ 'RWKV6Attention',
75
+ 'RWKV7Attention',
76
+ 'ABCForCausalLM',
77
+ 'ABCModel',
78
+ 'BitNetForCausalLM',
79
+ 'BitNetModel',
80
+ 'DeltaNetForCausalLM',
81
+ 'DeltaNetModel',
82
+ 'GatedDeltaNetForCausalLM',
83
+ 'GatedDeltaNetModel',
84
+ 'GatedDeltaProductForCausalLM',
85
+ 'GatedDeltaProductModel',
86
+ 'GLAForCausalLM',
87
+ 'GLAModel',
88
+ 'GSAForCausalLM',
89
+ 'GSAModel',
90
+ 'HGRNForCausalLM',
91
+ 'HGRNModel',
92
+ 'HGRN2ForCausalLM',
93
+ 'HGRN2Model',
94
+ 'LightNetForCausalLM',
95
+ 'LightNetModel',
96
+ 'LinearAttentionForCausalLM',
97
+ 'LinearAttentionModel',
98
+ 'NSAForCausalLM',
99
+ 'NSAModel',
100
+ 'RetNetForCausalLM',
101
+ 'RetNetModel',
102
+ 'RWKV6ForCausalLM',
103
+ 'RWKV6Model',
104
+ 'RWKV7ForCausalLM',
105
+ 'RWKV7Model',
106
+ 'TransformerForCausalLM',
107
+ 'TransformerModel',
108
+ ]
109
+
110
+ __version__ = '0.1.2'
fla/utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import contextlib
4
+ import functools
5
+ import os
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ from packaging import version
13
+
14
+
15
+ def tensor_cache(
16
+ fn: Callable[..., torch.Tensor]
17
+ ) -> Callable[..., torch.Tensor]:
18
+ """
19
+ A decorator that caches the most recent result of a function with tensor inputs.
20
+
21
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
22
+ If the function is called again with the same input tensors, it will return the cached result.
23
+
24
+
25
+ Args:
26
+ fn (Callable[..., torch.Tensor]):
27
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
28
+
29
+ Returns:
30
+ Callable[..., torch.Tensor]:
31
+ A wrapped version of the input function with single-entry caching.
32
+ """
33
+ last_args: Optional[Tuple] = None
34
+ last_kwargs: Optional[Dict] = None
35
+ last_result: Any = None
36
+
37
+ @functools.wraps(fn)
38
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
39
+ nonlocal last_args, last_kwargs, last_result
40
+
41
+ if last_args is not None and last_kwargs is not None:
42
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
43
+ if all(a is b for a, b in zip(args, last_args)) and \
44
+ all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
45
+ return last_result
46
+
47
+ result = fn(*args, **kwargs)
48
+ last_args, last_kwargs, last_result = args, kwargs, result
49
+ return result
50
+
51
+ return wrapper
52
+
53
+
54
+ def input_guard(
55
+ fn: Callable[..., torch.Tensor]
56
+ ) -> Callable[..., torch.Tensor]:
57
+ """
58
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
59
+ """
60
+
61
+ @functools.wraps(fn)
62
+ def wrapper(*args, **kwargs):
63
+ contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
64
+ contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
65
+
66
+ tensor = None
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ tensor = arg
70
+ break
71
+ if tensor is None:
72
+ for value in kwargs.values():
73
+ if isinstance(value, torch.Tensor):
74
+ tensor = value
75
+ break
76
+
77
+ if tensor is not None:
78
+ ctx = custom_device_ctx(tensor.device.index)
79
+ else:
80
+ ctx = contextlib.nullcontext()
81
+
82
+ with ctx:
83
+ return fn(*contiguous_args, **contiguous_kwargs)
84
+
85
+ return wrapper
86
+
87
+
88
+ contiguous = input_guard
89
+
90
+
91
+ def require_version(version, hint):
92
+ """
93
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
94
+ """
95
+ def decorator(fn):
96
+ @functools.wraps(fn)
97
+ def wrapper(ctx, *args, **kwargs):
98
+ from transformers.utils.versions import require_version
99
+ require_version(version, hint)
100
+ return fn(ctx,
101
+ *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
102
+ **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
103
+ return wrapper
104
+ return decorator
105
+
106
+
107
+ def checkpoint(fn):
108
+ def wrapper(*args, **kwargs):
109
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
110
+ return wrapper
111
+
112
+
113
+ @lru_cache(maxsize=None)
114
+ def check_pytorch_version(version_s: str = '2.4') -> bool:
115
+ return version.parse(torch.__version__) >= version.parse(version_s)
116
+
117
+
118
+ def _cpu_device_warning():
119
+ import warnings
120
+ warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
125
+ try:
126
+ # Only works if Homogeneous hardware
127
+ # TEMPORARY FIX since old version introduce graph break
128
+ return torch.cuda.get_device_properties().multi_processor_count
129
+ except BaseException:
130
+ _cpu_device_warning()
131
+ return -1
132
+
133
+
134
+ @lru_cache(maxsize=None)
135
+ def get_available_device() -> str:
136
+ try:
137
+ return triton.runtime.driver.active.get_current_target().backend
138
+ except BaseException:
139
+ _cpu_device_warning()
140
+ return 'cpu'
141
+
142
+
143
+ @lru_cache(maxsize=None)
144
+ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
145
+ device = get_available_device()
146
+ if device == 'cuda':
147
+ return 'nvidia'
148
+ elif device == 'hip':
149
+ return 'amd'
150
+ elif device == 'xpu':
151
+ return 'intel'
152
+ else:
153
+ return device
154
+
155
+
156
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
157
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
158
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
159
+ device = get_available_device() if get_available_device() != 'hip' else 'cuda'
160
+ device_torch_lib = getattr(torch, device)
161
+ device_platform = _check_platform()
162
+
163
+ is_amd = (device_platform == 'amd')
164
+ is_intel = (device_platform == 'intel')
165
+ is_nvidia = (device_platform == 'nvidia')
166
+ is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
167
+ is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
168
+ use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
169
+
170
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
171
+ is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
172
+ is_gather_supported = hasattr(triton.language, 'gather')
173
+
174
+
175
+ def get_all_max_shared_mem():
176
+ try:
177
+ return [
178
+ triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
179
+ for i in range(device_torch_lib.device_count())
180
+ ]
181
+ except BaseException:
182
+ _cpu_device_warning()
183
+ return [-1]
184
+
185
+
186
+ class Backend(Enum):
187
+ ADA = 101376 # RTX 4090
188
+ AMPERE = 166912 # A100
189
+ HOPPER = 232448 # H100
190
+ DEFAULT = 102400 # Default
191
+
192
+ @classmethod
193
+ def get_shared_memory(cls, arch: str) -> int:
194
+ try:
195
+ return cls[arch.upper()].value
196
+ except KeyError:
197
+ return cls.DEFAULT.value
198
+
199
+
200
+ @lru_cache(maxsize=None)
201
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
202
+ try:
203
+ device_shared_mem_list = get_all_max_shared_mem()
204
+ max_shared_memory = device_shared_mem_list[tensor_idx]
205
+ return max_shared_memory >= Backend.get_shared_memory(arch)
206
+ except Exception:
207
+ return False
208
+
209
+
210
+ if check_pytorch_version('2.4'):
211
+ device = 'cuda' if device == 'cpu' else device
212
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
213
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
214
+
215
+ def custom_device_ctx(index: int):
216
+ return device_torch_lib.device(index)
217
+ else:
218
+ assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
219
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
220
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
221
+
222
+ def custom_device_ctx(index: int):
223
+ return torch.cuda.device(index)
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/__pycache__/train.cpython-312.pyc ADDED
Binary file (38.1 kB). View file
 
flame/config_manager.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.implementation",
188
+ type=str,
189
+ default="fused",
190
+ choices=["for-loop", "foreach", "fused"],
191
+ help="""
192
+ Specify which optimizer implementation to use:
193
+ - 'fused': Use fused implementation (CUDA only) for best performance.
194
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
195
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
196
+ - more info: https://pytorch.org/docs/stable/optim.html
197
+ """,
198
+ )
199
+ self.parser.add_argument(
200
+ "--optimizer.early_step_in_backward",
201
+ action="store_true",
202
+ help="""
203
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
204
+ is not compatible with gradients clipping, users should not call
205
+ register_post_accumulate_grad_hook after the optimizer is built.""",
206
+ )
207
+
208
+ # lr scheduler configs
209
+ self.parser.add_argument(
210
+ "--lr_scheduler.warmup_steps",
211
+ type=int,
212
+ default=200,
213
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
214
+ )
215
+ self.parser.add_argument(
216
+ "--lr_scheduler.decay_ratio",
217
+ type=float,
218
+ default=None,
219
+ help="""
220
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
221
+
222
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
223
+ Otherwise, the learning rate will remain stable after the warmup period and
224
+ only start decaying during the last `decay_ratio` portion of the total training steps.
225
+
226
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
227
+ """,
228
+ )
229
+ self.parser.add_argument(
230
+ "--lr_scheduler.decay_type",
231
+ type=str,
232
+ default="linear",
233
+ choices=["linear", "sqrt", "cosine"],
234
+ help="""
235
+ Learning rate decay type to use during training:
236
+ - 'linear': linearly decays learning rate from initial to final value
237
+ - 'sqrt': decays learning rate following a 1 minus square root curve
238
+ - 'cosine': smoothly decays learning rate following a cosine curve
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.lr_min",
243
+ type=float,
244
+ default=0.0,
245
+ help="""
246
+ Min lr ratio for lr scheduler.
247
+
248
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
249
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
250
+ """,
251
+ )
252
+
253
+ # training configs
254
+ self.parser.add_argument(
255
+ "--training.batch_size", type=int, default=8, help="Batch size"
256
+ )
257
+ self.parser.add_argument(
258
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
259
+ )
260
+ self.parser.add_argument(
261
+ "--training.context_len",
262
+ type=int,
263
+ default=2048,
264
+ help="Max length allowed for each sequence",
265
+ )
266
+ self.parser.add_argument(
267
+ "--training.varlen",
268
+ action="store_true",
269
+ help="Whether to take sequences of variable length as input",
270
+ )
271
+ self.parser.add_argument(
272
+ "--training.gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of steps to accumulate gradients before updating parameters",
276
+ )
277
+ self.parser.add_argument(
278
+ "--training.steps",
279
+ type=int,
280
+ default=10000,
281
+ help="How many train steps to run",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.max_norm",
285
+ type=float,
286
+ default=1.0,
287
+ help="Max norm for gradient clipping",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.skip_nan_inf",
291
+ action="store_true",
292
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
293
+ )
294
+ self.parser.add_argument(
295
+ "--training.dataset",
296
+ default="HuggingFaceFW/fineweb-edu",
297
+ help="Dataset to use, with comma separated values",
298
+ )
299
+ self.parser.add_argument(
300
+ "--training.dataset_name",
301
+ default=None,
302
+ help="The name of the dataset config, with comma separated values if provided",
303
+ )
304
+ self.parser.add_argument(
305
+ "--training.dataset_split",
306
+ default=None,
307
+ help="Dataset split to use, with comma separated values if provided",
308
+ )
309
+ self.parser.add_argument(
310
+ "--training.data_dir",
311
+ default=None,
312
+ help="Data dirs to use, with comma separated values if provided",
313
+ )
314
+ self.parser.add_argument(
315
+ "--training.data_files",
316
+ default=None,
317
+ help="Data files to use, with comma separated values if provided",
318
+ )
319
+ self.parser.add_argument(
320
+ "--training.data_probs",
321
+ default=None,
322
+ help="Data sampling probabilities, with comma separated values if provided",
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.streaming",
326
+ action="store_true",
327
+ help="Whether to load dataset in streaming mode, used for huge dataset",
328
+ )
329
+ self.parser.add_argument(
330
+ "--training.num_workers",
331
+ type=int,
332
+ default=32,
333
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
334
+ )
335
+ self.parser.add_argument(
336
+ "--training.prefetch_factor",
337
+ type=int,
338
+ default=2,
339
+ help="Number of batches loaded in advance by each worker."
340
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
341
+ )
342
+ self.parser.add_argument(
343
+ "--training.data_parallel_replicate_degree",
344
+ type=int,
345
+ default=1,
346
+ help="""
347
+ The `data_parallel_replicate_degree` argument specifies the degree of
348
+ data parallelism for weight replication. When this value is greater
349
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
350
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
351
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
352
+ parallelism method used is DDP (Distributed Data Parallelism).
353
+ 1 means disabled.""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.data_parallel_shard_degree",
357
+ type=int,
358
+ default=-1,
359
+ help="""
360
+ The `data_parallel_shard_degree` argument specifies the degree of data
361
+ parallelism for weight sharding. When this value is greater than 1, weights
362
+ will be sharded across `data_parallel_shard_degree` ranks. If
363
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
364
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
365
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
366
+
367
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
368
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
369
+ )
370
+ self.parser.add_argument(
371
+ "--training.enable_cpu_offload",
372
+ action="store_true",
373
+ help="""
374
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
375
+ )
376
+ self.parser.add_argument(
377
+ "--training.tensor_parallel_degree",
378
+ type=int,
379
+ default=1,
380
+ help="Tensor Parallelism degree. 1 means disabled.",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.disable_loss_parallel",
384
+ action="store_true",
385
+ help="Whether to apply loss parallel when sequence parallel is enabled",
386
+ )
387
+ self.parser.add_argument(
388
+ "--training.fsdp_reshard_after_forward",
389
+ type=str,
390
+ default="default",
391
+ choices=["default", "always", "never"],
392
+ help="""
393
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
394
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
395
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
396
+ on `reshard_after_forward`.
397
+ The supported policies include "default", "always" and "never":
398
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
399
+ scenarios.
400
+ - "always" will enable `reshard_after_forward` for all forward passes.
401
+ - "never" will disable `reshard_after_forward` for all forward passes.
402
+ """,
403
+ )
404
+ self.parser.add_argument(
405
+ "--training.mixed_precision_param",
406
+ type=str,
407
+ default="bfloat16",
408
+ choices=["bfloat16", "float32"],
409
+ help="""
410
+ torch dtype to use for parameters when applying mixed precision via FSDP.
411
+ This feature only takes effect when data_parallel_shard_degree > 1
412
+ """,
413
+ )
414
+ self.parser.add_argument(
415
+ "--training.mixed_precision_reduce",
416
+ type=str,
417
+ default="float32",
418
+ choices=["float32"],
419
+ help="""
420
+ torch dtype to use for reductions when applying mixed precision via FSDP.
421
+ This feature only takes effect when data_parallel_shard_degree > 1
422
+ """,
423
+ )
424
+ self.parser.add_argument(
425
+ "--training.compile",
426
+ action="store_true",
427
+ help="Whether to compile the model",
428
+ )
429
+ self.parser.add_argument(
430
+ "--training.gc_freq",
431
+ type=int,
432
+ default=50,
433
+ help="Python garbage control scheduling interval, in steps",
434
+ )
435
+ self.parser.add_argument(
436
+ "--training.seed",
437
+ type=int,
438
+ default=42,
439
+ help="Choose the base RNG seed used for training",
440
+ )
441
+ self.parser.add_argument(
442
+ "--training.deterministic",
443
+ action="store_true",
444
+ help="Use deterministic algorithms wherever possible, may be slower",
445
+ )
446
+ # metrics configs
447
+ self.parser.add_argument(
448
+ "--metrics.log_freq",
449
+ type=int,
450
+ default=10,
451
+ help="How often to log metrics to TensorBoard, in iterations",
452
+ )
453
+ self.parser.add_argument(
454
+ "--metrics.enable_tensorboard",
455
+ action="store_true",
456
+ help="Whether to log metrics to TensorBoard",
457
+ )
458
+ self.parser.add_argument(
459
+ "--metrics.disable_color_printing",
460
+ action="store_true",
461
+ help="Whether to disable color printing in logs",
462
+ )
463
+ self.parser.add_argument(
464
+ "--metrics.save_tb_folder",
465
+ type=str,
466
+ default="tb",
467
+ help="Folder to dump TensorBoard states",
468
+ )
469
+ self.parser.add_argument(
470
+ "--metrics.save_for_all_ranks",
471
+ action="store_true",
472
+ default=False,
473
+ help="""
474
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
476
+ component uses the 0th rank of the last stage pipeline group, which is the
477
+ only stage that computes loss metrics.
478
+ """,
479
+ )
480
+ self.parser.add_argument(
481
+ "--metrics.enable_wandb",
482
+ action="store_true",
483
+ help="Whether to log metrics to Weights & Biases",
484
+ )
485
+
486
+ self.parser.add_argument(
487
+ "--experimental.enable_async_tensor_parallel",
488
+ action="store_true",
489
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
490
+ )
491
+ self.parser.add_argument(
492
+ "--experimental.pipeline_parallel_degree",
493
+ type=int,
494
+ default=1,
495
+ help="""
496
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
497
+ If using looped schedules, this still specifies the number of physical ranks, not the number
498
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
499
+ )
500
+ self.parser.add_argument(
501
+ "--experimental.pipeline_parallel_split_points",
502
+ type=string_list,
503
+ nargs="+",
504
+ default=[],
505
+ help="""
506
+ Specify comma-separated names of modules to use as the beginning of a split point.
507
+
508
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
509
+ the first containing all the layers up to layers.0,
510
+ the second containing layers.0 and up to layers.2,
511
+ the third containing layers.2 and all the remaining layers.
512
+
513
+ Note: fully-automated splitting may be enabled in the future,
514
+ but currently the split points must be specified manually.""",
515
+ )
516
+ self.parser.add_argument(
517
+ "--experimental.pipeline_parallel_schedule",
518
+ type=str,
519
+ default="1F1B",
520
+ help="""
521
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
522
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
523
+ The schedule must be compatible with the split points and stages_per_rank.
524
+
525
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
526
+ and split_points = number of stages - 1
527
+ """,
528
+ )
529
+ self.parser.add_argument(
530
+ "--experimental.pipeline_parallel_schedule_csv",
531
+ type=str,
532
+ default="",
533
+ help="""
534
+ Specify the path to the pipeline parallel schedule csv file to use.
535
+ The pipeline_parallel_schedule argument must be either
536
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
537
+ """,
538
+ )
539
+
540
+ self.parser.add_argument(
541
+ "--experimental.pipeline_parallel_microbatches",
542
+ type=int,
543
+ default=None,
544
+ help="""
545
+ How many microbatches to split the global training batch into when using pipeline parallelism.
546
+
547
+ The global training batch size must be evenly divisible by the number of microbatches.
548
+
549
+ The default value will be the number of pipeline stages, if unspecified.
550
+ """,
551
+ )
552
+ self.parser.add_argument(
553
+ "--experimental.enable_compiled_autograd",
554
+ action="store_true",
555
+ help="Enable CompiledAutograd to compile the backward.",
556
+ )
557
+ self.parser.add_argument(
558
+ "--experimental.context_parallel_degree",
559
+ type=int,
560
+ default=1,
561
+ help="Context parallelism degree. 1 means disabled.",
562
+ )
563
+ self.parser.add_argument(
564
+ "--experimental.context_parallel_rotate_method",
565
+ type=str,
566
+ default="allgather",
567
+ help="""
568
+ The collective to use in context parallel SDPA for kv shards exchange.
569
+
570
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
571
+
572
+ 'alltoall' means to all-to-all shuffle the kv shards.
573
+
574
+ The default value is 'allgather'.
575
+ """,
576
+ )
577
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
578
+ # module and import TorchTitan training loop and execute it, which look cleaner.
579
+ # One reason to provide this option is to allow users to use the existing run script.
580
+ # While the script is pretty trivial now, we may add more logic when integrating
581
+ # with TorchFT.
582
+ # This option is subject to change and may be deleted in the future.
583
+ self.parser.add_argument(
584
+ "--experimental.custom_model_path",
585
+ type=str,
586
+ default="",
587
+ help="""
588
+ The --custom_model_path option allows to specify a custom path to a model module
589
+ that is not natively implemented within TorchTitan.
590
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
591
+ dotted import module (e.g., some_package.model_x).
592
+ """,
593
+ )
594
+ # checkpointing configs
595
+ self.parser.add_argument(
596
+ "--checkpoint.enable_checkpoint",
597
+ action="store_true",
598
+ help="Whether to enable checkpoint",
599
+ )
600
+ self.parser.add_argument(
601
+ "--checkpoint.folder",
602
+ type=str,
603
+ default="checkpoint",
604
+ help="""
605
+ The folder to store the checkpoints.
606
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607
+ """,
608
+ )
609
+ self.parser.add_argument(
610
+ "--checkpoint.interval",
611
+ type=int,
612
+ default=500,
613
+ help="Checkpointing interval in steps.",
614
+ )
615
+ self.parser.add_argument(
616
+ "--checkpoint.model_weights_only",
617
+ action="store_true",
618
+ help="""
619
+ When model_weights_only=True, only model weights will be saved at the end of training.
620
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621
+ When model_weights_only=False, the full checkpoint will be saved.
622
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623
+ The default value is false.
624
+ """,
625
+ )
626
+ self.parser.add_argument(
627
+ "--checkpoint.export_dtype",
628
+ type=str,
629
+ default="float32",
630
+ choices=["float16", "bfloat16", "float32"],
631
+ help="""
632
+ Converts to the specified precision when training completes and model_weights_only=true.
633
+ Currently supports float32, float16, and bfloat16.
634
+ The default value is float32.
635
+ """,
636
+ )
637
+ self.parser.add_argument(
638
+ "--checkpoint.create_seed_checkpoint",
639
+ action="store_true",
640
+ help="""
641
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
642
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
643
+ Could be implemented as a separate script, but this way shares more code.
644
+ """,
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.async_mode",
648
+ type=str,
649
+ default="disabled",
650
+ help="""
651
+ Which async checkpoint mode to use. Currently there are 3 different modes.
652
+ 1. "disabled": synchronized checkpointing will be used.
653
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
654
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
655
+ space and creates a separate process for faster GPU->CPU transfer
656
+ performance and eliminating GIL contention. The cost is increased CPU
657
+ memory usage. If insufficient CPU memory is available, performance may
658
+ degrade due to memory paging. For most users, "async" should suffice as
659
+ the performance overhead is typically small (on the order of tens of
660
+ seconds) compared to checkpointing frequency. This mode can be employed
661
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
662
+ appropriate hardware support such as ample CPU memory and fast PCIe.
663
+
664
+ "disabled" is the default mode.
665
+ """,
666
+ )
667
+ self.parser.add_argument(
668
+ "--checkpoint.keep_latest_k",
669
+ type=int,
670
+ default=0,
671
+ help="""
672
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
673
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
674
+ saved. As a result, the metadata of the last one may not be ready yet.
675
+ """,
676
+ )
677
+ self.parser.add_argument(
678
+ "--checkpoint.load_step",
679
+ type=int,
680
+ default=-1,
681
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
682
+ )
683
+ self.parser.add_argument(
684
+ "--checkpoint.exclude_from_loading",
685
+ type=string_list,
686
+ nargs="*",
687
+ default=[],
688
+ help="""
689
+ Exclude specific keys from being loaded from the checkpoint.
690
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
691
+ This will load the model only, excluding the specified keys.
692
+ """,
693
+ )
694
+ self.parser.add_argument(
695
+ "--checkpoint.convert_to_hf_on_save",
696
+ action="store_true",
697
+ help="""
698
+ If true, automatically convert the saved DCP checkpoint to Hugging Face format
699
+ in a parallel directory (e.g., step-1000-hf) after each save.
700
+ """,
701
+ )
702
+ self.parser.add_argument(
703
+ "--checkpoint.hf_upload_enabled",
704
+ action="store_true",
705
+ help="Enable uploading converted Hugging Face checkpoints to the Hub.",
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.hf_repo_base_name",
709
+ type=str,
710
+ default=None,
711
+ help="Hugging Face Hub repository ID to upload checkpoints to (e.g., 'username/repo').",
712
+ )
713
+ self.parser.add_argument(
714
+ "--checkpoint.hf_upload_format",
715
+ type=str,
716
+ default="dcp",
717
+ choices=["dcp", "hf"],
718
+ help="""
719
+ Format to upload to Hugging Face Hub. 'dcp' for DCP format, 'hf' for Hugging Face format.
720
+ Note: 'hf' is only supported for models with a single pipeline stage.
721
+ """,
722
+ )
723
+ # activation checkpointing configs
724
+ self.parser.add_argument(
725
+ "--activation_checkpoint.mode",
726
+ type=str,
727
+ default="selective",
728
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
729
+ )
730
+ self.parser.add_argument(
731
+ "--activation_checkpoint.selective_ac_option",
732
+ type=str,
733
+ default="2", # 2 = checkpoint every other layer
734
+ help="""
735
+ Selective activation checkpointing options ['int', 'op'].
736
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
737
+ """,
738
+ )
739
+
740
+ self.parser.add_argument(
741
+ "--activation_offload.mode",
742
+ type=str,
743
+ default="none",
744
+ help="""
745
+ if we are using activation offload or not. Options are ['none', 'full'].
746
+ """,
747
+ )
748
+
749
+ # float8 configs
750
+ self.parser.add_argument(
751
+ "--float8.enable_fsdp_float8_all_gather",
752
+ action="store_true",
753
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
754
+ )
755
+ self.parser.add_argument(
756
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
757
+ action="store_true",
758
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
759
+ )
760
+ self.parser.add_argument(
761
+ "--float8.force_recompute_fp8_weight_in_bwd",
762
+ action="store_true",
763
+ help="""
764
+ Whether to force the recomputation of FP8 weights during backward pass.
765
+ When using FSDP with tensorwise scaling, it is recommended to enable
766
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
767
+ for backward computation.
768
+ """,
769
+ )
770
+ self.parser.add_argument(
771
+ "--float8.recipe_name",
772
+ type=str,
773
+ default=None,
774
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
775
+ help="""
776
+ If specified, creates float8 config from recipe name, valid choices are
777
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
778
+ """,
779
+ )
780
+
781
+ # communications library settings
782
+ self.parser.add_argument(
783
+ "--comm.init_timeout_seconds",
784
+ type=int,
785
+ default=300,
786
+ help="Timeout for communication operations, during initialization and first train step.",
787
+ )
788
+ self.parser.add_argument(
789
+ "--comm.train_timeout_seconds",
790
+ type=int,
791
+ default=100,
792
+ help=(
793
+ "Timeout for communication operations after the first train step -- "
794
+ "usually a tighter bound than during initialization."
795
+ ),
796
+ )
797
+ self.parser.add_argument(
798
+ "--comm.trace_buf_size",
799
+ type=int,
800
+ default=20000,
801
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
802
+ )
803
+
804
+ # memory estimation settings
805
+ self.parser.add_argument(
806
+ "--memory_estimation.enabled",
807
+ help="Whether to estimate memory usage for FSDP",
808
+ action="store_true",
809
+ )
810
+
811
+ self.parser.add_argument(
812
+ "--memory_estimation.disable_fake_mode",
813
+ help="Whether to estimate memory under FakeTensorMode",
814
+ action="store_true",
815
+ )
816
+
817
+ self.parser.add_argument(
818
+ "--fault_tolerance.enable",
819
+ action="store_true",
820
+ help="""
821
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
822
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
823
+ --fault_tolerance.group_size will be used to control the maximum
824
+ replicate group size as the replicate group size is dynamic.
825
+
826
+ Note that this is still an experimental feature.
827
+ """,
828
+ )
829
+
830
+ self.parser.add_argument(
831
+ "--fault_tolerance.replica_id",
832
+ type=int,
833
+ default=0,
834
+ help="The TorchFT replica ID of this run.",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.group_size",
839
+ type=int,
840
+ default=0,
841
+ help="""
842
+ The number of TorchFT replicate groups. This number will be used for
843
+ dataloader to split the dataset across the replicate groups and FSDP
844
+ dimension
845
+ """,
846
+ )
847
+
848
+ self.parser.add_argument(
849
+ "--fault_tolerance.min_replica_size",
850
+ type=int,
851
+ default=1,
852
+ help="The minimum number of FT replica for each step.",
853
+ )
854
+
855
+ def to_dict(self):
856
+ return self.args_dict
857
+
858
+ def parse_args(self, args_list: list = sys.argv[1:]):
859
+ args, cmd_args = self.parse_args_from_command_line(args_list)
860
+ config_file = getattr(args, "job.config_file", None)
861
+ # build up a two level dict
862
+ args_dict = self._args_to_two_level_dict(args)
863
+ if config_file is not None:
864
+ try:
865
+ with open(config_file, "rb") as f:
866
+ for k, v in tomllib.load(f).items():
867
+ # to prevent overwrite of non-specified keys
868
+ args_dict[k] |= v
869
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
870
+ logger.exception(
871
+ f"Error while loading the configuration file: {config_file}"
872
+ )
873
+ logger.exception(f"Error details: {str(e)}")
874
+ raise e
875
+
876
+ # Checking string-list arguments are properly split into a list
877
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
878
+ string_list_argnames = self._get_string_list_argument_names()
879
+ for n in string_list_argnames:
880
+ check_string_list_argument(args_dict, n)
881
+
882
+ # override args dict with cmd_args
883
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
884
+ for section, section_args in cmd_args_dict.items():
885
+ for k, v in section_args.items():
886
+ args_dict[section][k] = v
887
+
888
+ self.args_dict = args_dict
889
+
890
+ for k, v in args_dict.items():
891
+ class_type = type(k.title(), (), v)
892
+ setattr(self, k, class_type())
893
+ self._validate_config()
894
+
895
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
896
+ args_dict = defaultdict(defaultdict)
897
+ for k, v in vars(args).items():
898
+ first_level_key, second_level_key = k.split(".", 1)
899
+ args_dict[first_level_key][second_level_key] = v
900
+ return args_dict
901
+
902
+ def _validate_config(self) -> None:
903
+ # TODO: Add more mandatory validations
904
+ assert self.model.config
905
+ assert self.model.tokenizer_path
906
+
907
+ def _get_string_list_argument_names(self) -> list[str]:
908
+ """Get the parser argument names of type `string_list`."""
909
+ string_list_args = [
910
+ v.dest for v in self.parser._actions if v.type is string_list
911
+ ]
912
+ return string_list_args
913
+
914
+ def parse_args_from_command_line(
915
+ self, args_list
916
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
917
+ """
918
+ Parse command line arguments and return the parsed args and the command line only args
919
+ """
920
+ args = self.parser.parse_args(args_list)
921
+ string_list_argnames = set(self._get_string_list_argument_names())
922
+
923
+ # aux parser to parse the command line only args, with no defaults from main parser
924
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
925
+ for arg, val in vars(args).items():
926
+ if isinstance(val, bool):
927
+ aux_parser.add_argument(
928
+ "--" + arg, action="store_true" if val else "store_false"
929
+ )
930
+ elif arg in string_list_argnames:
931
+ # without this special case, type inference breaks here,
932
+ # since the inferred type is just 'list' and it ends up flattening
933
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
934
+ aux_parser.add_argument("--" + arg, type=string_list)
935
+ else:
936
+ aux_parser.add_argument("--" + arg, type=type(val))
937
+
938
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
939
+
940
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class BufferShuffledIterableDataset(IterableDataset):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ tokenizer: PreTrainedTokenizer,
28
+ seq_len: int = 2048,
29
+ rank: int = 0,
30
+ world_size: int = 1,
31
+ buffer_size: int = 1024,
32
+ ) -> BufferShuffledIterableDataset:
33
+ self.dataset = dataset
34
+ self.tokenizer = tokenizer
35
+
36
+ self.data = dataset.shard(world_size, rank)
37
+ self.seq_len = seq_len
38
+
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.buffer_size = buffer_size
42
+
43
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
+ self.dtype = torch.int16
45
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
+ self.dtype = torch.int32
47
+ else:
48
+ self.dtype = torch.int64
49
+ self.states = None
50
+ self.buffer = torch.tensor([], dtype=self.dtype)
51
+ self.tokens = []
52
+ self.rand_id = 0
53
+ self.token_id = 0
54
+ self.rng_state = None
55
+ self._epoch = 0
56
+
57
+ def __iter__(self):
58
+ g = torch.Generator()
59
+ g.manual_seed(self._epoch + self.rank)
60
+ if self.rng_state is not None:
61
+ g.set_state(self.rng_state)
62
+
63
+ rand_it = self.randint(0, self.buffer_size, g=g)
64
+ if self.states is not None:
65
+ self.data.load_state_dict(self.states)
66
+
67
+ # max number of tokens allowed in the chunk buffer
68
+ n_tokens = self.buffer_size * self.seq_len
69
+
70
+ while True:
71
+ for sample in self.tokenize(self.data):
72
+ # keep appending the samples to the token buffer
73
+ self.tokens += sample
74
+ # if the token buffer is full, start sampling
75
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
76
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
77
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
78
+ self.tokens = self.tokens[n_tokens:]
79
+ if len(self.buffer) == self.buffer_size:
80
+ yield from self.sample(rand_it)
81
+
82
+ n_chunks = len(self.tokens) // self.seq_len
83
+ # handle the left tokens in the buffer
84
+ if n_chunks > 0:
85
+ n_tokens = n_chunks * self.seq_len
86
+ indices = torch.randperm(n_chunks, generator=g).tolist()
87
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
88
+ self.tokens = self.tokens[n_tokens:]
89
+ for i in indices:
90
+ yield {'input_ids': self.buffer[i]}
91
+
92
+ def tokenize(self, data, batch_size: int = 64):
93
+ texts, states = [], []
94
+ for sample in data:
95
+ texts.append(sample['text'])
96
+ states.append(self.data.state_dict())
97
+ if len(texts) == batch_size:
98
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
99
+ self.states = s
100
+ yield tokenized
101
+ texts, states = [], []
102
+ if len(texts) > 0:
103
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
104
+ self.states = s
105
+ yield tokenized
106
+
107
+ def sample(self, indices):
108
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
109
+ while self.token_id < n_tokens:
110
+ i = next(indices)
111
+ start, end = self.token_id, self.token_id + self.seq_len
112
+ self.token_id += self.seq_len
113
+ yield {'input_ids': self.buffer[i].to(torch.long)}
114
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
115
+ self.token_id = 0
116
+ self.tokens = self.tokens[n_tokens:]
117
+
118
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
119
+ indices = torch.empty(buffer_size, dtype=torch.long)
120
+ while True:
121
+ # record the generator states before sampling
122
+ self.rng_state = g.get_state()
123
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
124
+ for i in indices[self.rand_id:].tolist():
125
+ self.rand_id += 1
126
+ yield i
127
+ self.rand_id = 0
128
+
129
+ def set_epoch(self, epoch):
130
+ self._epoch = epoch
131
+ if hasattr(self.dataset, 'set_epoch'):
132
+ self.dataset.set_epoch(epoch)
133
+
134
+ def state_dict(self):
135
+ return {
136
+ 'states': self.states,
137
+ 'buffer': self.buffer.clone(),
138
+ 'tokens': deepcopy(self.tokens),
139
+ 'rand_id': self.rand_id,
140
+ 'token_id': self.token_id,
141
+ 'rng_state': self.rng_state,
142
+ 'epoch': self._epoch,
143
+ }
144
+
145
+ def load_state_dict(self, state_dict):
146
+ self.states = state_dict['states']
147
+ self.buffer = state_dict['buffer'].clone()
148
+ self.tokens = deepcopy(state_dict['tokens'])
149
+ self.rand_id = state_dict['rand_id']
150
+ self.token_id = state_dict['token_id']
151
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
152
+ self._epoch = state_dict['epoch']
153
+
154
+
155
+ class OnlineTokenizedIterableDataset(IterableDataset):
156
+ def __init__(
157
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
158
+ ) -> OnlineTokenizedIterableDataset:
159
+ self.dataset = dataset
160
+ self.tokenizer = tokenizer
161
+
162
+ self.data = dataset.shard(world_size, rank)
163
+ self.seq_len = seq_len
164
+ self.rank = rank
165
+ self.world_size = world_size
166
+
167
+ self.states = None
168
+ self.tokens = []
169
+
170
+ def __iter__(self):
171
+ if self.states is not None:
172
+ self.data.load_state_dict(self.states)
173
+
174
+ while True:
175
+ for sample in self.tokenize(self.data):
176
+ # keep appending the samples to the token buffer
177
+ self.tokens += sample
178
+
179
+ while len(self.tokens) >= self.seq_len:
180
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
181
+ self.tokens = self.tokens[self.seq_len:]
182
+ yield {'input_ids': input_ids}
183
+
184
+ def tokenize(self, data, buffer_size: int = 64):
185
+ buffer, states = [], []
186
+ for sample in data:
187
+ if sample.get('text', None) is not None:
188
+ buffer.append(sample['text'])
189
+ elif sample.get('content', None) is not None:
190
+ buffer.append(sample['content'])
191
+ else:
192
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
193
+ states.append(self.data.state_dict())
194
+ if len(buffer) == buffer_size:
195
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
196
+ self.states = s
197
+ yield tokenized
198
+ buffer, states = [], []
199
+ if len(buffer) > 0:
200
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
201
+ self.states = s
202
+ yield tokenized
203
+
204
+ def state_dict(self):
205
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
206
+
207
+ def load_state_dict(self, state_dict):
208
+ self.states = state_dict['states']
209
+ self.tokens = deepcopy(state_dict['tokens'])
210
+
211
+
212
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def _init_state_dict(self) -> dict:
217
+ self._state_dict = self.ex_iterable._init_state_dict()
218
+ self._state_dict['mem_buffer'] = ([],)
219
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
220
+ self._state_dict['bit_generator_index_offset'] = 0
221
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
222
+ return self._state_dict
223
+
224
+ def __iter__(self):
225
+ buffer_size = self.buffer_size
226
+ rng = deepcopy(self.generator)
227
+ # this is the shuffle buffer that we keep in memory
228
+ mem_buffer = self._state_dict['mem_buffer'][0]
229
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
230
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
231
+ if self._state_dict:
232
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
233
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
234
+ # skip already consumed ones
235
+ for _ in range(index_offset):
236
+ i = next(indices_iterator)
237
+
238
+ for x in self.ex_iterable:
239
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
240
+ mem_buffer.append(x)
241
+ else: # otherwise, pick an example from it
242
+ i = next(indices_iterator)
243
+ index_offset = (index_offset + 1) % buffer_size
244
+ if self._state_dict:
245
+ self._state_dict['bit_generator_index_offset'] = index_offset
246
+ if index_offset == 0:
247
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
248
+ selected = mem_buffer[i]
249
+ mem_buffer[i] = x # replace the picked example by a new one
250
+ yield selected
251
+
252
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
253
+ if self._state_dict:
254
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
255
+
256
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
257
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
258
+ index_offset = index_offset + 1
259
+ if self._state_dict:
260
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
261
+ yield mem_buffer[i]
262
+
263
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
264
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
265
+ return BufferShuffledExamplesIterable(
266
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
267
+ )
268
+
269
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
270
+ """Keep only the requested shard."""
271
+ return BufferShuffledExamplesIterable(
272
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
273
+ buffer_size=self.buffer_size,
274
+ generator=self.generator,
275
+ )
276
+
277
+ def load_state_dict(self, state_dict: dict) -> dict:
278
+ def _inner_load_state_dict(state, new_state):
279
+ if new_state is not None and isinstance(state, dict):
280
+ for key in new_state:
281
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
282
+ return state
283
+ elif new_state is not None and isinstance(state, list):
284
+ for i in range(len(state)):
285
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
286
+ return state
287
+ return new_state
288
+
289
+ return _inner_load_state_dict(self._state_dict, state_dict)
290
+
291
+
292
+ def shuffle(
293
+ dataset: IterableDataset,
294
+ seed: int = 42,
295
+ generator: np.random.Generator = None,
296
+ buffer_size: int = 1024,
297
+ ):
298
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
299
+ return IterableDataset(
300
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
301
+ info=dataset._info.copy(),
302
+ split=dataset._split,
303
+ formatting=dataset._formatting,
304
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
305
+ distributed=copy.deepcopy(dataset._distributed),
306
+ token_per_repo_id=dataset._token_per_repo_id,
307
+ )
308
+
309
+
310
+ @dataclass
311
+ class DataCollatorForLanguageModeling:
312
+ """
313
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
314
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
315
+
316
+ Args:
317
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
318
+ The tokenizer used for encoding the data.
319
+ context_len (`int`, optional):
320
+ When `varlen=True`, sequences longer than this length within a document
321
+ (as determined by `cu_seqlens`) will be further chunked.
322
+ varlen (`bool`):
323
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
324
+
325
+ Returns:
326
+ A dictionary with the following keys:
327
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
328
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
329
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
330
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
331
+
332
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
333
+ """
334
+
335
+ tokenizer: PreTrainedTokenizer
336
+ context_len: Optional[int] = None
337
+ varlen: bool = False
338
+
339
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
340
+ if not isinstance(examples[0], Dict):
341
+ examples = [{'input_ids': example} for example in examples]
342
+
343
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
344
+ tensorized = {}
345
+ for key in ['input_ids', 'cu_seqlens']:
346
+ if key not in example:
347
+ continue
348
+ if isinstance(example[key], List):
349
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
350
+ elif isinstance(example[key], np.ndarray):
351
+ tensorized[key] = torch.from_numpy(example[key])
352
+ else:
353
+ tensorized[key] = example[key]
354
+ return tensorized
355
+
356
+ examples = list(map(tensorize, examples))
357
+
358
+ if not self.varlen:
359
+ # --- Handling for varlen=False (Batch Padding) ---
360
+ length_of_first = examples[0]['input_ids'].size(0)
361
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
362
+
363
+ if needs_padding:
364
+ # Check for pad token if padding is actually required
365
+ if self.tokenizer.pad_token_id is None:
366
+ raise ValueError(
367
+ f'You are attempting to pad samples but the tokenizer you are using '
368
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
369
+ )
370
+ # Pad using the tokenizer, ensuring attention_mask is returned
371
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
372
+ else:
373
+ # No padding needed, stack directly and create a full attention mask
374
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
375
+ batch = {
376
+ 'input_ids': input_ids,
377
+ # Create attention mask of all ones
378
+ 'attention_mask': torch.ones_like(input_ids),
379
+ }
380
+
381
+ # Create labels by cloning input_ids
382
+ labels = batch['input_ids'].clone()
383
+ # Mask labels only where attention_mask is 0 (padding positions)
384
+ if 'attention_mask' in batch:
385
+ labels[batch['attention_mask'] == 0] = -100
386
+ batch['labels'] = labels
387
+
388
+ else:
389
+ # --- Handling for varlen=True (Concatenated Sequences) ---
390
+ if len(examples) > 1:
391
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
392
+
393
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
394
+
395
+ # --- cu_seqlens calculation logic remains the same ---
396
+ if 'cu_seqlens' in examples[0]:
397
+ batch['cu_seqlens'] = (
398
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
399
+ ) # Ensure int32
400
+ else:
401
+ # determine boundaries by bos/eos positions
402
+ # Check for bos_token_id first
403
+ if self.tokenizer.bos_token_id is not None:
404
+ cu_seqlens = []
405
+ # Handle case where the sequence doesn't start with BOS
406
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
407
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
408
+ # Find all BOS token positions
409
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
410
+ # Ensure bos_positions is on the correct device if empty
411
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
412
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
413
+ elif bos_positions.numel() > 0:
414
+ cu_seqlens.append(bos_positions)
415
+ # Add the end of the entire batch
416
+ cu_seqlens.append(
417
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
418
+ ) # Match device and use size(1)
419
+ # Filter out empty tensors before cat
420
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
421
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
422
+ batch['cu_seqlens'] = torch.tensor(
423
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
424
+ )
425
+ else:
426
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
427
+
428
+ # Else, check for eos_token_id
429
+ elif self.tokenizer.eos_token_id is not None:
430
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
431
+ # Find positions *after* EOS tokens
432
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
433
+ # Ensure eos_positions is on the correct device if empty
434
+ if eos_positions.numel() > 0:
435
+ cu_seqlens.append(eos_positions)
436
+ # Handle case where the sequence doesn't end with EOS
437
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
438
+ # Only add the final length if the last found EOS wasn't already the end
439
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
440
+ cu_seqlens.append(
441
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
442
+ ) # Match device and use size(1)
443
+ # Filter out empty tensors before cat
444
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
445
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
446
+ batch['cu_seqlens'] = torch.tensor(
447
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
448
+ )
449
+ else:
450
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
451
+ # Else, neither BOS nor EOS is usable
452
+ else:
453
+ raise ValueError(
454
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
455
+ 'or an eos_token_id defined to act as sequence separators.'
456
+ )
457
+
458
+ # --- cu_seqlens validation checks remain the same ---
459
+ if batch['cu_seqlens'].numel() < 2:
460
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
461
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
462
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
463
+ if batch['cu_seqlens'][0] != 0:
464
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
466
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
467
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
468
+ raise ValueError(
469
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
470
+ f'{batch["cu_seqlens"]}'
471
+ )
472
+
473
+ # --- context_len splitting logic remains the same ---
474
+ if self.context_len is not None:
475
+ # This logic splits sequences based on context_len *after* initial boundaries are found
476
+ bos = batch['cu_seqlens'][:-1].tolist()
477
+ eos = batch['cu_seqlens'][1:].tolist()
478
+ # Handle empty sequences between boundaries
479
+ split_boundaries = []
480
+ for i, j in zip(bos, eos):
481
+ if i < j: # Only process non-empty sequences
482
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
483
+ # Add the final end point if it wasn't included by arange
484
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
485
+ # Concatenate all boundaries
486
+ if not split_boundaries: # Handle case of completely empty input
487
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
488
+ else:
489
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
490
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
491
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
492
+
493
+ # Create labels directly from input_ids, NO padding mask needed for varlen
494
+ labels = batch['input_ids'].clone()
495
+ batch['labels'] = labels
496
+
497
+ return batch
498
+
499
+
500
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
501
+ """
502
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ rank: int,
508
+ dataset: IterableDataset,
509
+ batch_size: int,
510
+ collate_fn: Callable,
511
+ num_workers: int = 0,
512
+ pin_memory: bool = False,
513
+ prefetch_factor: int = 2,
514
+ persistent_workers: bool = False,
515
+ snapshot_every_n_steps: Optional[int] = 1,
516
+ ):
517
+ super().__init__(
518
+ dataset=dataset,
519
+ batch_size=batch_size,
520
+ collate_fn=collate_fn,
521
+ num_workers=num_workers,
522
+ pin_memory=pin_memory,
523
+ prefetch_factor=prefetch_factor,
524
+ persistent_workers=persistent_workers,
525
+ snapshot_every_n_steps=snapshot_every_n_steps,
526
+ )
527
+ self.rank = rank
528
+
529
+ def state_dict(self) -> Dict[str, Any]:
530
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
531
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
532
+
533
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
534
+ # State being empty is valid
535
+ if not state_dict:
536
+ return
537
+
538
+ if f'rank_{self.rank}' not in state_dict:
539
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
540
+ return
541
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542
+
543
+
544
+ def build_dataloader(
545
+ dataset: IterableDataset,
546
+ tokenizer: PreTrainedTokenizer,
547
+ rank: int,
548
+ world_size: int,
549
+ batch_size: int,
550
+ seq_len: int,
551
+ context_len: Optional[int] = None,
552
+ varlen: bool = False,
553
+ num_workers: int = 0,
554
+ pin_memory: bool = False,
555
+ persistent_workers: bool = False,
556
+ snapshot_every_n_steps: Optional[int] = 1,
557
+ ):
558
+ dataset = OnlineTokenizedIterableDataset(
559
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
560
+ )
561
+ return ParallelAwareDataLoader(
562
+ rank=rank,
563
+ dataset=dataset,
564
+ batch_size=batch_size,
565
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
566
+ num_workers=num_workers,
567
+ pin_memory=pin_memory,
568
+ persistent_workers=persistent_workers,
569
+ snapshot_every_n_steps=snapshot_every_n_steps,
570
+ )
flame/models/__init__.py ADDED
File without changes
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/train.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+ from collections import defaultdict
12
+ import dataclasses
13
+
14
+ import torch
15
+ from datasets import interleave_datasets, load_dataset
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
18
+
19
+ import fla # noqa
20
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
21
+ from fla.ops.common.utils import prepare_position_ids
22
+ from flame.components.checkpoint import TrainState
23
+ from flame.config_manager import JobConfig
24
+ from flame.data import build_dataloader, shuffle
25
+ from flame.models.parallelize_fla import parallelize_fla
26
+ from flame.models.pipeline_fla import pipeline_fla
27
+ from flame.tools.utils import get_nparams_and_flops
28
+ from flame.utils.checkpoint import cleanup_local_checkpoints
29
+ from flame.utils.convert_dcp_to_hf import save_pretrained
30
+ from flame.utils.hf_utils import upload_checkpoint_to_hf
31
+ from datetime import datetime
32
+ from torchtitan.components.checkpoint import CheckpointManager
33
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
34
+ from torchtitan.components.loss import build_cross_entropy_loss
35
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
36
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
37
+ from torchtitan.components.optimizer import build_optimizers
38
+ from torchtitan.distributed import ParallelDims
39
+ from torchtitan.distributed import utils as dist_utils
40
+ from torchtitan.protocols.model_converter import build_model_converters
41
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
42
+ from torchtitan.tools import utils
43
+ from torchtitan.tools.logging import init_logger, logger
44
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
45
+
46
+ from dotenv import load_dotenv
47
+ load_dotenv()
48
+
49
+ import wandb
50
+ wandb.login(key=os.environ["WANDB_API_KEY"])
51
+
52
+ import huggingface_hub
53
+ huggingface_hub.login(token=os.environ["HF_TOKEN"])
54
+
55
+
56
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
57
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
58
+
59
+
60
+ register_train_spec(
61
+ TrainSpec(
62
+ name="fla",
63
+ cls=AutoModelForCausalLM,
64
+ config=AutoConfig,
65
+ parallelize_fn=parallelize_fla,
66
+ pipelining_fn=pipeline_fla,
67
+ build_optimizers_fn=build_optimizers,
68
+ build_lr_schedulers_fn=build_lr_schedulers,
69
+ build_dataloader_fn=build_dataloader,
70
+ build_tokenizer_fn=build_tokenizer,
71
+ build_loss_fn=build_cross_entropy_loss,
72
+ )
73
+ )
74
+
75
+
76
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
77
+ @record
78
+ def main(job_config: JobConfig):
79
+ logger.info(f"Starting job: {job_config.job.description}")
80
+
81
+ if job_config.experimental.custom_model_path:
82
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
83
+
84
+ # used for colorful printing
85
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
86
+
87
+ if job_config.job.print_args:
88
+ logger.info(
89
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
90
+ )
91
+
92
+ # take control of garbage collection to avoid stragglers
93
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
94
+
95
+ device_module, device_type = utils.device_module, utils.device_type
96
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
97
+ # Device has to be set before creating TorchFT manager.
98
+ device_module.set_device(device)
99
+ ft_manager = init_ft_manager(job_config)
100
+
101
+ run_specific_repo_id = None
102
+ if getattr(job_config.checkpoint, "hf_upload_enabled", False):
103
+ hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None)
104
+ if hf_repo_base:
105
+ # Generate timestamp (adjust format if desired)
106
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
107
+ run_specific_repo_id = f"{hf_repo_base}-{timestamp}"
108
+ logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}")
109
+ else:
110
+ logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.")
111
+ # Disable upload if base name is missing
112
+ job_config.checkpoint.hf_upload_enabled = False
113
+
114
+ # init distributed
115
+ world_size = int(os.environ["WORLD_SIZE"])
116
+ if not ft_manager.enabled:
117
+ parallel_dims = ParallelDims(
118
+ dp_shard=job_config.training.data_parallel_shard_degree,
119
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
120
+ cp=job_config.experimental.context_parallel_degree,
121
+ tp=job_config.training.tensor_parallel_degree,
122
+ pp=job_config.experimental.pipeline_parallel_degree,
123
+ world_size=world_size,
124
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
125
+ )
126
+ else:
127
+ parallel_dims = FTParallelDims(
128
+ dp_shard=job_config.training.data_parallel_shard_degree,
129
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
130
+ cp=job_config.experimental.context_parallel_degree,
131
+ tp=job_config.training.tensor_parallel_degree,
132
+ pp=job_config.experimental.pipeline_parallel_degree,
133
+ world_size=world_size,
134
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
135
+ ft_manager=ft_manager,
136
+ )
137
+ dist_utils.init_distributed(job_config)
138
+ # initialize device memory monitor and get peak flops for MFU calculation
139
+ device_memory_monitor = build_device_memory_monitor()
140
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
141
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
142
+
143
+ # build meshes
144
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
145
+ if parallel_dims.dp_enabled:
146
+ dp_mesh = world_mesh["dp"]
147
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
148
+ else:
149
+ dp_degree, dp_rank = 1, 0
150
+
151
+ if parallel_dims.pp_enabled:
152
+ raise NotImplementedError(
153
+ "Pipeline parallelism is not supported in this version"
154
+ )
155
+ """
156
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
157
+ [x] Match the key of models' components with the actual naming
158
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
159
+ forces to tie if head is None, we need to handle this case
160
+ [ ]
161
+ """
162
+ pp_mesh = world_mesh["pp"]
163
+
164
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
165
+ dist_utils.set_determinism(
166
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
167
+ )
168
+ train_spec = get_train_spec(job_config.model.name)
169
+
170
+ logger.info("Loading tokenizer...")
171
+ tokenizer = AutoTokenizer.from_pretrained(
172
+ job_config.model.tokenizer_path,
173
+ trust_remote_code=True,
174
+ model_max_length=int(1e10),
175
+ )
176
+ logger.info(f"{tokenizer}")
177
+ logger.info(
178
+ f"Loading dataset {job_config.training.dataset}"
179
+ f":{job_config.training.dataset_name}"
180
+ if job_config.training.dataset_name is not None
181
+ else ""
182
+ )
183
+
184
+ min_num_shards = dp_degree * job_config.training.num_workers
185
+ if len(job_config.training.dataset.split(",")) == 1:
186
+ dataset = load_dataset(
187
+ path=job_config.training.dataset,
188
+ name=getattr(job_config.training, "dataset_name", None),
189
+ data_dir=getattr(job_config.training, "data_dir", None),
190
+ data_files=getattr(job_config.training, "data_files", None),
191
+ split=job_config.training.dataset_split or "train",
192
+ trust_remote_code=True,
193
+ streaming=job_config.training.streaming,
194
+ num_proc=(
195
+ job_config.training.num_workers
196
+ if not job_config.training.streaming
197
+ else None
198
+ ),
199
+ )
200
+ logger.info(f"{dataset}")
201
+
202
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
203
+ if not job_config.training.streaming:
204
+ # the states of map-style dataset is recoverable after shuffling
205
+ dataset = dataset.shuffle(
206
+ seed=job_config.training.seed
207
+ ).to_iterable_dataset(num_shards=min_num_shards)
208
+ else:
209
+ if dataset.num_shards < min_num_shards:
210
+ logger.warning(
211
+ f"{color.red}"
212
+ f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). "
213
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
214
+ f"{job_config.training.num_workers} dataloader workers. "
215
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
216
+ f"{color.reset}"
217
+ )
218
+ dataset = (
219
+ load_dataset(
220
+ path=job_config.training.dataset,
221
+ name=getattr(job_config.training, "dataset_name", None),
222
+ data_dir=getattr(job_config.training, "data_dir", None),
223
+ data_files=getattr(job_config.training, "data_files", None),
224
+ split=job_config.training.dataset_split or "train",
225
+ trust_remote_code=True,
226
+ streaming=False,
227
+ num_proc=job_config.training.num_workers,
228
+ )
229
+ .shuffle(seed=job_config.training.seed)
230
+ .to_iterable_dataset(num_shards=min_num_shards)
231
+ )
232
+ else:
233
+ dataset = shuffle(dataset, seed=job_config.training.seed)
234
+ else:
235
+ datasets = job_config.training.dataset.split(",")
236
+ if job_config.training.dataset_name is not None:
237
+ dataset_names = [
238
+ name or None for name in job_config.training.dataset_name.split(",")
239
+ ]
240
+ assert len(dataset_names) == len(datasets), (
241
+ "The number of dataset names must match the number of datasets"
242
+ )
243
+ else:
244
+ dataset_names = [None] * len(datasets)
245
+ if job_config.training.dataset_split is not None:
246
+ dataset_splits = [
247
+ split or "train"
248
+ for split in job_config.training.dataset_split.split(",")
249
+ ]
250
+ assert len(dataset_splits) == len(datasets), (
251
+ "The number of dataset splits must match the number of datasets"
252
+ )
253
+ else:
254
+ dataset_splits = ["train"] * len(datasets)
255
+ if job_config.training.data_dir is not None:
256
+ data_dirs = [
257
+ data_dir or None for data_dir in job_config.training.data_dir.split(",")
258
+ ]
259
+ assert len(data_dirs) == len(datasets), (
260
+ "The number of data dirs must match the number of datasets"
261
+ )
262
+ else:
263
+ data_dirs = [None] * len(datasets)
264
+ if job_config.training.data_files is not None:
265
+ data_files = job_config.training.data_files.split(",")
266
+ assert len(data_files) == len(datasets), (
267
+ "The number of data files must match the number of datasets"
268
+ )
269
+ else:
270
+ data_files = [None] * len(datasets)
271
+ if job_config.training.data_probs is not None:
272
+ data_probs = [float(p) for p in job_config.training.data_probs.split(",")]
273
+ assert len(data_probs) == len(datasets), (
274
+ "The number of data probabilities must match the number of datasets"
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ "Data sampling probabilities are required if using multiple datasets"
279
+ )
280
+
281
+ subsets = []
282
+ for i, prob in enumerate(data_probs):
283
+ subset = load_dataset(
284
+ path=datasets[i],
285
+ name=dataset_names[i],
286
+ data_dir=data_dirs[i],
287
+ data_files=data_files[i],
288
+ split=dataset_splits[i],
289
+ trust_remote_code=True,
290
+ streaming=job_config.training.streaming,
291
+ num_proc=(
292
+ job_config.training.num_workers
293
+ if not job_config.training.streaming
294
+ else None
295
+ ),
296
+ )
297
+ logger.info(
298
+ f"Subset {color.cyan}{datasets[i]}"
299
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
300
+ + f"(p = {prob:.3f}){color.reset}:\n"
301
+ + f"{subset}"
302
+ )
303
+
304
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
305
+ if not job_config.training.streaming:
306
+ # the states of map-style dataset is recoverable after shuffling
307
+ subset = subset.shuffle(
308
+ seed=job_config.training.seed
309
+ ).to_iterable_dataset(num_shards=min_num_shards)
310
+ else:
311
+ if subset.num_shards < min_num_shards:
312
+ logger.warning(
313
+ f"{color.red}"
314
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
315
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
316
+ f"{job_config.training.num_workers} dataloader workers. "
317
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
318
+ f"{color.reset}"
319
+ )
320
+ # again, it's ok to directly shuffle the map-style dataset
321
+ # we expect an error raised if the map-style dataset still has not enough data shards
322
+ subset = (
323
+ load_dataset(
324
+ path=datasets[i],
325
+ name=dataset_names[i],
326
+ data_dir=data_dirs[i],
327
+ data_files=data_files[i],
328
+ split=dataset_splits[i],
329
+ trust_remote_code=True,
330
+ streaming=False,
331
+ num_proc=job_config.training.num_workers,
332
+ )
333
+ .shuffle(seed=job_config.training.seed)
334
+ .to_iterable_dataset(min_num_shards)
335
+ )
336
+ else:
337
+ # we set relatively small buffer size here as interleaving could provide some randomness
338
+ subset = shuffle(
339
+ subset,
340
+ seed=job_config.training.seed,
341
+ buffer_size=max(128, 1024 // len(datasets)),
342
+ )
343
+
344
+ if "text" in subset.column_names:
345
+ subset = subset.select_columns("text")
346
+ elif "content" in subset.column_names:
347
+ subset = subset.select_columns("content")
348
+ else:
349
+ raise ValueError(
350
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
351
+ )
352
+ subsets.append(subset)
353
+
354
+ logger.info(
355
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
356
+ )
357
+ dataset = interleave_datasets(
358
+ datasets=subsets,
359
+ probabilities=data_probs,
360
+ stopping_strategy="all_exhausted",
361
+ seed=job_config.training.seed,
362
+ )
363
+ logger.info(f"{dataset}")
364
+
365
+
366
+ logger.info(f"Loading model config from {job_config.model.config}")
367
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
368
+
369
+ logger.info("Building dataloader...")
370
+ dataloader = build_dataloader(
371
+ dataset=dataset,
372
+ tokenizer=tokenizer,
373
+ rank=dp_rank,
374
+ world_size=dp_degree,
375
+ batch_size=job_config.training.batch_size,
376
+ # TODO: Make this more modular
377
+ # seq_len=job_config.training.seq_len if not model_config.use_top_loss else job_config.training.seq_len*2,
378
+ seq_len=job_config.training.seq_len * 2,
379
+ context_len=job_config.training.context_len,
380
+ varlen=job_config.training.varlen,
381
+ num_workers=job_config.training.num_workers,
382
+ pin_memory=job_config.training.pin_memory,
383
+ persistent_workers=job_config.training.persistent_workers,
384
+ snapshot_every_n_steps=job_config.checkpoint.interval,
385
+ )
386
+
387
+ # set the model configs from training inputs:
388
+ # 1. norm type to decide which norm layer to use
389
+ # 2. disable fused norm if TP is enabled
390
+ # 3. vocab size from tokenizer
391
+ # 4. context_len base on inputs
392
+ if parallel_dims.tp_enabled:
393
+ if model_config.fuse_norm:
394
+ logger.warning(
395
+ f"{color.red}"
396
+ f"Fused norm is not compatible with tensor parallelism. "
397
+ f"Disabling it for now."
398
+ f"{color.reset}"
399
+ )
400
+ model_config.fuse_norm = False
401
+ if parallel_dims.loss_parallel_enabled:
402
+ if model_config.fuse_cross_entropy:
403
+ logger.warning(
404
+ f"{color.red}"
405
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
406
+ f"{color.reset}"
407
+ )
408
+ model_config.fuse_cross_entropy = False
409
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
410
+
411
+ logger.info(
412
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
413
+ )
414
+ with torch.device("meta"):
415
+ model = AutoModelForCausalLM.from_config(model_config)
416
+ if (
417
+ getattr(model_config, "fuse_cross_entropy", False)
418
+ and FusedLinearCrossEntropyLoss is not None
419
+ ):
420
+ model.criterion = FusedLinearCrossEntropyLoss(
421
+ num_chunks=8 // parallel_dims.tp
422
+ )
423
+ # defer weight initialization until after parallelisms are applied
424
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
425
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
426
+
427
+ # Build the collection of model converters. No-op if `model.converters` empty
428
+ model_converters = build_model_converters(job_config, parallel_dims)
429
+ model_converters.convert(model)
430
+
431
+ # calculate model size and flops per token
432
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
433
+ model, model_config, job_config.training.context_len
434
+ )
435
+
436
+ # move sharded model to CPU/GPU and initialize weights via DTensor
437
+ if job_config.checkpoint.create_seed_checkpoint:
438
+ init_device = "cpu"
439
+ elif job_config.training.enable_cpu_offload:
440
+ init_device = "cpu"
441
+ else:
442
+ init_device = device_type
443
+
444
+ # apply parallelisms and initialization
445
+ if parallel_dims.pp_enabled:
446
+ # apply PT-D Pipeline Parallel
447
+ (
448
+ pp_schedule,
449
+ model_parts,
450
+ has_first_stage,
451
+ has_last_stage,
452
+ ) = train_spec.pipelining_fn(
453
+ model,
454
+ pp_mesh,
455
+ parallel_dims,
456
+ job_config,
457
+ device,
458
+ model_config,
459
+ train_spec.loss_fn,
460
+ )
461
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
462
+ del model
463
+
464
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
465
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
466
+ # optimizer, and checkpointing
467
+ for m in model_parts:
468
+ # apply SPMD-style PT-D techniques
469
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
470
+ m.to_empty(device=init_device)
471
+ with torch.no_grad():
472
+ m.post_init()
473
+ m.train()
474
+
475
+ # confirm that user will be able to view loss metrics on the console
476
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
477
+ else:
478
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
479
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
480
+ model.to_empty(device=init_device)
481
+ with torch.no_grad():
482
+ model.post_init()
483
+ model.train()
484
+
485
+ model_parts = [model]
486
+
487
+ device_mem_stats = device_memory_monitor.get_peak_stats()
488
+ logger.info(
489
+ f"{device_type.upper()} memory usage for model: "
490
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
491
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
492
+ )
493
+
494
+ # build optimizer after applying parallelisms to the model
495
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
496
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
497
+ # Post optimizer step model converters hook.
498
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
499
+ # where it issues a single all-reduce for all parameters at once for better performance
500
+ optimizers.register_step_post_hook(
501
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
502
+ )
503
+
504
+ train_state = TrainState()
505
+
506
+ # load initial checkpoint
507
+ checkpoint = CheckpointManager(
508
+ dataloader=dataloader,
509
+ model_parts=model_parts,
510
+ optimizers=optimizers,
511
+ lr_schedulers=lr_schedulers,
512
+ states={"train_state": train_state},
513
+ job_config=job_config,
514
+ ft_manager=ft_manager,
515
+ )
516
+
517
+ if job_config.checkpoint.create_seed_checkpoint:
518
+ assert world_size == 1, (
519
+ "Must create seed checkpoint using a single device, to disable sharding"
520
+ )
521
+ assert job_config.checkpoint.enable_checkpoint, (
522
+ "Must enable checkpointing when creating a seed checkpoint"
523
+ )
524
+ checkpoint.save(curr_step=0, force=True)
525
+ logger.info("Created seed checkpoint")
526
+ return
527
+
528
+ checkpoint.load(step=job_config.checkpoint.load_step)
529
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
530
+ # Set dependent attributes for metric_logger
531
+ metric_logger.num_flops_per_token = num_flops_per_token
532
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
533
+ metric_logger.lr_schedulers = (
534
+ lr_schedulers # Pass schedulers if needed by logger logic
535
+ )
536
+
537
+ # plot losses loaded from checkpoint (if any) to TensorBoard
538
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
539
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
540
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
541
+ for idx, step in enumerate(train_state.log_steps):
542
+ metric_logger.log(
543
+ step,
544
+ global_avg_loss=train_state.global_avg_losses[idx],
545
+ global_max_loss=train_state.global_max_losses[idx],
546
+ )
547
+
548
+ data_iterator = iter(dataloader)
549
+
550
+ train_context = dist_utils.get_train_context(
551
+ parallel_dims.loss_parallel_enabled,
552
+ job_config.experimental.enable_compiled_autograd,
553
+ )
554
+
555
+ # variables used to keep info for metrics logging
556
+ device_memory_monitor.reset_peak_stats()
557
+
558
+ global_batch_size = (
559
+ job_config.training.batch_size
560
+ * dp_degree
561
+ * job_config.training.gradient_accumulation_steps
562
+ )
563
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
564
+ # train loop
565
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
566
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
567
+ logger.info(
568
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
569
+ )
570
+ logger.info(
571
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
572
+ )
573
+ logger.info(
574
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
575
+ )
576
+ logger.info(
577
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
578
+ f" ({num_tokens_per_step:,} tokens)"
579
+ )
580
+ logger.info(
581
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
582
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
583
+ )
584
+ logger.info(
585
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
586
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
587
+ )
588
+ logger.info(
589
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
590
+ )
591
+
592
+ with (
593
+ maybe_enable_profiling(
594
+ job_config, global_step=train_state.step
595
+ ) as torch_profiler,
596
+ maybe_enable_memory_snapshot(
597
+ job_config, global_step=train_state.step
598
+ ) as memory_profiler,
599
+ ):
600
+ while train_state.step < job_config.training.steps:
601
+ train_state.step += 1
602
+ gc_handler.run(train_state.step)
603
+
604
+ optimizers.zero_grad()
605
+
606
+ losses = defaultdict(list)
607
+ actual_loss = []
608
+ # do gradient accumulation if enabled
609
+ for _ in range(job_config.training.gradient_accumulation_steps):
610
+ # get batch
611
+ data_load_start = time.perf_counter()
612
+ batch = next(data_iterator)
613
+ # Recall that this is, for top and MTP, it will be
614
+ # input_ids : (B, seq_len)
615
+ # labels : (B, seq_len * 2)
616
+ input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"]
617
+
618
+ # Update metrics processor state before forward/backward
619
+ metric_logger.ntokens_since_last_log += input_ids.numel()
620
+ metric_logger.data_loading_times.append(
621
+ time.perf_counter() - data_load_start
622
+ )
623
+
624
+ input_ids = input_ids.to(device_type)
625
+
626
+ """
627
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
628
+ Depending on the Models'PE, the position_ids might be different.
629
+
630
+ e.g. for TP
631
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
632
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
633
+
634
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
635
+ Each rank has the coresponding chunked position_ids. [FOR All model]
636
+
637
+ """
638
+ labels = labels.to(device_type)
639
+ cu_seqlens = (
640
+ batch["cu_seqlens"].to(device_type)
641
+ if "cu_seqlens" in batch
642
+ else None
643
+ )
644
+ if cu_seqlens is not None:
645
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
646
+ else:
647
+ position_ids = (
648
+ torch.arange(0, input_ids.shape[1], device=device_type)
649
+ .repeat(input_ids.shape[0], 1)
650
+ .to(torch.int32)
651
+ )
652
+ # apply context parallelism if cp is enabled
653
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
654
+ optional_context_parallel_ctx = (
655
+ dist_utils.create_context_parallel_ctx(
656
+ cp_mesh=world_mesh["cp"],
657
+ cp_buffers=[input_ids, labels, position_ids],
658
+ cp_seq_dims=[1, 1, 1],
659
+ cp_no_restore_buffers={input_ids, labels, position_ids},
660
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
661
+ )
662
+ if parallel_dims.cp_enabled
663
+ else None
664
+ )
665
+
666
+ # #! TODO[flame], we should distribute the position_ids as well with CP
667
+ if parallel_dims.pp_enabled:
668
+ raise NotImplementedError(
669
+ "Pipeline parallelism is not supported in this version"
670
+ )
671
+ # Pipeline Parallel forward / backward inside step() call
672
+ with train_context(optional_context_parallel_ctx):
673
+ targets, losses = (
674
+ (labels, []) if has_last_stage else (None, None)
675
+ )
676
+
677
+ if has_first_stage:
678
+ pp_schedule.step(input_ids, target=targets, losses=losses)
679
+ else:
680
+ pp_schedule.step(target=targets, losses=losses)
681
+
682
+ # accumulate losses across pipeline microbatches
683
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
684
+ loss = (
685
+ torch.mean(torch.stack(losses)).to(device)
686
+ if has_last_stage
687
+ else torch.tensor([-1.0], device=device)
688
+ )
689
+ else:
690
+ # Non-PP forward / backward
691
+ with train_context(optional_context_parallel_ctx):
692
+ output = model(
693
+ input_ids=input_ids,
694
+ labels=labels,
695
+ position_ids=position_ids,
696
+ cu_seqlens=cu_seqlens,
697
+ )
698
+ output_attributes = [field.name for field in dataclasses.fields(output)]
699
+ losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"]
700
+ loss = (
701
+ output.loss
702
+ / job_config.training.gradient_accumulation_steps
703
+ )
704
+ loss.backward()
705
+
706
+ actual_loss.append(loss)
707
+ for loss_attr in losses_atributes:
708
+ custom_loss = getattr(output, loss_attr, None)
709
+ if custom_loss is not None:
710
+ custom_loss = custom_loss / job_config.training.gradient_accumulation_steps
711
+ custom_loss = custom_loss
712
+ losses[loss_attr].append(custom_loss)
713
+
714
+ loss = sum(actual_loss)
715
+ for loss_attr, loss_values in losses.items():
716
+ losses[loss_attr] = sum(loss_values)
717
+
718
+ # clip gradients
719
+ grad_norm = dist_utils.clip_grad_norm_(
720
+ [p for m in model_parts for p in m.parameters()],
721
+ job_config.training.max_norm,
722
+ foreach=True,
723
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
724
+ )
725
+
726
+ # optimizer step
727
+ checkpoint.maybe_wait_for_staging()
728
+ if job_config.training.skip_nan_inf and (
729
+ grad_norm.isnan() or grad_norm.isinf()
730
+ ):
731
+ logger.warning(
732
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
733
+ )
734
+ optimizers.zero_grad()
735
+ train_state.skipped_step += 1
736
+ else:
737
+ optimizers.step()
738
+ lr_schedulers.step()
739
+
740
+ # log metrics - Use MetricsProcessor
741
+ global_avg_custom_loss = {}
742
+ global_max_custom_loss = {}
743
+ if metric_logger.should_log(train_state.step):
744
+ if (
745
+ parallel_dims.dp_replicate_enabled
746
+ or parallel_dims.dp_shard_enabled
747
+ or parallel_dims.cp_enabled
748
+ ):
749
+ loss = loss.detach()
750
+ # Use dist_mean/max on the accumulated loss for the step
751
+ global_avg_loss, global_max_loss = (
752
+ dist_utils.dist_mean(
753
+ loss,
754
+ world_mesh["dp_cp"],
755
+ ),
756
+ dist_utils.dist_max(
757
+ loss,
758
+ world_mesh["dp_cp"],
759
+ ),
760
+ )
761
+ for loss_attr, loss_value in losses.items():
762
+ global_avg_custom_loss[loss_attr] = dist_utils.dist_mean(
763
+ loss_value, world_mesh["dp_cp"]
764
+ )
765
+ global_max_custom_loss[loss_attr] = dist_utils.dist_max(
766
+ loss_value, world_mesh["dp_cp"]
767
+ )
768
+ else:
769
+ # Scale back the loss before logging
770
+ global_avg_loss = global_max_loss = loss.item()
771
+ for loss_attr, loss_value in losses.items():
772
+ global_avg_custom_loss[loss_attr] = global_max_custom_loss[
773
+ loss_attr
774
+ ] = loss_value.item()
775
+
776
+ # Update train state tokens and elapsed time
777
+ time_now = time.perf_counter()
778
+ time_delta = (
779
+ time_now - metric_logger.time_last_log
780
+ ) # Use metric_logger's time
781
+ train_state.token += (
782
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
783
+ * parallel_dims.world_size
784
+ / parallel_dims.non_data_parallel_size
785
+ )
786
+ train_state.elapsed += timedelta(seconds=time_delta)
787
+ train_state.log_steps.append(train_state.step)
788
+ train_state.global_avg_losses.append(global_avg_loss)
789
+ train_state.global_max_losses.append(global_max_loss)
790
+
791
+ # Log using the metric processor
792
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
793
+ eta = (
794
+ train_state.elapsed
795
+ * (job_config.training.steps - train_state.step)
796
+ / train_state.step
797
+ )
798
+ extra_metrics = {
799
+ "optimizer/lr": last_lr,
800
+ "optimizer/grad_norm": grad_norm.item(),
801
+ "optimizer/skipped_step": train_state.skipped_step,
802
+ }
803
+ for loss_attr, loss_value in global_avg_custom_loss.items():
804
+ extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value
805
+ metric_logger.log(
806
+ train_state.step,
807
+ global_avg_loss,
808
+ global_max_loss,
809
+ extra_metrics=extra_metrics,
810
+ )
811
+
812
+ logger.info(
813
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
814
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
815
+ )
816
+
817
+ checkpoint.save(
818
+ train_state.step, force=(train_state.step == job_config.training.steps)
819
+ )
820
+
821
+ if torch.distributed.get_rank() == 0:
822
+ if job_config.checkpoint.enable_checkpoint:
823
+ hf_target_path = None
824
+ dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}")
825
+
826
+ # TODO: Haven't tested this one yet
827
+ if getattr(job_config.checkpoint, "convert_to_hf_on_save", False):
828
+ try:
829
+ # Get the path where DCP was just saved
830
+ # Check CheckpointManager API for the best way, assuming get_save_path exists
831
+ hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf
832
+
833
+ logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}")
834
+ save_pretrained( # Call the imported function
835
+ path=hf_target_path, # Pass target HF path as 'path'
836
+ step=train_state.step,
837
+ config=job_config.model.config, # Pass model config path/id
838
+ tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id
839
+ )
840
+ logger.info(f"Successfully converted step {train_state.step} to HF format.")
841
+
842
+ except Exception as e:
843
+ logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True)
844
+
845
+ base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder)
846
+ if getattr(job_config.checkpoint, "hf_upload_enabled", True):
847
+ upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf")
848
+ keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5)
849
+
850
+ local_path_to_upload = None
851
+ if upload_format == "hf":
852
+ if hf_target_path and os.path.isdir(hf_target_path):
853
+ local_path_to_upload = hf_target_path
854
+ elif upload_format == "dcp":
855
+ if dcp_save_path and os.path.isdir(dcp_save_path):
856
+ local_path_to_upload = dcp_save_path
857
+
858
+ if local_path_to_upload:
859
+ try:
860
+ upload_checkpoint_to_hf(
861
+ local_path=local_path_to_upload,
862
+ step=train_state.step,
863
+ hf_repo_id_for_run=run_specific_repo_id,
864
+ upload_format=upload_format,
865
+ hf_keep_latest_k=job_config.checkpoint.keep_latest_k,
866
+ )
867
+ except Exception as e:
868
+ logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True)
869
+
870
+ # signal the profiler that the next profiling step has started
871
+ if torch_profiler:
872
+ torch_profiler.step()
873
+ if memory_profiler:
874
+ memory_profiler.step()
875
+
876
+ # reduce timeout after first train step for faster signal
877
+ # (assuming lazy init and compilation are finished)
878
+ if train_state.step == 1:
879
+ dist_utils.set_pg_timeouts(
880
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
881
+ world_mesh=world_mesh,
882
+ )
883
+
884
+ if torch.distributed.get_rank() == 0:
885
+ logger.info("Sleeping 2 seconds for other ranks to complete")
886
+ time.sleep(2)
887
+
888
+ metric_logger.close()
889
+ logger.info("Training completed")
890
+
891
+
892
+ if __name__ == "__main__":
893
+ init_logger()
894
+ config = JobConfig()
895
+ config.parse_args()
896
+ main(config)
897
+ torch.distributed.destroy_process_group()
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.50.3"
7
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 7101489152
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embeddings.weight": "model-00001-of-00002.safetensors",
8
+ "model.extra_heads.0.attn.k_proj.weight": "model-00002-of-00002.safetensors",
9
+ "model.extra_heads.0.attn.o_proj.weight": "model-00002-of-00002.safetensors",
10
+ "model.extra_heads.0.attn.q_proj.weight": "model-00002-of-00002.safetensors",
11
+ "model.extra_heads.0.attn.v_proj.weight": "model-00002-of-00002.safetensors",
12
+ "model.extra_heads.0.attn_norm.weight": "model-00002-of-00002.safetensors",
13
+ "model.extra_heads.0.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
14
+ "model.extra_heads.0.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
15
+ "model.extra_heads.0.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
16
+ "model.extra_heads.0.mlp_norm.weight": "model-00002-of-00002.safetensors",
17
+ "model.extra_heads.1.attn.k_proj.weight": "model-00002-of-00002.safetensors",
18
+ "model.extra_heads.1.attn.o_proj.weight": "model-00002-of-00002.safetensors",
19
+ "model.extra_heads.1.attn.q_proj.weight": "model-00002-of-00002.safetensors",
20
+ "model.extra_heads.1.attn.v_proj.weight": "model-00002-of-00002.safetensors",
21
+ "model.extra_heads.1.attn_norm.weight": "model-00002-of-00002.safetensors",
22
+ "model.extra_heads.1.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
23
+ "model.extra_heads.1.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
24
+ "model.extra_heads.1.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
25
+ "model.extra_heads.1.mlp_norm.weight": "model-00002-of-00002.safetensors",
26
+ "model.extra_heads.2.attn.k_proj.weight": "model-00002-of-00002.safetensors",
27
+ "model.extra_heads.2.attn.o_proj.weight": "model-00002-of-00002.safetensors",
28
+ "model.extra_heads.2.attn.q_proj.weight": "model-00002-of-00002.safetensors",
29
+ "model.extra_heads.2.attn.v_proj.weight": "model-00002-of-00002.safetensors",
30
+ "model.extra_heads.2.attn_norm.weight": "model-00002-of-00002.safetensors",
31
+ "model.extra_heads.2.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
32
+ "model.extra_heads.2.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
33
+ "model.extra_heads.2.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
34
+ "model.extra_heads.2.mlp_norm.weight": "model-00002-of-00002.safetensors",
35
+ "model.extra_heads.3.attn.k_proj.weight": "model-00002-of-00002.safetensors",
36
+ "model.extra_heads.3.attn.o_proj.weight": "model-00002-of-00002.safetensors",
37
+ "model.extra_heads.3.attn.q_proj.weight": "model-00002-of-00002.safetensors",
38
+ "model.extra_heads.3.attn.v_proj.weight": "model-00002-of-00002.safetensors",
39
+ "model.extra_heads.3.attn_norm.weight": "model-00002-of-00002.safetensors",
40
+ "model.extra_heads.3.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
41
+ "model.extra_heads.3.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
42
+ "model.extra_heads.3.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
43
+ "model.extra_heads.3.mlp_norm.weight": "model-00002-of-00002.safetensors",
44
+ "model.layers.0.attn.k_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.0.attn.o_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.0.attn.q_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.0.attn.v_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.0.attn_norm.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.0.mlp_norm.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.1.attn.k_proj.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.1.attn.o_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.1.attn.q_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.1.attn.v_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.1.attn_norm.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.1.mlp_norm.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.10.attn.k_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.10.attn.o_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.10.attn.q_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.10.attn.v_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.10.attn_norm.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.10.mlp_norm.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.11.attn.k_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.11.attn.o_proj.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.11.attn.q_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.11.attn.v_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.11.attn_norm.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.11.mlp_norm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.12.attn.k_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.12.attn.o_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.12.attn.q_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.12.attn.v_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.12.attn_norm.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.12.mlp_norm.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.13.attn.k_proj.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.13.attn.o_proj.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.13.attn.q_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.13.attn.v_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.13.attn_norm.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.13.mlp_norm.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.14.attn.k_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.14.attn.o_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.14.attn.q_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.14.attn.v_proj.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.14.attn_norm.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.14.mlp_norm.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.15.attn.k_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.15.attn.o_proj.weight": "model-00001-of-00002.safetensors",
109
+ "model.layers.15.attn.q_proj.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.15.attn.v_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.15.attn_norm.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.15.mlp_norm.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.16.attn.k_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.16.attn.o_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.16.attn.q_proj.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.16.attn.v_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.16.attn_norm.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.16.mlp_norm.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.17.attn.k_proj.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.17.attn.o_proj.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.17.attn.q_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.17.attn.v_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.17.attn_norm.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.17.mlp_norm.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.18.attn.k_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.18.attn.o_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.18.attn.q_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.18.attn.v_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.18.attn_norm.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.18.mlp_norm.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.19.attn.k_proj.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.19.attn.o_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.19.attn.q_proj.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.19.attn.v_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.19.attn_norm.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.19.mlp_norm.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.2.attn.k_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.2.attn.o_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.2.attn.q_proj.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.2.attn.v_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.2.attn_norm.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.2.mlp_norm.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.20.attn.k_proj.weight": "model-00001-of-00002.safetensors",
162
+ "model.layers.20.attn.o_proj.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.20.attn.q_proj.weight": "model-00001-of-00002.safetensors",
164
+ "model.layers.20.attn.v_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.20.attn_norm.weight": "model-00001-of-00002.safetensors",
166
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
169
+ "model.layers.20.mlp_norm.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.21.attn.k_proj.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.21.attn.o_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.21.attn.q_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.21.attn.v_proj.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.21.attn_norm.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
176
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
177
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.21.mlp_norm.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.22.attn.k_proj.weight": "model-00001-of-00002.safetensors",
180
+ "model.layers.22.attn.o_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.layers.22.attn.q_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.22.attn.v_proj.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.22.attn_norm.weight": "model-00001-of-00002.safetensors",
184
+ "model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.22.mlp_norm.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.23.attn.k_proj.weight": "model-00002-of-00002.safetensors",
189
+ "model.layers.23.attn.o_proj.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.23.attn.q_proj.weight": "model-00002-of-00002.safetensors",
191
+ "model.layers.23.attn.v_proj.weight": "model-00002-of-00002.safetensors",
192
+ "model.layers.23.attn_norm.weight": "model-00001-of-00002.safetensors",
193
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
194
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
195
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
196
+ "model.layers.23.mlp_norm.weight": "model-00002-of-00002.safetensors",
197
+ "model.layers.24.attn.k_proj.weight": "model-00002-of-00002.safetensors",
198
+ "model.layers.24.attn.o_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.24.attn.q_proj.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.24.attn.v_proj.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.24.attn_norm.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
204
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
205
+ "model.layers.24.mlp_norm.weight": "model-00002-of-00002.safetensors",
206
+ "model.layers.25.attn.k_proj.weight": "model-00002-of-00002.safetensors",
207
+ "model.layers.25.attn.o_proj.weight": "model-00002-of-00002.safetensors",
208
+ "model.layers.25.attn.q_proj.weight": "model-00002-of-00002.safetensors",
209
+ "model.layers.25.attn.v_proj.weight": "model-00002-of-00002.safetensors",
210
+ "model.layers.25.attn_norm.weight": "model-00002-of-00002.safetensors",
211
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
212
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
213
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
214
+ "model.layers.25.mlp_norm.weight": "model-00002-of-00002.safetensors",
215
+ "model.layers.26.attn.k_proj.weight": "model-00002-of-00002.safetensors",
216
+ "model.layers.26.attn.o_proj.weight": "model-00002-of-00002.safetensors",
217
+ "model.layers.26.attn.q_proj.weight": "model-00002-of-00002.safetensors",
218
+ "model.layers.26.attn.v_proj.weight": "model-00002-of-00002.safetensors",
219
+ "model.layers.26.attn_norm.weight": "model-00002-of-00002.safetensors",
220
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
221
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
222
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
223
+ "model.layers.26.mlp_norm.weight": "model-00002-of-00002.safetensors",
224
+ "model.layers.27.attn.k_proj.weight": "model-00002-of-00002.safetensors",
225
+ "model.layers.27.attn.o_proj.weight": "model-00002-of-00002.safetensors",
226
+ "model.layers.27.attn.q_proj.weight": "model-00002-of-00002.safetensors",
227
+ "model.layers.27.attn.v_proj.weight": "model-00002-of-00002.safetensors",
228
+ "model.layers.27.attn_norm.weight": "model-00002-of-00002.safetensors",
229
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
230
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
231
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
232
+ "model.layers.27.mlp_norm.weight": "model-00002-of-00002.safetensors",
233
+ "model.layers.3.attn.k_proj.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.3.attn.o_proj.weight": "model-00001-of-00002.safetensors",
235
+ "model.layers.3.attn.q_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.3.attn.v_proj.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.3.attn_norm.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
239
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
240
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
241
+ "model.layers.3.mlp_norm.weight": "model-00001-of-00002.safetensors",
242
+ "model.layers.4.attn.k_proj.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.4.attn.o_proj.weight": "model-00001-of-00002.safetensors",
244
+ "model.layers.4.attn.q_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.4.attn.v_proj.weight": "model-00001-of-00002.safetensors",
246
+ "model.layers.4.attn_norm.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
249
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
250
+ "model.layers.4.mlp_norm.weight": "model-00001-of-00002.safetensors",
251
+ "model.layers.5.attn.k_proj.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.5.attn.o_proj.weight": "model-00001-of-00002.safetensors",
253
+ "model.layers.5.attn.q_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.layers.5.attn.v_proj.weight": "model-00001-of-00002.safetensors",
255
+ "model.layers.5.attn_norm.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
258
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.5.mlp_norm.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.6.attn.k_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.6.attn.o_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.6.attn.q_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.6.attn.v_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.6.attn_norm.weight": "model-00001-of-00002.safetensors",
265
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
268
+ "model.layers.6.mlp_norm.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.7.attn.k_proj.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.7.attn.o_proj.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.7.attn.q_proj.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.7.attn.v_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.7.attn_norm.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.layers.7.mlp_norm.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.8.attn.k_proj.weight": "model-00001-of-00002.safetensors",
279
+ "model.layers.8.attn.o_proj.weight": "model-00001-of-00002.safetensors",
280
+ "model.layers.8.attn.q_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.layers.8.attn.v_proj.weight": "model-00001-of-00002.safetensors",
282
+ "model.layers.8.attn_norm.weight": "model-00001-of-00002.safetensors",
283
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
284
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
285
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.layers.8.mlp_norm.weight": "model-00001-of-00002.safetensors",
287
+ "model.layers.9.attn.k_proj.weight": "model-00001-of-00002.safetensors",
288
+ "model.layers.9.attn.o_proj.weight": "model-00001-of-00002.safetensors",
289
+ "model.layers.9.attn.q_proj.weight": "model-00001-of-00002.safetensors",
290
+ "model.layers.9.attn.v_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.layers.9.attn_norm.weight": "model-00001-of-00002.safetensors",
292
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
293
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
294
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
295
+ "model.layers.9.mlp_norm.weight": "model-00001-of-00002.safetensors",
296
+ "model.norm.weight": "model-00002-of-00002.safetensors"
297
+ }
298
+ }
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "[email protected]" },
8
+ { name = "Yu Zhang", email = "[email protected]" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'torch==2.6',
20
+ 'torchdata',
21
+ 'transformers==4.51.3',
22
+ 'triton>=3.0',
23
+ 'datasets>=3.3.0',
24
+ 'einops',
25
+ 'ninja',
26
+ 'wandb',
27
+ 'tiktoken',
28
+ 'tensorboard',
29
+ 'python-dotenv'
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = ["pytest"]
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/fla-org/flame"
37
+
38
+ [build-system]
39
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
40
+
41
+ [tool.isort]
42
+ line_length = 127
43
+ multi_line_output = 3
setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'torch==2.6',
40
+ 'torchdata',
41
+ 'transformers==4.51.3',
42
+ 'triton>=3.0',
43
+ 'datasets>=3.3.0',
44
+ 'einops',
45
+ 'ninja',
46
+ 'wandb',
47
+ 'tiktoken',
48
+ 'tensorboard',
49
+ 'python-dotenv'
50
+ ],
51
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tb/20250716-2210/wandb/debug-internal.log ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-07-16T22:10:00.785425491Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
2
+ {"time":"2025-07-16T22:10:01.508654924Z","level":"INFO","msg":"stream: created new stream","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
3
+ {"time":"2025-07-16T22:10:01.508690211Z","level":"INFO","msg":"stream: started","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
4
+ {"time":"2025-07-16T22:10:01.508739999Z","level":"INFO","msg":"writer: Do: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
5
+ {"time":"2025-07-16T22:10:01.508759314Z","level":"INFO","msg":"handler: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
6
+ {"time":"2025-07-16T22:10:01.508803829Z","level":"INFO","msg":"sender: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
7
+ {"time":"2025-07-16T23:09:45.740737848Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
8
+ {"time":"2025-07-16T23:18:29.56428269Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
9
+ {"time":"2025-07-16T23:19:01.917480335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
10
+ {"time":"2025-07-16T23:19:36.868918826Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
11
+ {"time":"2025-07-16T23:20:16.297827588Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
12
+ {"time":"2025-07-16T23:20:18.619477493Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:46470->127.0.0.53:53: i/o timeout"}
13
+ {"time":"2025-07-16T23:20:30.740650327Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:47482->127.0.0.53:53: i/o timeout"}
14
+ {"time":"2025-07-16T23:21:04.536690541Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
15
+ {"time":"2025-07-16T23:21:49.291673175Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
16
+ {"time":"2025-07-16T23:22:07.542159208Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
17
+ {"time":"2025-07-16T23:23:23.103733736Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
18
+ {"time":"2025-07-16T23:23:37.543151076Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
19
+ {"time":"2025-07-16T23:25:07.544031298Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
20
+ {"time":"2025-07-16T23:26:37.545971769Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
21
+ {"time":"2025-07-16T23:27:42.194377246Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
22
+ {"time":"2025-07-16T23:27:59.564813743Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000912631,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
23
+ {"time":"2025-07-16T23:28:07.547697617Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
24
+ {"time":"2025-07-16T23:29:37.549836886Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"}
25
+ {"time":"2025-07-16T23:31:01.930916994Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000672411,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
26
+ {"time":"2025-07-16T23:31:02.101966833Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000995925,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
27
+ {"time":"2025-07-16T23:31:07.103368571Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000796336,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
28
+ {"time":"2025-07-16T23:31:07.551682713Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
29
+ {"time":"2025-07-16T23:32:37.553473869Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
30
+ {"time":"2025-07-16T23:33:58.248779065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
31
+ {"time":"2025-07-16T23:34:58.351555112Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1018.787711083,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
32
+ {"time":"2025-07-16T23:34:58.351650283Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.421498346,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
33
+ {"time":"2025-07-16T23:34:58.351778293Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":831.249242004,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
34
+ {"time":"2025-07-16T23:34:58.351785775Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.250829923,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
35
+ {"time":"2025-07-17T01:31:13.353253854Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
36
+ {"time":"2025-07-17T08:06:16.748740406Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
37
+ {"time":"2025-07-17T09:50:19.526737851Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": read tcp 10.0.2.15:54882->35.186.228.49:443: read: connection reset by peer"}
38
+ {"time":"2025-07-17T09:52:30.348552703Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
39
+ {"time":"2025-07-17T09:53:02.422139335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
40
+ {"time":"2025-07-17T09:53:36.600890938Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
41
+ {"time":"2025-07-17T09:54:16.203516351Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
42
+ {"time":"2025-07-17T09:55:05.357439477Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
43
+ {"time":"2025-07-17T09:56:15.05960959Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
44
+ {"time":"2025-07-17T09:57:45.061688428Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
45
+ {"time":"2025-07-17T09:59:15.063226591Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
46
+ {"time":"2025-07-17T10:00:45.065259852Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
47
+ {"time":"2025-07-17T10:01:04.518171545Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
48
+ {"time":"2025-07-17T10:02:00.347889757Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000372919,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
49
+ {"time":"2025-07-17T10:02:15.066174619Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
50
+ {"time":"2025-07-17T10:03:45.067145051Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
51
+ {"time":"2025-07-17T10:05:02.098970791Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000073665,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
52
+ {"time":"2025-07-17T10:05:07.474477054Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000841939,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
53
+ {"time":"2025-07-17T10:05:15.068468165Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
54
+ {"time":"2025-07-17T10:05:16.930808745Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000229861,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
55
+ {"time":"2025-07-17T10:06:07.008582668Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
56
+ {"time":"2025-07-17T10:06:45.070340311Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
57
+ {"time":"2025-07-17T10:07:57.799911415Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
58
+ {"time":"2025-07-17T10:08:57.969386735Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1017.621908973,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
59
+ {"time":"2025-07-17T10:08:57.969579361Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":835.870728331,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
60
+ {"time":"2025-07-17T10:08:57.969680501Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":821.039158554,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
61
+ {"time":"2025-07-17T10:08:57.969682134Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":830.496074059,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
62
+ {"time":"2025-07-17T12:53:12.780364188Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
63
+ {"time":"2025-07-17T16:43:31.998287109Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
64
+ {"time":"2025-07-18T00:01:06.015630566Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
65
+ {"time":"2025-07-18T06:56:24.118529653Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
66
+ {"time":"2025-07-18T14:32:12.830145916Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
67
+ {"time":"2025-07-18T19:51:31.703829065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
68
+ {"time":"2025-07-19T03:35:03.743864446Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
69
+ {"time":"2025-07-19T21:22:32.639517404Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:51870->35.186.228.49:443: read: connection reset by peer"}
70
+ {"time":"2025-07-19T21:31:32.643369264Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:38482->35.186.228.49:443: read: connection reset by peer"}
71
+ {"time":"2025-07-20T00:27:42.221361901Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
72
+ {"time":"2025-07-20T09:40:16.319872482Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
73
+ {"time":"2025-07-20T09:45:18.218885403Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
74
+ {"time":"2025-07-20T19:19:37.674808147Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
75
+ {"time":"2025-07-20T20:26:46.102126738Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
76
+ {"time":"2025-07-20T21:40:42.245223721Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
77
+ {"time":"2025-07-20T21:42:31.526229193Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
78
+ {"time":"2025-07-20T22:42:07.859288654Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
79
+ {"time":"2025-07-21T03:41:28.397742169Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
80
+ {"time":"2025-07-21T04:49:16.742257697Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
81
+ {"time":"2025-07-21T05:48:28.62347913Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
82
+ {"time":"2025-07-21T06:22:31.529351974Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
83
+ {"time":"2025-07-21T14:47:44.545628902Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
84
+ {"time":"2025-07-21T21:19:44.840025606Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
85
+ {"time":"2025-07-21T21:19:44.94975041Z","level":"INFO","msg":"handler: operation stats","stats":{}}
86
+ {"time":"2025-07-21T21:19:44.958211652Z","level":"INFO","msg":"stream: closing","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
87
+ {"time":"2025-07-21T21:19:44.958407771Z","level":"INFO","msg":"writer: Close: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
88
+ {"time":"2025-07-21T21:19:44.958426934Z","level":"INFO","msg":"handler: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
89
+ {"time":"2025-07-21T21:19:44.958428316Z","level":"INFO","msg":"sender: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
90
+ {"time":"2025-07-21T21:19:44.958480192Z","level":"INFO","msg":"stream: closed","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
tb/20250716-2210/wandb/debug.log ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Configure stats pid to 1336753
3
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
4
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
5
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug.log
7
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log
8
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():830] calling init triggers
9
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():871] starting backend
12
+ 2025-07-16 22:10:00,777 INFO MainThread:1336753 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-07-16 22:10:00,781 INFO MainThread:1336753 [wandb_init.py:init():882] backend started and connected
14
+ 2025-07-16 22:10:00,782 INFO MainThread:1336753 [wandb_init.py:init():953] updated telemetry
15
+ 2025-07-16 22:10:00,786 INFO MainThread:1336753 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-07-16 22:10:01,926 INFO MainThread:1336753 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2025-07-16 22:10:02,011 INFO MainThread:1336753 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2025-07-21 21:19:44,102 INFO MainThread:1336753 [wandb_run.py:_finish():2224] finishing run zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201
23
+ 2025-07-21 21:19:44,103 INFO MainThread:1336753 [wandb_run.py:_atexit_cleanup():2423] got exitcode: 0
24
+ 2025-07-21 21:19:44,104 INFO MainThread:1336753 [wandb_run.py:_restore():2405] restore
25
+ 2025-07-21 21:19:44,104 INFO MainThread:1336753 [wandb_run.py:_restore():2411] restore done
26
+ 2025-07-21 21:19:44,955 INFO MainThread:1336753 [wandb_run.py:_footer_history_summary_info():3903] rendering history
27
+ 2025-07-21 21:19:44,956 INFO MainThread:1336753 [wandb_run.py:_footer_history_summary_info():3935] rendering summary
28
+ 2025-07-21 21:19:44,957 INFO MainThread:1336753 [wandb_run.py:_footer_sync_info():3864] logging synced files
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/config.yaml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.21.0
4
+ e:
5
+ ynnjkeia1kakdpk58ub5v7vb16scnioi:
6
+ args:
7
+ - --job.config_file
8
+ - flame/models/fla.toml
9
+ - --job.dump_folder
10
+ - exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine
11
+ - --model.config
12
+ - configs/mtp_transformer_1B.json
13
+ - --model.tokenizer_path
14
+ - fla-hub/transformer-1.3B-100B
15
+ - --optimizer.name
16
+ - AdamW
17
+ - --optimizer.eps
18
+ - "1e-15"
19
+ - --optimizer.lr
20
+ - "2e-4"
21
+ - --lr_scheduler.warmup_steps
22
+ - "2000"
23
+ - --lr_scheduler.lr_min
24
+ - "0.1"
25
+ - --lr_scheduler.decay_type
26
+ - cosine
27
+ - --training.batch_size
28
+ - "16"
29
+ - --training.seq_len
30
+ - "4096"
31
+ - --training.context_len
32
+ - "4096"
33
+ - --training.gradient_accumulation_steps
34
+ - "1"
35
+ - --training.steps
36
+ - "200000"
37
+ - --training.max_norm
38
+ - "1.0"
39
+ - --training.skip_nan_inf
40
+ - --training.dataset
41
+ - /home/cvm/.cache/HuggingFaceFW___fineweb-edu/sample-100BT
42
+ - --training.dataset_split
43
+ - train
44
+ - --training.num_workers
45
+ - "32"
46
+ - --training.prefetch_factor
47
+ - "2"
48
+ - --training.seed
49
+ - "79"
50
+ - --training.compile
51
+ - --checkpoint.interval
52
+ - "10000"
53
+ - --checkpoint.load_step
54
+ - "-1"
55
+ - --metrics.log_freq
56
+ - "5"
57
+ - --checkpoint.hf_upload_enabled
58
+ - --checkpoint.hf_repo_base_name
59
+ - zaydzuhri/mtp-1B-4096-batch16x1-steps200000
60
+ - --comm.init_timeout_seconds
61
+ - "1800"
62
+ - --comm.train_timeout_seconds
63
+ - "1800"
64
+ cpu_count: 64
65
+ cpu_count_logical: 128
66
+ cudaVersion: "12.8"
67
+ disk:
68
+ /:
69
+ total: "3242363822080"
70
+ used: "1518440218624"
71
72
+ executable: /home/cvm/miniconda3/envs/flame-env/bin/python3.12
73
+ git:
74
+ commit: aa4d5932e54fad8a568e10aa6895e69e0664fcf1
75
+ remote: https://github.com/zaydzuhri/flame.git
76
+ gpu: NVIDIA H200
77
+ gpu_count: 8
78
+ gpu_nvidia:
79
+ - architecture: Hopper
80
+ cudaCores: 16896
81
+ memoryTotal: "150754820096"
82
+ name: NVIDIA H200
83
+ uuid: GPU-eddf9f4c-ffde-5f10-3c76-12ebce1f042b
84
+ - architecture: Hopper
85
+ cudaCores: 16896
86
+ memoryTotal: "150754820096"
87
+ name: NVIDIA H200
88
+ uuid: GPU-b532c850-7343-8f67-7eb1-a69024695a99
89
+ - architecture: Hopper
90
+ cudaCores: 16896
91
+ memoryTotal: "150754820096"
92
+ name: NVIDIA H200
93
+ uuid: GPU-751a6bdf-72f3-4f5a-fefd-d2b98c338579
94
+ - architecture: Hopper
95
+ cudaCores: 16896
96
+ memoryTotal: "150754820096"
97
+ name: NVIDIA H200
98
+ uuid: GPU-0cd9d3c7-1d2e-1925-91eb-8ec99a4ed277
99
+ - architecture: Hopper
100
+ cudaCores: 16896
101
+ memoryTotal: "150754820096"
102
+ name: NVIDIA H200
103
+ uuid: GPU-fba7e7ab-8340-13b0-b893-c3686cfec728
104
+ - architecture: Hopper
105
+ cudaCores: 16896
106
+ memoryTotal: "150754820096"
107
+ name: NVIDIA H200
108
+ uuid: GPU-12ca11c0-9080-3877-2bd5-3775573a4134
109
+ - architecture: Hopper
110
+ cudaCores: 16896
111
+ memoryTotal: "150754820096"
112
+ name: NVIDIA H200
113
+ uuid: GPU-32b3ec8b-9dc8-c6f6-5c19-74fa2ce10ffd
114
+ - architecture: Hopper
115
+ cudaCores: 16896
116
+ memoryTotal: "150754820096"
117
+ name: NVIDIA H200
118
+ uuid: GPU-d0021141-e4f4-14ab-c2ab-0ef3e30d6dd5
119
+ host: mbzuai-2
120
+ memory:
121
+ total: "1913833029632"
122
+ os: Linux-6.8.0-63-generic-x86_64-with-glibc2.39
123
+ program: -m flame.train
124
+ python: CPython 3.12.11
125
+ root: exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210
126
+ startedAt: "2025-07-16T22:10:00.535907Z"
127
+ writerId: ynnjkeia1kakdpk58ub5v7vb16scnioi
128
+ m: []
129
+ python_version: 3.12.11
130
+ t:
131
+ "1":
132
+ - 1
133
+ - 11
134
+ - 49
135
+ - 51
136
+ "2":
137
+ - 1
138
+ - 11
139
+ - 49
140
+ - 51
141
+ "3":
142
+ - 2
143
+ - 13
144
+ - 14
145
+ - 61
146
+ "4": 3.12.11
147
+ "5": 0.21.0
148
+ "6": 4.50.3
149
+ "12": 0.21.0
150
+ "13": linux-x86_64
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flame==0.1.0
2
+ triton==3.2.0
3
+ sympy==1.13.1
4
+ nvidia-cusolver-cu12==11.6.1.9
5
+ idna==3.10
6
+ regex==2024.11.6
7
+ wandb==0.21.0
8
+ nvidia-cuda-cupti-cu12==12.4.127
9
+ protobuf==6.31.1
10
+ Jinja2==3.1.6
11
+ packaging==25.0
12
+ Markdown==3.8.2
13
+ hf-xet==1.1.5
14
+ sentry-sdk==2.33.0
15
+ networkx==3.5
16
+ certifi==2025.7.14
17
+ ninja==1.11.1.4
18
+ PyYAML==6.0.2
19
+ smmap==5.0.2
20
+ numpy==2.3.1
21
+ tiktoken==0.9.0
22
+ nvidia-cuda-nvrtc-cu12==12.4.127
23
+ frozenlist==1.7.0
24
+ tzdata==2025.2
25
+ tokenizers==0.21.2
26
+ nvidia-nvjitlink-cu12==12.4.127
27
+ nvidia-cusparse-cu12==12.3.1.170
28
+ pandas==2.3.1
29
+ attrs==25.3.0
30
+ tensorboard-data-server==0.7.2
31
+ aiohappyeyeballs==2.6.1
32
+ aiosignal==1.4.0
33
+ platformdirs==4.3.8
34
+ python-dotenv==1.1.1
35
+ charset-normalizer==3.4.2
36
+ requests==2.32.4
37
+ MarkupSafe==3.0.2
38
+ GitPython==3.1.44
39
+ nvidia-cufft-cu12==11.2.1.3
40
+ click==8.2.1
41
+ wheel==0.45.1
42
+ nvidia-nccl-cu12==2.21.5
43
+ nvidia-cuda-runtime-cu12==12.4.127
44
+ typing-inspection==0.4.1
45
+ gitdb==4.0.12
46
+ datasets==4.0.0
47
+ multidict==6.6.3
48
+ Werkzeug==3.1.3
49
+ grpcio==1.73.1
50
+ tqdm==4.67.1
51
+ absl-py==2.3.1
52
+ multiprocess==0.70.16
53
+ fsspec==2025.3.0
54
+ dill==0.3.8
55
+ propcache==0.3.2
56
+ yarl==1.20.1
57
+ transformers==4.50.3
58
+ mpmath==1.3.0
59
+ pydantic_core==2.33.2
60
+ flame==0.1.0
61
+ pip==25.1
62
+ torch==2.6.0
63
+ pytz==2025.2
64
+ python-dateutil==2.9.0.post0
65
+ safetensors==0.5.3
66
+ nvidia-curand-cu12==10.3.5.147
67
+ pyarrow==20.0.0
68
+ nvidia-cusparselt-cu12==0.6.2
69
+ einops==0.8.1
70
+ torchdata==0.11.0
71
+ six==1.17.0
72
+ aiohttp==3.12.14
73
+ urllib3==2.5.0
74
+ nvidia-cublas-cu12==12.4.5.8
75
+ filelock==3.18.0
76
+ flash-attn==2.7.3
77
+ nvidia-nvtx-cu12==12.4.127
78
+ xxhash==3.5.0
79
+ tensorboard==2.19.0
80
+ pydantic==2.11.7
81
+ nvidia-cudnn-cu12==9.1.0.70
82
+ typing_extensions==4.14.1
83
+ setuptools==78.1.1
84
+ huggingface-hub==0.33.4
85
+ annotated-types==0.7.0
86
+ jaraco.context==5.3.0
87
+ autocommand==2.2.2
88
+ more-itertools==10.3.0
89
+ tomli==2.0.1
90
+ jaraco.functools==4.0.1
91
+ zipp==3.19.2
92
+ backports.tarfile==1.2.0
93
+ wheel==0.45.1
94
+ platformdirs==4.2.2
95
+ inflect==7.3.1
96
+ typing_extensions==4.12.2
97
+ jaraco.text==3.12.1
98
+ typeguard==4.3.0
99
+ importlib_metadata==8.0.0
100
+ packaging==24.2
101
+ jaraco.collections==5.1.0
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/wandb-metadata.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.8.0-63-generic-x86_64-with-glibc2.39",
3
+ "python": "CPython 3.12.11",
4
+ "startedAt": "2025-07-16T22:10:00.535907Z",
5
+ "args": [
6
+ "--job.config_file",
7
+ "flame/models/fla.toml",
8
+ "--job.dump_folder",
9
+ "exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine",
10
+ "--model.config",
11
+ "configs/mtp_transformer_1B.json",
12
+ "--model.tokenizer_path",
13
+ "fla-hub/transformer-1.3B-100B",
14
+ "--optimizer.name",
15
+ "AdamW",
16
+ "--optimizer.eps",
17
+ "1e-15",
18
+ "--optimizer.lr",
19
+ "2e-4",
20
+ "--lr_scheduler.warmup_steps",
21
+ "2000",
22
+ "--lr_scheduler.lr_min",
23
+ "0.1",
24
+ "--lr_scheduler.decay_type",
25
+ "cosine",
26
+ "--training.batch_size",
27
+ "16",
28
+ "--training.seq_len",
29
+ "4096",
30
+ "--training.context_len",
31
+ "4096",
32
+ "--training.gradient_accumulation_steps",
33
+ "1",
34
+ "--training.steps",
35
+ "200000",
36
+ "--training.max_norm",
37
+ "1.0",
38
+ "--training.skip_nan_inf",
39
+ "--training.dataset",
40
+ "/home/cvm/.cache/HuggingFaceFW___fineweb-edu/sample-100BT",
41
+ "--training.dataset_split",
42
+ "train",
43
+ "--training.num_workers",
44
+ "32",
45
+ "--training.prefetch_factor",
46
+ "2",
47
+ "--training.seed",
48
+ "79",
49
+ "--training.compile",
50
+ "--checkpoint.interval",
51
+ "10000",
52
+ "--checkpoint.load_step",
53
+ "-1",
54
+ "--metrics.log_freq",
55
+ "5",
56
+ "--checkpoint.hf_upload_enabled",
57
+ "--checkpoint.hf_repo_base_name",
58
+ "zaydzuhri/mtp-1B-4096-batch16x1-steps200000",
59
+ "--comm.init_timeout_seconds",
60
+ "1800",
61
+ "--comm.train_timeout_seconds",
62
+ "1800"
63
+ ],
64
+ "program": "-m flame.train",
65
+ "git": {
66
+ "remote": "https://github.com/zaydzuhri/flame.git",
67
+ "commit": "aa4d5932e54fad8a568e10aa6895e69e0664fcf1"
68
+ },
69
+ "email": "[email protected]",
70
+ "root": "exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210",
71
+ "host": "mbzuai-2",
72
+ "executable": "/home/cvm/miniconda3/envs/flame-env/bin/python3.12",
73
+ "cpu_count": 64,
74
+ "cpu_count_logical": 128,
75
+ "gpu": "NVIDIA H200",
76
+ "gpu_count": 8,
77
+ "disk": {
78
+ "/": {
79
+ "total": "3242363822080",
80
+ "used": "1518440218624"
81
+ }
82
+ },
83
+ "memory": {
84
+ "total": "1913833029632"
85
+ },
86
+ "gpu_nvidia": [
87
+ {
88
+ "name": "NVIDIA H200",
89
+ "memoryTotal": "150754820096",
90
+ "cudaCores": 16896,
91
+ "architecture": "Hopper",
92
+ "uuid": "GPU-eddf9f4c-ffde-5f10-3c76-12ebce1f042b"
93
+ },
94
+ {
95
+ "name": "NVIDIA H200",
96
+ "memoryTotal": "150754820096",
97
+ "cudaCores": 16896,
98
+ "architecture": "Hopper",
99
+ "uuid": "GPU-b532c850-7343-8f67-7eb1-a69024695a99"
100
+ },
101
+ {
102
+ "name": "NVIDIA H200",
103
+ "memoryTotal": "150754820096",
104
+ "cudaCores": 16896,
105
+ "architecture": "Hopper",
106
+ "uuid": "GPU-751a6bdf-72f3-4f5a-fefd-d2b98c338579"
107
+ },
108
+ {
109
+ "name": "NVIDIA H200",
110
+ "memoryTotal": "150754820096",
111
+ "cudaCores": 16896,
112
+ "architecture": "Hopper",
113
+ "uuid": "GPU-0cd9d3c7-1d2e-1925-91eb-8ec99a4ed277"
114
+ },
115
+ {
116
+ "name": "NVIDIA H200",
117
+ "memoryTotal": "150754820096",
118
+ "cudaCores": 16896,
119
+ "architecture": "Hopper",
120
+ "uuid": "GPU-fba7e7ab-8340-13b0-b893-c3686cfec728"
121
+ },
122
+ {
123
+ "name": "NVIDIA H200",
124
+ "memoryTotal": "150754820096",
125
+ "cudaCores": 16896,
126
+ "architecture": "Hopper",
127
+ "uuid": "GPU-12ca11c0-9080-3877-2bd5-3775573a4134"
128
+ },
129
+ {
130
+ "name": "NVIDIA H200",
131
+ "memoryTotal": "150754820096",
132
+ "cudaCores": 16896,
133
+ "architecture": "Hopper",
134
+ "uuid": "GPU-32b3ec8b-9dc8-c6f6-5c19-74fa2ce10ffd"
135
+ },
136
+ {
137
+ "name": "NVIDIA H200",
138
+ "memoryTotal": "150754820096",
139
+ "cudaCores": 16896,
140
+ "architecture": "Hopper",
141
+ "uuid": "GPU-d0021141-e4f4-14ab-c2ab-0ef3e30d6dd5"
142
+ }
143
+ ],
144
+ "cudaVersion": "12.8",
145
+ "writerId": "ynnjkeia1kakdpk58ub5v7vb16scnioi"
146
+ }
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"memory/num_ooms":0,"loss_metrics/global_avg_loss":14.232244491577148,"tflops":431.0192442993196,"_step":200000,"throughput(tps)":31067.90008791172,"loss_metrics/global_avg_ntp_loss":2.0935492515563965,"loss_metrics/global_max_loss":14.996297836303711,"optimizer/grad_norm":1.1316726207733154,"memory/max_reserved(%)":85.28248535606636,"memory/max_active(%)":83.05995444766508,"memory/num_alloc_retries":0,"time_metrics/end_to_end(s)":2.109444146999158,"optimizer/skipped_step":0,"time_metrics/data_loading(%)":0.22571920630939155,"loss_metrics/global_avg_mtp_loss":12.13869571685791,"mfu(%)":43.58131893825275,"memory/max_reserved(GiB)":118.845703125,"_runtime":428982.182300563,"_wandb":{"runtime":428982},"time_metrics/data_loading(s)":0.004761420586146414,"optimizer/lr":2e-05,"_timestamp":1.7531325156953037e+09,"memory/max_active(GiB)":115.74848747253418}
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-core.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-07-16T22:10:00.603271728Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmp4i2mlclc/port-1336753.txt","pid":1336753,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2025-07-16T22:10:00.604042762Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":1336753}
3
+ {"time":"2025-07-16T22:10:00.604005987Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-1336753-1346688-770015975/socket","Net":"unix"}}
4
+ {"time":"2025-07-16T22:10:00.768139222Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
5
+ {"time":"2025-07-16T22:10:00.785307739Z","level":"INFO","msg":"handleInformInit: received","streamId":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201","id":"1(@)"}
6
+ {"time":"2025-07-16T22:10:01.508694644Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201","id":"1(@)"}
7
+ {"time":"2025-07-21T21:19:44.957947072Z","level":"INFO","msg":"handleInformFinish: finish message received","streamId":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201","id":"1(@)"}
8
+ {"time":"2025-07-21T21:19:44.959492358Z","level":"INFO","msg":"handleInformFinish: stream closed","streamId":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201","id":"1(@)"}
9
+ {"time":"2025-07-21T21:20:04.73362632Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
10
+ {"time":"2025-07-21T21:20:04.73396049Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
11
+ {"time":"2025-07-21T21:20:04.733969545Z","level":"INFO","msg":"server is shutting down"}
12
+ {"time":"2025-07-21T21:20:04.734079592Z","level":"INFO","msg":"connection: closing","id":"1(@)"}
13
+ {"time":"2025-07-21T21:20:04.734239368Z","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
14
+ {"time":"2025-07-21T21:20:04.734245487Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
15
+ {"time":"2025-07-21T21:20:04.734574344Z","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-1336753-1346688-770015975/socket","Net":"unix"}}
16
+ {"time":"2025-07-21T21:20:04.734618146Z","level":"INFO","msg":"server is closed"}
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug.log ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Configure stats pid to 1336753
3
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/.config/wandb/settings
4
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from /home/cvm/flame/wandb/settings
5
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:setup_run_log_directory():703] Logging user logs to exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug.log
7
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to exp/mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log
8
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():830] calling init triggers
9
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-07-16 22:10:00,536 INFO MainThread:1336753 [wandb_init.py:init():871] starting backend
12
+ 2025-07-16 22:10:00,777 INFO MainThread:1336753 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-07-16 22:10:00,781 INFO MainThread:1336753 [wandb_init.py:init():882] backend started and connected
14
+ 2025-07-16 22:10:00,782 INFO MainThread:1336753 [wandb_init.py:init():953] updated telemetry
15
+ 2025-07-16 22:10:00,786 INFO MainThread:1336753 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-07-16 22:10:01,926 INFO MainThread:1336753 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2025-07-16 22:10:02,009 INFO MainThread:1336753 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2025-07-16 22:10:02,011 INFO MainThread:1336753 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2025-07-21 21:19:44,102 INFO MainThread:1336753 [wandb_run.py:_finish():2224] finishing run zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201
23
+ 2025-07-21 21:19:44,103 INFO MainThread:1336753 [wandb_run.py:_atexit_cleanup():2423] got exitcode: 0
24
+ 2025-07-21 21:19:44,104 INFO MainThread:1336753 [wandb_run.py:_restore():2405] restore
25
+ 2025-07-21 21:19:44,104 INFO MainThread:1336753 [wandb_run.py:_restore():2411] restore done
26
+ 2025-07-21 21:19:44,955 INFO MainThread:1336753 [wandb_run.py:_footer_history_summary_info():3903] rendering history
27
+ 2025-07-21 21:19:44,956 INFO MainThread:1336753 [wandb_run.py:_footer_history_summary_info():3935] rendering summary
28
+ 2025-07-21 21:19:44,957 INFO MainThread:1336753 [wandb_run.py:_footer_sync_info():3864] logging synced files
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<s>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "extra_special_tokens": {},
36
+ "legacy": true,
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": null,
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "LlamaTokenizer",
42
+ "unk_token": "<unk>",
43
+ "use_default_system_prompt": false
44
+ }
torchtitan/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ # Import to register Float8Converter.
10
+ import torchtitan.components.float8 # noqa: F401
11
+
12
+ # Import the built-in models here so that the corresponding register_model_spec()
13
+ # will be called.
14
+ import torchtitan.experiments # noqa: F401
15
+ import torchtitan.models # noqa: F401
torchtitan/config_manager.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import importlib
9
+ import inspect
10
+ import os
11
+ import sys
12
+ from collections import defaultdict
13
+ from typing import Tuple, Union
14
+
15
+ import torch
16
+
17
+ try:
18
+ import tomllib
19
+ except ModuleNotFoundError:
20
+ import tomli as tomllib
21
+
22
+ from torchtitan.tools.logging import logger
23
+
24
+ TORCH_DTYPE_MAP = {
25
+ "float16": torch.float16,
26
+ "float32": torch.float32,
27
+ "bfloat16": torch.bfloat16,
28
+ }
29
+
30
+
31
+ def string_list(raw_arg):
32
+ """Comma-separated string list argument."""
33
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
34
+
35
+
36
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
37
+ section, name = fullargname.split(".")
38
+ # Split string list which are still raw strings.
39
+ if (
40
+ section in args_dict
41
+ and name in args_dict[section]
42
+ and isinstance(args_dict[section][name], str)
43
+ ):
44
+ sec = args_dict[section]
45
+ sec[name] = string_list(sec[name])
46
+
47
+
48
+ class JobConfig:
49
+ """
50
+ A helper class to manage the train configuration.
51
+ Semantics:
52
+ - Default config is loaded from a toml file. If no toml file is provided,
53
+ then the default config is loaded from argparse defaults.
54
+ - if toml file has missing keys, they are filled with argparse defaults.
55
+ - if additional explicit cmd args are provided in addition to the toml
56
+ file, they will override the toml config and the argparse defaults
57
+
58
+ precedence order: cmdline > toml > argparse default
59
+
60
+ Arg parsing semantics:
61
+
62
+ Each argument starts with <prefix>_ which is the section name in the toml file
63
+ followed by name of the option in the toml file. For ex,
64
+ model.name translates to:
65
+ [model]
66
+ name
67
+ in the toml file
68
+ """
69
+
70
+ def __init__(self):
71
+ self.args_dict = None
72
+ # main parser
73
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
74
+
75
+ self.parser.add_argument(
76
+ "--job.config_file",
77
+ type=str,
78
+ default=None,
79
+ help="Job config file",
80
+ )
81
+
82
+ # job level configs
83
+ self.parser.add_argument(
84
+ "--job.dump_folder",
85
+ type=str,
86
+ default="./torchtitan/outputs",
87
+ help="Folder to dump job outputs",
88
+ )
89
+ self.parser.add_argument(
90
+ "--job.description",
91
+ type=str,
92
+ default="default job",
93
+ help="Description of the job",
94
+ )
95
+ self.parser.add_argument(
96
+ "--job.use_for_integration_test",
97
+ action="store_true",
98
+ help="Add this config to the integration test suite",
99
+ )
100
+ self.parser.add_argument(
101
+ "--job.print_args",
102
+ action="store_true",
103
+ help="Print the args to terminal",
104
+ )
105
+
106
+ # profiling configs
107
+ self.parser.add_argument(
108
+ "--profiling.enable_profiling",
109
+ action="store_true",
110
+ help="Whether to enable pytorch profiler",
111
+ )
112
+ self.parser.add_argument(
113
+ "--profiling.save_traces_folder",
114
+ type=str,
115
+ default="profile_traces",
116
+ help="Trace files location",
117
+ )
118
+ self.parser.add_argument(
119
+ "--profiling.profile_freq",
120
+ type=int,
121
+ default=10,
122
+ help="How often to collect profiler traces, in iterations",
123
+ )
124
+ self.parser.add_argument(
125
+ "--profiling.enable_memory_snapshot",
126
+ action="store_true",
127
+ help="Whether to dump memory snapshot",
128
+ )
129
+ self.parser.add_argument(
130
+ "--profiling.save_memory_snapshot_folder",
131
+ type=str,
132
+ default="memory_snapshot",
133
+ help="Memeory snapshot files location",
134
+ )
135
+
136
+ # metrics configs
137
+ self.parser.add_argument(
138
+ "--metrics.log_freq",
139
+ type=int,
140
+ default=10,
141
+ help="How often to log metrics to TensorBoard, in iterations",
142
+ )
143
+ self.parser.add_argument(
144
+ "--metrics.enable_tensorboard",
145
+ action="store_true",
146
+ help="Whether to log metrics to TensorBoard",
147
+ )
148
+ self.parser.add_argument(
149
+ "--metrics.disable_color_printing",
150
+ action="store_true",
151
+ help="Whether to disable color printing in logs",
152
+ )
153
+ self.parser.add_argument(
154
+ "--metrics.save_tb_folder",
155
+ type=str,
156
+ default="tb",
157
+ help="Folder to dump TensorBoard states",
158
+ )
159
+ self.parser.add_argument(
160
+ "--metrics.save_for_all_ranks",
161
+ action="store_true",
162
+ default=False,
163
+ help="""
164
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
165
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
166
+ component uses the 0th rank of the last stage pipeline group, which is the
167
+ only stage that computes loss metrics.
168
+ """,
169
+ )
170
+ self.parser.add_argument(
171
+ "--metrics.enable_wandb",
172
+ action="store_true",
173
+ help="Whether to log metrics to Weights & Biases",
174
+ )
175
+
176
+ # model configs
177
+ self.parser.add_argument(
178
+ "--model.name",
179
+ type=str,
180
+ default="llama3",
181
+ help="Which model to train",
182
+ )
183
+ self.parser.add_argument(
184
+ "--model.flavor",
185
+ type=str,
186
+ default="debugmodel",
187
+ help="Which model config to train",
188
+ )
189
+ self.parser.add_argument(
190
+ "--model.norm_type",
191
+ type=str,
192
+ default="rmsnorm",
193
+ choices=["layernorm", "np_layernorm", "rmsnorm"],
194
+ help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]",
195
+ )
196
+ self.parser.add_argument(
197
+ "--model.use_flex_attn",
198
+ action="store_true",
199
+ help="""
200
+ Whether to use Flex Attention.
201
+ Mixed usage of SDPA and FlexAttention is not upported yet.
202
+ """,
203
+ )
204
+ self.parser.add_argument(
205
+ "--model.attn_mask_type",
206
+ type=str,
207
+ default="causal",
208
+ choices=["causal", "block_causal"],
209
+ help="""
210
+ Specifies the type of bias/mask used for attention. If SDPA is used,
211
+ only the causal mask is supported by default. If FlexAttention is used,
212
+ both causal and block_causal masks are supported.
213
+ """,
214
+ )
215
+ self.parser.add_argument(
216
+ "--model.tokenizer_path",
217
+ type=str,
218
+ default="./assets/tokenizer/original/tokenizer.model",
219
+ help="Tokenizer path",
220
+ )
221
+ self.parser.add_argument(
222
+ "--model.converters",
223
+ type=string_list,
224
+ nargs="+",
225
+ default=[],
226
+ help="""
227
+ Comma separated list of converters to apply to the model.
228
+
229
+ For instance, the `float8` converter swaps `torch.nn.Linear`
230
+ with `Float8Linear`. This feature requires you to install 'torchao'
231
+ which can be found here: https://github.com/pytorch/ao
232
+ """,
233
+ )
234
+ self.parser.add_argument(
235
+ "--model.print_after_conversion",
236
+ action="store_true",
237
+ help="""
238
+ If true, model definition will be printed to stdout after all model
239
+ converters have been applied.
240
+ """,
241
+ )
242
+
243
+ # optimizer configs
244
+ self.parser.add_argument(
245
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
246
+ )
247
+ self.parser.add_argument(
248
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
249
+ )
250
+ self.parser.add_argument(
251
+ "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use"
252
+ )
253
+ self.parser.add_argument(
254
+ "--optimizer.implementation",
255
+ type=str,
256
+ default="fused",
257
+ choices=["for-loop", "foreach", "fused"],
258
+ help="""
259
+ Specify which optimizer implementation to use:
260
+ - 'fused': Use fused implementation (CUDA only) for best performance.
261
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
262
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
263
+ - more info: https://pytorch.org/docs/stable/optim.html
264
+ """,
265
+ )
266
+ self.parser.add_argument(
267
+ "--optimizer.early_step_in_backward",
268
+ action="store_true",
269
+ help="""
270
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
271
+ is not compatible with gradients clipping, users should not call
272
+ register_post_accumulate_grad_hook after the optimizer is built.""",
273
+ )
274
+
275
+ # lr scheduler configs
276
+ self.parser.add_argument(
277
+ "--lr_scheduler.warmup_steps",
278
+ type=int,
279
+ default=200,
280
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
281
+ )
282
+ self.parser.add_argument(
283
+ "--lr_scheduler.decay_ratio",
284
+ type=float,
285
+ default=None,
286
+ help="""
287
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
288
+
289
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
290
+ Otherwise, the learning rate will remain stable after the warmup period and
291
+ only start decaying during the last `decay_ratio` portion of the total training steps.
292
+
293
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
294
+ """,
295
+ )
296
+ self.parser.add_argument(
297
+ "--lr_scheduler.decay_type",
298
+ type=str,
299
+ default="linear",
300
+ choices=["linear", "sqrt", "cosine"],
301
+ help="""
302
+ Learning rate decay type to use during training:
303
+ - 'linear': linearly decays learning rate from initial to final value
304
+ - 'sqrt': decays learning rate following a 1 minus square root curve
305
+ - 'cosine': smoothly decays learning rate following a cosine curve
306
+ """,
307
+ )
308
+ self.parser.add_argument(
309
+ "--lr_scheduler.lr_min",
310
+ type=float,
311
+ default=0.0,
312
+ help="""
313
+ Min lr ratio for lr scheduler.
314
+
315
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
316
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
317
+ """,
318
+ )
319
+
320
+ # training configs
321
+ self.parser.add_argument(
322
+ "--training.dataset", type=str, default="c4_test", help="Dataset to use"
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.dataset_path",
326
+ type=str,
327
+ help="""
328
+ Path to the dataset in the file system. If provided, data will be
329
+ loaded from this path instead of downloaded.""",
330
+ )
331
+ self.parser.add_argument(
332
+ "--training.batch_size", type=int, default=8, help="Batch size"
333
+ )
334
+ self.parser.add_argument(
335
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
336
+ )
337
+ self.parser.add_argument(
338
+ "--training.max_norm",
339
+ type=Union[float, int],
340
+ default=1.0,
341
+ help="Max norm for gradient clipping",
342
+ )
343
+ self.parser.add_argument(
344
+ "--training.steps",
345
+ type=int,
346
+ default=10000,
347
+ help="How many train steps to run",
348
+ )
349
+ self.parser.add_argument(
350
+ "--training.enable_cpu_offload",
351
+ action="store_true",
352
+ help="""
353
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.mixed_precision_param",
357
+ type=str,
358
+ default="bfloat16",
359
+ choices=["bfloat16", "float32"],
360
+ help="""
361
+ torch dtype to use for parameters when applying mixed precision via FSDP.
362
+ This feature only takes effect when data_parallel_shard_degree > 1
363
+ """,
364
+ )
365
+ self.parser.add_argument(
366
+ "--training.mixed_precision_reduce",
367
+ type=str,
368
+ default="float32",
369
+ choices=["float32"],
370
+ help="""
371
+ torch dtype to use for reductions when applying mixed precision via FSDP.
372
+ This feature only takes effect when data_parallel_shard_degree > 1
373
+ """,
374
+ )
375
+ self.parser.add_argument(
376
+ "--training.compile",
377
+ action="store_true",
378
+ help="Whether to compile the model",
379
+ )
380
+ self.parser.add_argument(
381
+ "--training.gc_freq",
382
+ type=int,
383
+ default=50,
384
+ help="Python garbage control scheduling interval, in steps",
385
+ )
386
+ self.parser.add_argument(
387
+ "--training.seed",
388
+ type=int,
389
+ default=None,
390
+ help="Choose the base RNG seed used for training",
391
+ )
392
+ self.parser.add_argument(
393
+ "--training.deterministic",
394
+ action="store_true",
395
+ help="Use deterministic algorithms wherever possible, may be slower",
396
+ )
397
+
398
+ # parallelism configs
399
+ self.parser.add_argument(
400
+ "--parallelism.data_parallel_replicate_degree",
401
+ type=int,
402
+ default=1,
403
+ help="""
404
+ The `data_parallel_replicate_degree` argument specifies the degree of
405
+ data parallelism for weight replication. When this value is greater
406
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
407
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
408
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
409
+ parallelism method used is DDP (Distributed Data Parallelism).
410
+ 1 means disabled.""",
411
+ )
412
+ self.parser.add_argument(
413
+ "--parallelism.enable_compiled_autograd",
414
+ action="store_true",
415
+ help="Enable CompiledAutograd to compile the backward.",
416
+ )
417
+ self.parser.add_argument(
418
+ "--parallelism.data_parallel_shard_degree",
419
+ type=int,
420
+ default=-1,
421
+ help="""
422
+ The `data_parallel_shard_degree` argument specifies the degree of data
423
+ parallelism for weight sharding. When this value is greater than 1, weights
424
+ will be sharded across `data_parallel_shard_degree` ranks. If
425
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
426
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
427
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
428
+
429
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
430
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
431
+ )
432
+ self.parser.add_argument(
433
+ "--parallelism.fsdp_reshard_after_forward",
434
+ type=str,
435
+ default="default",
436
+ choices=["default", "always", "never"],
437
+ help="""
438
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
439
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
440
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
441
+ on `reshard_after_forward`.
442
+ The supported policies include "default", "always" and "never":
443
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
444
+ scenarios.
445
+ - "always" will enable `reshard_after_forward` for all forward passes.
446
+ - "never" will disable `reshard_after_forward` for all forward passes.
447
+ """,
448
+ )
449
+ self.parser.add_argument(
450
+ "--parallelism.tensor_parallel_degree",
451
+ type=int,
452
+ default=1,
453
+ help="Tensor Parallelism degree. 1 means disabled.",
454
+ )
455
+ self.parser.add_argument(
456
+ "--parallelism.disable_loss_parallel",
457
+ action="store_true",
458
+ help="Whether to apply loss parallel when sequence parallel is enabled",
459
+ )
460
+ self.parser.add_argument(
461
+ "--parallelism.enable_async_tensor_parallel",
462
+ action="store_true",
463
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
464
+ )
465
+ self.parser.add_argument(
466
+ "--parallelism.pipeline_parallel_degree",
467
+ type=int,
468
+ default=1,
469
+ help="""
470
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
471
+ If using looped schedules, this still specifies the number of physical ranks, not the number
472
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
473
+ )
474
+ self.parser.add_argument(
475
+ "--parallelism.pipeline_parallel_split_points",
476
+ type=string_list,
477
+ nargs="+",
478
+ default=[],
479
+ help="""
480
+ Specify comma-separated names of modules to use as the beginning of a split point.
481
+
482
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
483
+ the first containing all the layers up to layers.0,
484
+ the second containing layers.0 and up to layers.2,
485
+ the third containing layers.2 and all the remaining layers.
486
+
487
+ Note: fully-automated splitting may be enabled in the future,
488
+ but currently the split points must be specified manually.""",
489
+ )
490
+ self.parser.add_argument(
491
+ "--parallelism.pipeline_parallel_layers_per_stage",
492
+ type=int,
493
+ default=None,
494
+ help="""
495
+ The number of layers per stage. If specified, the split points will be calculated from
496
+ the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will
497
+ be inferred from the model, schedule, and pipeline_parallel_degree.""",
498
+ )
499
+ self.parser.add_argument(
500
+ "--parallelism.pipeline_parallel_schedule",
501
+ type=str,
502
+ default="1F1B",
503
+ help="""
504
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
505
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
506
+ The schedule must be compatible with the split points and stages_per_rank.
507
+
508
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
509
+ and split_points = number of stages - 1
510
+ """,
511
+ )
512
+ self.parser.add_argument(
513
+ "--parallelism.pipeline_parallel_schedule_csv",
514
+ type=str,
515
+ default="",
516
+ help="""
517
+ Specify the path to the pipeline parallel schedule csv file to use.
518
+ The pipeline_parallel_schedule argument must be either
519
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
520
+ """,
521
+ )
522
+ self.parser.add_argument(
523
+ "--parallelism.pipeline_parallel_microbatch_size",
524
+ type=int,
525
+ default=1,
526
+ help="""
527
+ The size of each pipeline parallel microbatch (default 1).
528
+
529
+ This value is used to compute the total number of microbatches by dividing batch_size with
530
+ pipeline_parallel_microbatch_size.
531
+
532
+ The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
533
+ """,
534
+ )
535
+ self.parser.add_argument(
536
+ "--parallelism.context_parallel_degree",
537
+ type=int,
538
+ default=1,
539
+ help="Context parallelism degree. 1 means disabled.",
540
+ )
541
+ self.parser.add_argument(
542
+ "--parallelism.context_parallel_rotate_method",
543
+ type=str,
544
+ default="allgather",
545
+ help="""
546
+ The collective to use in context parallel SDPA for kv shards exchange.
547
+
548
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
549
+
550
+ 'alltoall' means to all-to-all shuffle the kv shards.
551
+
552
+ The default value is 'allgather'.
553
+ """,
554
+ )
555
+
556
+ # checkpointing configs
557
+ self.parser.add_argument(
558
+ "--checkpoint.enable_checkpoint",
559
+ action="store_true",
560
+ help="Whether to enable checkpoint",
561
+ )
562
+ self.parser.add_argument(
563
+ "--checkpoint.folder",
564
+ type=str,
565
+ default="checkpoint",
566
+ help="""
567
+ The folder to store the checkpoints.
568
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
569
+ """,
570
+ )
571
+ self.parser.add_argument(
572
+ "--checkpoint.interval",
573
+ type=int,
574
+ default=500,
575
+ help="Checkpointing interval in steps.",
576
+ )
577
+ self.parser.add_argument(
578
+ "--checkpoint.model_weights_only",
579
+ action="store_true",
580
+ help="""
581
+ When model_weights_only=True, only model weights will be saved at the end of training.
582
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
583
+ When model_weights_only=False, the full checkpoint will be saved.
584
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
585
+ The default value is false.
586
+ """,
587
+ )
588
+ self.parser.add_argument(
589
+ "--checkpoint.export_dtype",
590
+ type=str,
591
+ default="float32",
592
+ choices=["float16", "bfloat16", "float32"],
593
+ help="""
594
+ Converts to the specified precision when training completes and model_weights_only=true.
595
+ Currently supports float32, float16, and bfloat16.
596
+ The default value is float32.
597
+ """,
598
+ )
599
+ self.parser.add_argument(
600
+ "--checkpoint.create_seed_checkpoint",
601
+ action="store_true",
602
+ help="""
603
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
604
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
605
+ Could be implemented as a separate script, but this way shares more code.
606
+ """,
607
+ )
608
+ self.parser.add_argument(
609
+ "--checkpoint.async_mode",
610
+ type=str,
611
+ default="disabled",
612
+ help="""
613
+ Which async checkpoint mode to use. Currently there are 3 different modes.
614
+ 1. "disabled": synchronized checkpointing will be used.
615
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
616
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
617
+ space and creates a separate process for faster GPU->CPU transfer
618
+ performance and eliminating GIL contention. The cost is increased CPU
619
+ memory usage. If insufficient CPU memory is available, performance may
620
+ degrade due to memory paging. For most users, "async" should suffice as
621
+ the performance overhead is typically small (on the order of tens of
622
+ seconds) compared to checkpointing frequency. This mode can be employed
623
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
624
+ appropriate hardware support such as ample CPU memory and fast PCIe.
625
+
626
+ "disabled" is the default mode.
627
+ """,
628
+ )
629
+ self.parser.add_argument(
630
+ "--checkpoint.keep_latest_k",
631
+ type=int,
632
+ default=10,
633
+ help="""
634
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
635
+ K cannot be 1 as the last one may be in the process of being saved. As a result,
636
+ the metadata of the last one may not be ready yet. The default value is 10 to avoid
637
+ filling up the disk.
638
+ """,
639
+ )
640
+ self.parser.add_argument(
641
+ "--checkpoint.load_step",
642
+ type=int,
643
+ default=-1,
644
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.exclude_from_loading",
648
+ type=string_list,
649
+ nargs="*",
650
+ default=[],
651
+ help="""
652
+ Exclude specific keys from being loaded from the checkpoint.
653
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
654
+ This will load the model only, excluding the specified keys.
655
+ """,
656
+ )
657
+
658
+ # activation checkpointing configs
659
+ self.parser.add_argument(
660
+ "--activation_checkpoint.mode",
661
+ type=str,
662
+ default="selective",
663
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
664
+ )
665
+ self.parser.add_argument(
666
+ "--activation_checkpoint.selective_ac_option",
667
+ type=str,
668
+ default="2", # 2 = checkpoint every other layer
669
+ help="""
670
+ Selective activation checkpointing options ['int', 'op'].
671
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
672
+ """,
673
+ )
674
+
675
+ # float8 configs
676
+ self.parser.add_argument(
677
+ "--float8.enable_fsdp_float8_all_gather",
678
+ action="store_true",
679
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
680
+ )
681
+ self.parser.add_argument(
682
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
683
+ action="store_true",
684
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
685
+ )
686
+ self.parser.add_argument(
687
+ "--float8.force_recompute_fp8_weight_in_bwd",
688
+ action="store_true",
689
+ help="""
690
+ Whether to force the recomputation of FP8 weights during backward pass.
691
+ When using FSDP with tensorwise scaling, it is recommended to enable
692
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
693
+ for backward computation.
694
+ """,
695
+ )
696
+ self.parser.add_argument(
697
+ "--float8.recipe_name",
698
+ type=str,
699
+ default=None,
700
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
701
+ help="""
702
+ If specified, creates float8 config from recipe name, valid choices are
703
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
704
+ """,
705
+ )
706
+ self.parser.add_argument(
707
+ "--float8.filter_fqns",
708
+ type=string_list,
709
+ default=[],
710
+ nargs="+",
711
+ help="""
712
+ Comma-separated list of fully qualified names of modules to skip applying float8 training to.
713
+ nn.Linear modules with any dim size not divisible by 16 are always skipped due to hardware requirements.
714
+ Example: --float8.module_filter_fqns "attention.wq,attention.wk,attention.wv,output"
715
+ """,
716
+ )
717
+
718
+ # communications library settings
719
+ self.parser.add_argument(
720
+ "--comm.init_timeout_seconds",
721
+ type=int,
722
+ default=300,
723
+ help="Timeout for communication operations, during initialization and first train step.",
724
+ )
725
+ self.parser.add_argument(
726
+ "--comm.train_timeout_seconds",
727
+ type=int,
728
+ default=100,
729
+ help=(
730
+ "Timeout for communication operations after the first train step -- "
731
+ "usually a tighter bound than during initialization."
732
+ ),
733
+ )
734
+ self.parser.add_argument(
735
+ "--comm.trace_buf_size",
736
+ type=int,
737
+ default=20000,
738
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
739
+ )
740
+
741
+ # memory estimation configs
742
+ self.parser.add_argument(
743
+ "--memory_estimation.enabled",
744
+ help="Whether to estimate memory usage for FSDP",
745
+ action="store_true",
746
+ )
747
+
748
+ self.parser.add_argument(
749
+ "--memory_estimation.disable_fake_mode",
750
+ help="Whether to estimate memory under FakeTensorMode",
751
+ action="store_true",
752
+ )
753
+
754
+ self.parser.add_argument(
755
+ "--fault_tolerance.enable",
756
+ action="store_true",
757
+ help="""
758
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
759
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
760
+ --fault_tolerance.group_size will be used to control the maximum
761
+ replicate group size as the replicate group size is dynamic.
762
+
763
+ Note that this is still an experimental feature.
764
+ """,
765
+ )
766
+
767
+ # torchft configs
768
+ self.parser.add_argument(
769
+ "--fault_tolerance.replica_id",
770
+ type=int,
771
+ default=0,
772
+ help="The TorchFT replica ID of this run.",
773
+ )
774
+ self.parser.add_argument(
775
+ "--fault_tolerance.group_size",
776
+ type=int,
777
+ default=0,
778
+ help="""
779
+ The number of TorchFT replicate groups. This number will be used for
780
+ dataloader to split the dataset across the replicate groups and FSDP
781
+ dimension
782
+ """,
783
+ )
784
+ self.parser.add_argument(
785
+ "--fault_tolerance.min_replica_size",
786
+ type=int,
787
+ default=1,
788
+ help="The minimum number of FT replica for each step.",
789
+ )
790
+
791
+ self.parser.add_argument(
792
+ "--experimental.custom_import",
793
+ type=str,
794
+ default="",
795
+ help="""
796
+ This option enables the importation of external modules.
797
+ Currently, it only supports dotted import modules (e.g., some_package.model_x).
798
+ It is the user's responsibility to ensure that the specified path can be
799
+ successfully imported. One method to achieve this, you can place your module
800
+ inside the ``torchtitan/torchtitan`` folder and execute ``pip install -e .`` to
801
+ make it available for import.
802
+ """,
803
+ )
804
+
805
+ self.parser.add_argument(
806
+ "--experimental.custom_args_module",
807
+ type=str,
808
+ default="",
809
+ help="""
810
+ This option allows users to extend TorchTitan's existing JobConfig by importing
811
+ a customized module. Similar to ``--experimental.custom_model_path``, the user
812
+ needs to ensure that the path can be imported. The module should contain exactly
813
+ one public function and the function has the signature
814
+ ``def func(parser: argparse.ArgumentParser) -> None:``. The user can use the
815
+ given parser to add new argument by calling``parser.add_argument``, as wish.
816
+ """,
817
+ )
818
+
819
+ self._is_parsed = False
820
+ self._allow_unkown_args = False
821
+
822
+ def maybe_add_custom_args(self) -> None:
823
+ """Add custom arguments to the parser if --experimental.custom_args_module is set.
824
+
825
+ Note: This function should be called before the parser is used to parse arguments.
826
+ """
827
+ if self._is_parsed:
828
+ raise RuntimeError(
829
+ "JobConfig has already been parsed. We could not add new arguments."
830
+ )
831
+
832
+ self._allow_unkown_args = True
833
+ self.parse_args(sys.argv[1:])
834
+ self._allow_unkown_args = False
835
+
836
+ if self.experimental.custom_args_module:
837
+ module = importlib.import_module(self.experimental.custom_args_module)
838
+ public_functions = [
839
+ name
840
+ for name, func in inspect.getmembers(module)
841
+ if inspect.isfunction(func) and not name.startswith("_")
842
+ ]
843
+ func = getattr(module, public_functions[0])
844
+ func(self.parser)
845
+
846
+ def to_dict(self):
847
+ return self.args_dict
848
+
849
+ def parse_args(self, args_list: list = sys.argv[1:]):
850
+ self._is_parsed = True
851
+ args, cmd_args = self.parse_args_from_command_line(args_list)
852
+ config_file = getattr(args, "job.config_file", None)
853
+ # build up a two level dict
854
+ args_dict = self._args_to_two_level_dict(args)
855
+ if config_file is not None:
856
+ try:
857
+ with open(config_file, "rb") as f:
858
+ for k, v in tomllib.load(f).items():
859
+ # to prevent overwrite of non-specified keys
860
+ args_dict[k] |= v
861
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
862
+ logger.exception(
863
+ f"Error while loading the configuration file: {config_file}"
864
+ )
865
+ logger.exception(f"Error details: {str(e)}")
866
+ raise e
867
+
868
+ # Checking string-list arguments are properly split into a list
869
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
870
+ string_list_argnames = self._get_string_list_argument_names()
871
+ for n in string_list_argnames:
872
+ check_string_list_argument(args_dict, n)
873
+
874
+ # override args dict with cmd_args
875
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
876
+ for section, section_args in cmd_args_dict.items():
877
+ for k, v in section_args.items():
878
+ args_dict[section][k] = v
879
+
880
+ self.args_dict = args_dict
881
+
882
+ for k, v in args_dict.items():
883
+ class_type = type(k.title(), (), v)
884
+ setattr(self, k, class_type())
885
+ self._validate_config()
886
+
887
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
888
+ args_dict = defaultdict(defaultdict)
889
+ for k, v in vars(args).items():
890
+ first_level_key, second_level_key = k.split(".", 1)
891
+ args_dict[first_level_key][second_level_key] = v
892
+ return args_dict
893
+
894
+ def _validate_config(self) -> None:
895
+ # TODO: temporary mitigation of BC breaking change in
896
+ # tokenizer default path, need to remove later
897
+ if not os.path.exists(self.model.tokenizer_path):
898
+ logger.warning(
899
+ f"Tokenizer path {self.model.tokenizer_path} does not exist!"
900
+ )
901
+ old_tokenizer_path = (
902
+ "torchtitan/datasets/tokenizer/original/tokenizer.model"
903
+ )
904
+ if os.path.exists(old_tokenizer_path):
905
+ self.model.tokenizer_path = old_tokenizer_path
906
+ logger.warning(
907
+ f"Temporarily switching to previous default tokenizer path {old_tokenizer_path}. "
908
+ "Please update your config."
909
+ )
910
+
911
+ def _get_string_list_argument_names(self) -> list[str]:
912
+ """Get the parser argument names of type `string_list`."""
913
+ string_list_args = [
914
+ v.dest for v in self.parser._actions if v.type is string_list
915
+ ]
916
+ return string_list_args
917
+
918
+ def parse_args_from_command_line(
919
+ self, args_list
920
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
921
+ """
922
+ Parse command line arguments and return the parsed args and the command line only args
923
+ """
924
+ if self._allow_unkown_args:
925
+ args, _ = self.parser.parse_known_args(args_list)
926
+ else:
927
+ args = self.parser.parse_args(args_list)
928
+ string_list_argnames = set(self._get_string_list_argument_names())
929
+
930
+ # aux parser to parse the command line only args, with no defaults from main parser
931
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
932
+ for arg, val in vars(args).items():
933
+ if isinstance(val, bool):
934
+ aux_parser.add_argument(
935
+ "--" + arg, action="store_true" if val else "store_false"
936
+ )
937
+ elif arg in string_list_argnames:
938
+ # without this special case, type inference breaks here,
939
+ # since the inferred type is just 'list' and it ends up flattening
940
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
941
+ aux_parser.add_argument("--" + arg, type=string_list)
942
+ else:
943
+ aux_parser.add_argument("--" + arg, type=type(val))
944
+
945
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
946
+
947
+ return args, cmd_args
torchtitan/train.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+ from typing import Any, Generator, Iterable, Optional
12
+
13
+ import torch
14
+ from torch.distributed.elastic.multiprocessing.errors import record
15
+
16
+ import torchtitan.components.ft as ft
17
+ import torchtitan.protocols.train_spec as train_spec_module
18
+
19
+ from torchtitan.components.checkpoint import CheckpointManager
20
+ from torchtitan.components.metrics import (
21
+ build_metrics_processor,
22
+ ensure_pp_loss_visible,
23
+ )
24
+ from torchtitan.config_manager import JobConfig
25
+ from torchtitan.distributed import ParallelDims, utils as dist_utils
26
+ from torchtitan.protocols.model_converter import build_model_converters
27
+ from torchtitan.tools import utils
28
+ from torchtitan.tools.logging import init_logger, logger
29
+ from torchtitan.tools.profiling import (
30
+ maybe_enable_memory_snapshot,
31
+ maybe_enable_profiling,
32
+ )
33
+
34
+
35
+ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
36
+ job_config: JobConfig
37
+ gc_handler: utils.GarbageCollection
38
+
39
+ parallel_dims: ParallelDims
40
+ train_spec: train_spec_module.TrainSpec
41
+ world_mesh: torch.distributed.DeviceMesh
42
+
43
+ dataloader: train_spec_module.BaseDataLoader
44
+ metrics_processor: train_spec_module.MetricsProcessor
45
+ checkpointer: CheckpointManager
46
+ train_context: Generator[None, None, None]
47
+
48
+ model_parts: list[torch.nn.Module]
49
+ loss_fn: train_spec_module.LossFunction
50
+ optimizers: train_spec_module.OptimizersContainer
51
+ lr_schedulers: train_spec_module.LRSchedulersContainer
52
+
53
+ pp_has_first_stage: bool
54
+ pp_has_last_stage: bool
55
+
56
+ device: torch.device
57
+
58
+ # states
59
+ step: int
60
+
61
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
62
+ @record
63
+ def __init__(self, job_config: JobConfig):
64
+ self.job_config = job_config
65
+
66
+ logger.info(f"Starting job: {job_config.job.description}")
67
+
68
+ if job_config.experimental.custom_import:
69
+ importlib.import_module(job_config.experimental.custom_import)
70
+
71
+ if job_config.job.print_args:
72
+ logger.info(f"Running with args: {job_config.to_dict()}")
73
+
74
+ # take control of garbage collection to avoid stragglers
75
+ self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
76
+
77
+ device_module, device_type = utils.device_module, utils.device_type
78
+ self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
79
+ # Device has to be set before creating TorchFT manager.
80
+ device_module.set_device(self.device)
81
+ ft_manager = ft.init_ft_manager(job_config)
82
+
83
+ # init distributed
84
+ world_size = int(os.environ["WORLD_SIZE"])
85
+ parallelism_config = job_config.parallelism
86
+ if not ft_manager.enabled:
87
+ self.parallel_dims = parallel_dims = ParallelDims(
88
+ dp_shard=parallelism_config.data_parallel_shard_degree,
89
+ dp_replicate=parallelism_config.data_parallel_replicate_degree,
90
+ cp=parallelism_config.context_parallel_degree,
91
+ tp=parallelism_config.tensor_parallel_degree,
92
+ pp=parallelism_config.pipeline_parallel_degree,
93
+ world_size=world_size,
94
+ enable_loss_parallel=not parallelism_config.disable_loss_parallel,
95
+ )
96
+ else:
97
+ self.parallel_dims = parallel_dims = ft.FTParallelDims(
98
+ dp_shard=parallelism_config.data_parallel_shard_degree,
99
+ dp_replicate=parallelism_config.data_parallel_replicate_degree,
100
+ cp=parallelism_config.context_parallel_degree,
101
+ tp=parallelism_config.tensor_parallel_degree,
102
+ pp=parallelism_config.pipeline_parallel_degree,
103
+ world_size=world_size,
104
+ enable_loss_parallel=not parallelism_config.disable_loss_parallel,
105
+ ft_manager=ft_manager,
106
+ )
107
+ dist_utils.init_distributed(job_config)
108
+
109
+ # build meshes
110
+ self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type)
111
+ if parallel_dims.dp_enabled:
112
+ dp_mesh = world_mesh["dp"]
113
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
114
+ else:
115
+ dp_degree, dp_rank = 1, 0
116
+
117
+ # Set random seed, and maybe enable deterministic mode
118
+ # (mainly for debugging, expect perf loss).
119
+ dist_utils.set_determinism(
120
+ world_mesh,
121
+ self.device,
122
+ job_config.training.seed,
123
+ job_config.training.deterministic,
124
+ )
125
+ self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
126
+
127
+ # build dataloader
128
+ tokenizer = (
129
+ self.train_spec.build_tokenizer_fn(job_config)
130
+ if self.train_spec.build_tokenizer_fn is not None
131
+ else None
132
+ )
133
+
134
+ # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
135
+ # dataloader must be changed.
136
+ if ft_manager.enabled:
137
+ dp_degree, dp_rank = ft_manager.get_dp_info(dp_degree, dp_rank)
138
+
139
+ self.dataloader = self.train_spec.build_dataloader_fn(
140
+ dp_world_size=dp_degree,
141
+ dp_rank=dp_rank,
142
+ tokenizer=tokenizer,
143
+ job_config=job_config,
144
+ )
145
+
146
+ # build model (using meta init)
147
+ model_cls = self.train_spec.cls
148
+ model_args = self.train_spec.config[job_config.model.flavor]
149
+ # set the model args from training job configs
150
+ model_args.update_from_config(job_config, tokenizer)
151
+
152
+ logger.info(
153
+ f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
154
+ )
155
+ with torch.device("meta"):
156
+ model = model_cls.from_model_args(model_args)
157
+
158
+ # Build the collection of model converters. No-op if `model.converters` empty
159
+ model_converters = build_model_converters(job_config, parallel_dims)
160
+ model_converters.convert(model)
161
+
162
+ # metrics logging
163
+ build_metrics_processor_fn = (
164
+ build_metrics_processor
165
+ if self.train_spec.build_metrics_processor_fn is None
166
+ else self.train_spec.build_metrics_processor_fn
167
+ )
168
+ self.metrics_processor = build_metrics_processor_fn(job_config, parallel_dims)
169
+ color = self.metrics_processor.color
170
+
171
+ # calculate model size and flops per token
172
+ (
173
+ model_param_count,
174
+ self.metrics_processor.num_flops_per_token,
175
+ ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
176
+
177
+ logger.info(
178
+ f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
179
+ f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
180
+ )
181
+
182
+ # move sharded model to CPU/GPU and initialize weights via DTensor
183
+ if job_config.checkpoint.create_seed_checkpoint:
184
+ init_device = "cpu"
185
+ buffer_device = None
186
+ elif job_config.training.enable_cpu_offload:
187
+ init_device = "cpu"
188
+ buffer_device = device_type
189
+ else:
190
+ init_device = device_type
191
+ buffer_device = None
192
+
193
+ self.loss_fn = self.train_spec.build_loss_fn(job_config)
194
+
195
+ # apply parallelisms and initialization
196
+ if parallel_dims.pp_enabled:
197
+ if not self.train_spec.pipelining_fn:
198
+ raise RuntimeError(
199
+ f"Pipeline Parallel is enabled but {self.train_spec.name} "
200
+ f"does not support pipelining"
201
+ )
202
+
203
+ # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
204
+ (
205
+ self.pp_schedule,
206
+ self.model_parts,
207
+ self.pp_has_first_stage,
208
+ self.pp_has_last_stage,
209
+ ) = self.train_spec.pipelining_fn(
210
+ model,
211
+ world_mesh,
212
+ parallel_dims,
213
+ job_config,
214
+ self.device,
215
+ model_args,
216
+ self.train_spec.parallelize_fn,
217
+ self.loss_fn,
218
+ )
219
+ # when PP is enabled, `model` obj is no longer used after this point,
220
+ # model_parts is used instead
221
+ del model
222
+
223
+ for m in self.model_parts:
224
+ m.to_empty(device=init_device)
225
+ with torch.no_grad():
226
+ m.init_weights(buffer_device=buffer_device)
227
+ m.train()
228
+
229
+ # confirm that user will be able to view loss metrics on the console
230
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
231
+ else:
232
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
233
+ model = self.train_spec.parallelize_fn(
234
+ model, world_mesh, parallel_dims, job_config
235
+ )
236
+
237
+ model.to_empty(device=init_device)
238
+ with torch.no_grad():
239
+ model.init_weights(buffer_device=buffer_device)
240
+ model.train()
241
+
242
+ self.model_parts = [model]
243
+
244
+ # initialize device memory monitor and get peak flops for MFU calculation
245
+ device_memory_monitor = self.metrics_processor.device_memory_monitor
246
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
247
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
248
+ device_mem_stats = device_memory_monitor.get_peak_stats()
249
+ logger.info(
250
+ f"{device_type.upper()} memory usage for model: "
251
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
252
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
253
+ )
254
+
255
+ # build optimizer after applying parallelisms to the model
256
+ self.optimizers = self.train_spec.build_optimizers_fn(
257
+ self.model_parts, job_config, ft_manager
258
+ )
259
+ self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
260
+ self.optimizers, job_config
261
+ )
262
+ # Post optimizer step model converters hook.
263
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
264
+ # where it issues a single all-reduce for all parameters at once for better performance
265
+ self.optimizers.register_step_post_hook(
266
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(
267
+ self.model_parts
268
+ )
269
+ )
270
+ self.metrics_processor.optimizers = self.optimizers
271
+
272
+ # Initialize trainer states that will be saved in checkpoint.
273
+ # These attributes must be initialized before checkpoint loading.
274
+ self.step = 0
275
+
276
+ self.checkpointer = CheckpointManager(
277
+ dataloader=self.dataloader,
278
+ model_parts=self.model_parts,
279
+ optimizers=self.optimizers,
280
+ lr_schedulers=self.lr_schedulers,
281
+ states={"train_state": self},
282
+ job_config=job_config,
283
+ ft_manager=ft_manager,
284
+ )
285
+
286
+ self.train_context = dist_utils.get_train_context(
287
+ parallel_dims.loss_parallel_enabled,
288
+ parallelism_config.enable_compiled_autograd,
289
+ )
290
+
291
+ logger.info(
292
+ "Trainer is initialized with "
293
+ f"local batch size {job_config.training.batch_size}, "
294
+ f"global batch size {job_config.training.batch_size * dp_degree}, "
295
+ f"sequence length {job_config.training.seq_len}, "
296
+ f"total steps {job_config.training.steps} "
297
+ f"(warmup {job_config.lr_scheduler.warmup_steps})."
298
+ )
299
+
300
+ def next_batch(
301
+ self, data_iterator: Iterable
302
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
303
+ data_load_start = time.perf_counter()
304
+ batch = next(data_iterator)
305
+ input_dict, labels = batch
306
+ self.metrics_processor.ntokens_since_last_log += labels.numel()
307
+ self.metrics_processor.data_loading_times.append(
308
+ time.perf_counter() - data_load_start
309
+ )
310
+
311
+ device_type = utils.device_type
312
+ for k, _ in input_dict.items():
313
+ input_dict[k] = input_dict[k].to(device_type)
314
+ labels = labels.to(device_type)
315
+ return input_dict, labels
316
+
317
+ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
318
+ self.optimizers.zero_grad()
319
+
320
+ # Keep these variables local to shorten the code as these are
321
+ # the major variables that are used in the training loop.
322
+ model_parts = self.model_parts
323
+ world_mesh = self.world_mesh
324
+ parallel_dims = self.parallel_dims
325
+
326
+ # apply context parallelism if cp is enabled
327
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
328
+ inputs = input_dict["input"]
329
+ optional_context_parallel_ctx = (
330
+ dist_utils.create_context_parallel_ctx(
331
+ cp_mesh=world_mesh["cp"],
332
+ cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
333
+ cp_seq_dims=[1, 1] + [0 for _ in model_parts],
334
+ cp_no_restore_buffers={inputs, labels},
335
+ cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
336
+ )
337
+ if parallel_dims.cp_enabled
338
+ else None
339
+ )
340
+
341
+ if parallel_dims.pp_enabled:
342
+ # Pipeline Parallel forward / backward inside step() call
343
+ with self.train_context(optional_context_parallel_ctx):
344
+ targets, losses = (
345
+ (labels, []) if self.pp_has_last_stage else (None, None)
346
+ )
347
+ if self.pp_has_first_stage:
348
+ self.pp_schedule.step(inputs, target=targets, losses=losses)
349
+ else:
350
+ self.pp_schedule.step(target=targets, losses=losses)
351
+
352
+ # accumulate losses across pipeline microbatches
353
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
354
+ loss = (
355
+ torch.mean(torch.stack(losses)).to(self.device)
356
+ if self.pp_has_last_stage
357
+ else torch.tensor([-1.0], device=self.device)
358
+ )
359
+ else:
360
+ # Non-PP forward / backward
361
+ with self.train_context(optional_context_parallel_ctx):
362
+ assert len(model_parts) == 1
363
+ pred = model_parts[0](inputs)
364
+ loss = self.loss_fn(pred, labels)
365
+ # pred.shape=(bs, seq_len, vocab_size)
366
+ # need to free to before bwd to avoid peaking memory
367
+ del pred
368
+ loss.backward()
369
+
370
+ dist_utils.clip_grad_norm_(
371
+ [p for m in model_parts for p in m.parameters()],
372
+ self.job_config.training.max_norm,
373
+ foreach=True,
374
+ pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
375
+ )
376
+ self.checkpointer.maybe_wait_for_staging()
377
+ self.optimizers.step()
378
+ self.lr_schedulers.step()
379
+
380
+ # log metrics
381
+ if not self.metrics_processor.should_log(self.step):
382
+ return
383
+
384
+ if (
385
+ parallel_dims.dp_replicate_enabled
386
+ or parallel_dims.dp_shard_enabled
387
+ or parallel_dims.cp_enabled
388
+ ):
389
+ loss = loss.detach()
390
+ global_avg_loss, global_max_loss = (
391
+ dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
392
+ dist_utils.dist_max(loss, world_mesh["dp_cp"]),
393
+ )
394
+ else:
395
+ global_avg_loss = global_max_loss = loss.item()
396
+
397
+ self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
398
+
399
+ @record
400
+ def train(self):
401
+ job_config = self.job_config
402
+
403
+ self.checkpointer.load(step=job_config.checkpoint.load_step)
404
+ logger.info(f"Training starts at step {self.step + 1}.")
405
+
406
+ with maybe_enable_profiling(
407
+ job_config, global_step=self.step
408
+ ) as torch_profiler, maybe_enable_memory_snapshot(
409
+ job_config, global_step=self.step
410
+ ) as memory_profiler:
411
+ data_iterator = iter(self.dataloader)
412
+ while self.step < job_config.training.steps:
413
+ self.step += 1
414
+ self.gc_handler.run(self.step)
415
+ inputs, labels = self.next_batch(data_iterator)
416
+ self.train_step(inputs, labels)
417
+ self.checkpointer.save(
418
+ self.step, force=(self.step == job_config.training.steps)
419
+ )
420
+
421
+ # signal the profiler that the next profiling step has started
422
+ if torch_profiler:
423
+ torch_profiler.step()
424
+ if memory_profiler:
425
+ memory_profiler.step()
426
+
427
+ # reduce timeout after first train step for faster signal
428
+ # (assuming lazy init and compilation are finished)
429
+ if self.step == 1:
430
+ dist_utils.set_pg_timeouts(
431
+ timeout=timedelta(
432
+ seconds=job_config.comm.train_timeout_seconds
433
+ ),
434
+ world_mesh=self.world_mesh,
435
+ )
436
+
437
+ if torch.distributed.get_rank() == 0:
438
+ logger.info("Sleeping 2 seconds for other ranks to complete")
439
+ time.sleep(2)
440
+
441
+ self.metrics_processor.close()
442
+ logger.info("Training completed")
443
+
444
+ def state_dict(self) -> dict[str, Any]:
445
+ return {"step": self.step}
446
+
447
+ def load_state_dict(self, state_dict: dict[str, Any]):
448
+ self.step = state_dict["step"]
449
+
450
+ def close(self) -> None:
451
+ if self.checkpointer:
452
+ self.checkpointer.close()
453
+
454
+
455
+ if __name__ == "__main__":
456
+ init_logger()
457
+ config = JobConfig()
458
+ config.maybe_add_custom_args()
459
+ config.parse_args()
460
+ trainer: Optional[Trainer] = None
461
+
462
+ try:
463
+ trainer = Trainer(config)
464
+
465
+ if config.checkpoint.create_seed_checkpoint:
466
+ assert int(
467
+ os.environ["WORLD_SIZE"]
468
+ ), "Must create seed checkpoint using a single device, to disable sharding."
469
+ assert (
470
+ config.checkpoint.enable_checkpoint
471
+ ), "Must enable checkpointing when creating a seed checkpoint."
472
+ trainer.checkpointer.save(curr_step=0, force=True)
473
+ logger.info("Created seed checkpoint")
474
+ else:
475
+ trainer.train()
476
+ finally:
477
+ if trainer:
478
+ trainer.close()
479
+
480
+ if torch.distributed.is_initialized():
481
+ torch.distributed.destroy_process_group()
482
+ logger.info("Process group destroyed.")
train.sh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ params=""
4
+ if [ $# -ne 0 ]; then
5
+ params="$*"
6
+ fi
7
+
8
+ # use envs as local params for convenience
9
+ # e.g.
10
+ # NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
11
+ NNODE=${NNODE:-"1"}
12
+ NGPU=${NGPU:-"8"}
13
+ LOG_RANK=${LOG_RANK:-0}
14
+
15
+ if [[ -z "${MASTER_ADDR}" ]]; then
16
+ export MASTER_ADDR="localhost"
17
+ fi
18
+ if [[ -z "${MASTER_PORT}" ]]; then
19
+ export MASTER_PORT="0"
20
+ fi
21
+
22
+ : '
23
+ Usage:
24
+
25
+ bash train.sh -h
26
+
27
+ Training a 340M model:
28
+
29
+ NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
30
+ --job.config_file flame/models/fla.toml \
31
+ --job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
32
+ --model.config configs/transformer_340M.json \
33
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
34
+ --optimizer.name AdamW \
35
+ --optimizer.eps 1e-15 \
36
+ --optimizer.lr 3e-4 \
37
+ --lr_scheduler.warmup_steps 1024 \
38
+ --lr_scheduler.lr_min 0.1 \
39
+ --lr_scheduler.decay_type cosine \
40
+ --training.batch_size 32 \
41
+ --training.seq_len 2048 \
42
+ --training.gradient_accumulation_steps 1 \
43
+ --training.steps 20480 \
44
+ --training.max_norm 1.0 \
45
+ --training.skip_nan_inf \
46
+ --training.dataset HuggingFaceFW/fineweb-edu \
47
+ --training.dataset_name default \
48
+ --training.dataset_split train \
49
+ --training.streaming \
50
+ --training.num_workers 32 \
51
+ --training.prefetch_factor 2 \
52
+ --training.seed 42 \
53
+ --training.compile \
54
+ --training.tensor_parallel_degree 1 \
55
+ --training.disable_loss_parallel \
56
+ --checkpoint.interval 2048 \
57
+ --checkpoint.load_step -1 \
58
+ --metrics.log_freq 1
59
+ '
60
+
61
+ echo "Launching training..."
62
+
63
+ set -x
64
+ path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
65
+ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
66
+ config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
67
+ tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
68
+ model=$(
69
+ python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
70
+ )
71
+
72
+ mkdir -p $path
73
+ cp * $path
74
+ cp -r configs $path
75
+ cp -r flame $path
76
+ cp -r 3rdparty/flash-linear-attention/fla $path
77
+ cp -r 3rdparty/torchtitan/torchtitan $path
78
+
79
+ # for offline systems
80
+ # export TRANSFORMERS_OFFLINE=1
81
+ # export HF_DATASETS_OFFLINE=1
82
+ # export HF_HUB_OFFLINE=1
83
+ if [ "$date" == "" ]; then
84
+ date=$(date +%Y%m%d%H%M)
85
+ fi
86
+ RUN_NAME="$model-$(basename $path)"
87
+ RUN_ID="$RUN_NAME-$date"
88
+
89
+ export WANDB_RESUME=allow
90
+ if [[ -z "${WANDB_PROJECT}" ]]; then
91
+ export WANDB_PROJECT="fla"
92
+ fi
93
+ if [[ -z "${WANDB_NAME}" ]]; then
94
+ export WANDB_NAME="$RUN_NAME"
95
+ fi
96
+ if [[ -z "${WANDB_RUN_ID}" ]]; then
97
+ export WANDB_RUN_ID="$RUN_ID"
98
+ fi
99
+
100
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
101
+ torchrun --nnodes=${NNODE} \
102
+ --nproc_per_node=${NGPU} \
103
+ --rdzv_backend c10d \
104
+ --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
105
+ --local-ranks-filter ${LOG_RANK} \
106
+ --role rank \
107
+ --tee 3 \
108
+ --log-dir $path/logs \
109
+ -m flame.train \
110
+ $params
111
+
112
+ echo "TRAINING DONE!"
113
+ echo "Converting the DCP checkpoints to HF format..."
114
+
115
+ python -m flame.utils.convert_dcp_to_hf \
116
+ --path $path \
117
+ --step $steps \
118
+ --config $config \
119
+ --tokenizer $tokenizer
120
+
121
+ echo "RUNNING DONE!"