zaydzuhri commited on
Commit
2c5a30a
·
verified ·
1 Parent(s): 214e0ab

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +471 -0
  3. config.json +34 -0
  4. configs/delta_net_1B.json +29 -0
  5. configs/delta_net_340M.json +27 -0
  6. configs/gla_340M.json +24 -0
  7. configs/gla_7B.json +25 -0
  8. configs/gsa_340M.json +29 -0
  9. configs/hgrn2_340M.json +20 -0
  10. configs/mtp_transformer_120M.json +19 -0
  11. configs/mtp_transformer_1B.json +23 -0
  12. configs/mtp_transformer_340M.json +19 -0
  13. configs/mtp_transformer_7B.json +22 -0
  14. configs/top_transformer_120M.json +20 -0
  15. configs/top_transformer_1B.json +24 -0
  16. configs/top_transformer_340M.json +20 -0
  17. configs/top_transformer_7B.json +23 -0
  18. configs/transformer_120M.json +18 -0
  19. configs/transformer_1B.json +22 -0
  20. configs/transformer_340M.json +18 -0
  21. configs/transformer_7B.json +21 -0
  22. download_checkpoint.py +35 -0
  23. fla/__init__.py +110 -0
  24. fla/utils.py +223 -0
  25. flame/__init__.py +1 -0
  26. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  27. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  28. flame/__pycache__/data.cpython-312.pyc +0 -0
  29. flame/__pycache__/train.cpython-312.pyc +0 -0
  30. flame/components/__init__.py +0 -0
  31. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
  32. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  33. flame/components/checkpoint.py +59 -0
  34. flame/config_manager.py +940 -0
  35. flame/data.py +570 -0
  36. flame/models/__init__.py +0 -0
  37. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  38. flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
  39. flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
  40. flame/models/activation_offloading.py +447 -0
  41. flame/models/fla.toml +67 -0
  42. flame/models/parallelize_fla.py +550 -0
  43. flame/models/pipeline_fla.py +162 -0
  44. flame/tools/__init__.py +0 -0
  45. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  46. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  47. flame/tools/utils.py +41 -0
  48. flame/train.py +897 -0
  49. flame/utils/__init__.py +0 -0
  50. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
8
+
9
+ **Feature Highlights:**
10
+
11
+ - 🚀 Minimal, easy-to-use, extensible training framework
12
+ - 🤗 Seamless integration with `fla` and `transformers`
13
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
14
+ - 🔮 4D parallelism (coming soon)
15
+
16
+ ## Setup
17
+
18
+ To get started, clone the `flame` repository and install the required dependencies:
19
+
20
+ ```bash
21
+ git clone https://github.com/fla-org/flame.git
22
+ cd flame
23
+ pip install .
24
+ ```
25
+
26
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
27
+ After installation, initialize and update the submodules:
28
+ ```sh
29
+ git submodule update --init --recursive
30
+ ```
31
+
32
+ ## Dataset Preparation
33
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
34
+
35
+ ```py
36
+ from datasets import load_dataset
37
+
38
+ # load fineweb-edu with parallel processing
39
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
40
+
41
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
42
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
43
+ ```
44
+
45
+ ## Training Recipes
46
+
47
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
48
+
49
+ > [!WARNING]
50
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
51
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
52
+
53
+ ```sh
54
+ bash train.sh \
55
+ --job.config_file flame/models/fla.toml \
56
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
57
+ --model.config configs/transformer_340M.json \
58
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
59
+ --optimizer.name AdamW \
60
+ --optimizer.eps 1e-15 \
61
+ --optimizer.lr 3e-4 \
62
+ --lr_scheduler.warmup_steps 1024 \
63
+ --lr_scheduler.lr_min 0.1 \
64
+ --lr_scheduler.decay_type cosine \
65
+ --training.batch_size 1 \
66
+ --training.seq_len 65536 \
67
+ --training.context_len 4096 \
68
+ --training.varlen \
69
+ --training.gradient_accumulation_steps 1 \
70
+ --training.steps 20480 \
71
+ --training.max_norm 1.0 \
72
+ --training.skip_nan_inf \
73
+ --training.dataset HuggingFaceFW/fineweb-edu \
74
+ --training.dataset_name sample-100BT \
75
+ --training.dataset_split train \
76
+ --training.streaming \
77
+ --training.num_workers 32 \
78
+ --training.prefetch_factor 2 \
79
+ --training.seed 42 \
80
+ --training.compile \
81
+ --checkpoint.interval 2048 \
82
+ --checkpoint.load_step -1 \
83
+ --checkpoint.keep_latest_k 2 \
84
+ --metrics.log_freq 1
85
+ ```
86
+
87
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
88
+ **For single-GPU debugging, set `NGPU=1`.**
89
+
90
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
91
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
92
+
93
+ **Key parameters:**
94
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
95
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
96
+ - `--training.steps`: Total number of training steps.
97
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
98
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
99
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
100
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
101
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
102
+
103
+ > [!WARNING]
104
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
105
+ > Each step processes `global_batch_size * seq_len` tokens.
106
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
107
+
108
+ For a detailed explanation of all parameters, run:
109
+
110
+ ```sh
111
+ bash train.sh -h
112
+ ```
113
+
114
+ <details>
115
+ <summary>Usage</summary>
116
+
117
+ ```py
118
+ options:
119
+ -h, --help show this help message and exit
120
+ --job.config_file JOB.CONFIG_FILE
121
+ Job config file
122
+ --job.dump_folder JOB.DUMP_FOLDER
123
+ Folder to dump job outputs
124
+ --job.description JOB.DESCRIPTION
125
+ Description of the job
126
+ --job.use_for_integration_test
127
+ Add this config to the integration test suite
128
+ --job.print_args Print the args to terminal
129
+ --model.config MODEL.CONFIG
130
+ Path to the model config
131
+ --model.norm_type MODEL.NORM_TYPE
132
+ Type of layer normalization to use [layernorm,
133
+ np_layernorm, rmsnorm, fused_rmsnorm]
134
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
135
+ Tokenizer path
136
+ --profiling.enable_profiling
137
+ Whether to enable pytorch profiler
138
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
139
+ Trace files location
140
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
141
+ How often to collect profiler traces, in iterations
142
+ --profiling.enable_memory_snapshot
143
+ Whether to dump memory snapshot
144
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
145
+ Memeory snapshot files location
146
+ --optimizer.name OPTIMIZER.NAME
147
+ Optimizer to use
148
+ --optimizer.eps OPTIMIZER.EPS
149
+ Epsilon value for the optimizer.
150
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
151
+ --optimizer.scheduler {wsd,cosine,linear}
152
+ Scheduler to use. Currently supported: wsd, cosine,
153
+ and linear.
154
+ --optimizer.lr OPTIMIZER.LR
155
+ Learning rate to use
156
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
157
+ Min lr ratio for lr scheduler
158
+ --optimizer.early_step_in_backward
159
+ Whether to apply optimizer in the backward. Caution,
160
+ optimizer_in_backward is not compatible with gradients
161
+ clipping, users should not call
162
+ register_post_accumulate_grad_hook after the optimizer
163
+ is built.
164
+ --training.batch_size TRAINING.BATCH_SIZE
165
+ Batch size
166
+ --training.seq_len TRAINING.SEQ_LEN
167
+ Sequence length
168
+ --training.context_len TRAINING.CONTEXT_LEN
169
+ Max length allowed for each sequence
170
+ --training.varlen Whether to take sequences of variable length as input
171
+ --training.warmup_steps TRAINING.WARMUP_STEPS
172
+ Steps for lr scheduler warmup, normally 1/5 of
173
+ --training.steps
174
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
175
+ Number of steps to accumulate gradients before
176
+ updating parameters
177
+ --training.steps TRAINING.STEPS
178
+ How many train steps to run
179
+ --training.max_norm TRAINING.MAX_NORM
180
+ Max norm for gradient clipping
181
+ --training.skip_nan_inf
182
+ Skip batch updates when NaN or INF gradients are
183
+ encountered during training
184
+ --training.dataset TRAINING.DATASET
185
+ Dataset to use, with comma separated values
186
+ --training.dataset_name TRAINING.DATASET_NAME
187
+ The name of the dataset config, with comma separated
188
+ values if provided
189
+ --training.dataset_split TRAINING.DATASET_SPLIT
190
+ Dataset split to use, with comma separated values if
191
+ provided
192
+ --training.data_dir TRAINING.DATA_DIR
193
+ Data dirs to use, with comma separated values if
194
+ provided
195
+ --training.data_files TRAINING.DATA_FILES
196
+ Data files to use, with comma separated values if
197
+ provided
198
+ --training.data_probs TRAINING.DATA_PROBS
199
+ Data sampling probabilities, with comma separated
200
+ values if provided
201
+ --training.streaming Whether to load dataset in streaming mode, used for
202
+ huge dataset
203
+ --training.num_workers TRAINING.NUM_WORKERS
204
+ Number of subprocesses to use for data loading. 0
205
+ means that the data will be loaded in the main
206
+ process.
207
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
208
+ Number of batches loaded in advance by each worker.2
209
+ means there will be a total of 2 * num_workers batches
210
+ prefetched across all workers.
211
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
212
+ The `data_parallel_replicate_degree` argument
213
+ specifies the degree of data parallelism for weight
214
+ replication. When this value is greater than 1,
215
+ weights will be replicated across
216
+ `data_parallel_replicate_degree` ranks. If
217
+ `data_parallel_shard_degree` is also greater than 1,
218
+ the parallelism method used is HSDP (Hybrid Sharded
219
+ Data Parallelism). Otherwise, the parallelism method
220
+ used is DDP (Distributed Data Parallelism). 1 means
221
+ disabled.
222
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
223
+ The `data_parallel_shard_degree` argument specifies
224
+ the degree of data parallelism for weight sharding.
225
+ When this value is greater than 1, weights will be
226
+ sharded across `data_parallel_shard_degree` ranks. If
227
+ `data_parallel_replicate_degree` is also greater than
228
+ 1, the parallelism method used is HSDP (Hybrid Sharded
229
+ Data Parallelism). Otherwise, the parallelism method
230
+ used is FSDP (Fully Sharded Data Parallelism). -1
231
+ means leftover ranks will be used (After
232
+ DP_REPLICATE/SP/PP). Note that only
233
+ `data_parallel_shard_degree` can be negative. 1 means
234
+ disabled.
235
+ --training.enable_cpu_offload
236
+ Whether to apply CPU offloading of parameters,
237
+ gradients, and optimizer states in FSDP
238
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
239
+ Tensor Parallelism degree. 1 means disabled.
240
+ --training.disable_loss_parallel
241
+ Whether to apply loss parallel when sequence parallel
242
+ is enabled
243
+ --training.mixed_precision_param {bfloat16,float32}
244
+ torch dtype to use for parameters when applying mixed
245
+ precision via FSDP. This feature only takes effect
246
+ when data_parallel_shard_degree > 1
247
+ --training.mixed_precision_reduce {float32}
248
+ torch dtype to use for reductions when applying mixed
249
+ precision via FSDP. This feature only takes effect
250
+ when data_parallel_shard_degree > 1
251
+ --training.compile Whether to compile the model
252
+ --training.gc_freq TRAINING.GC_FREQ
253
+ Python garbage control scheduling interval, in steps
254
+ --training.seed TRAINING.SEED
255
+ Choose the base RNG seed used for training
256
+ --training.deterministic
257
+ Use deterministic algorithms wherever possible, may be
258
+ slower
259
+ --metrics.log_freq METRICS.LOG_FREQ
260
+ How often to log metrics to TensorBoard, in iterations
261
+ --metrics.enable_tensorboard
262
+ Whether to log metrics to TensorBoard
263
+ --metrics.disable_color_printing
264
+ Whether to disable color printing in logs
265
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
266
+ Folder to dump TensorBoard states
267
+ --metrics.rank_0_only
268
+ Whether to save TensorBoard metrics only for rank 0 or
269
+ for all ranks. When pipeline_parallel_degree is > 1,
270
+ this option uses the 0th rank of the last stage
271
+ pipeline group, which is the only stage that computes
272
+ loss metrics.
273
+ --metrics.enable_wandb
274
+ Whether to log metrics to Weights & Biases
275
+ --experimental.enable_async_tensor_parallel
276
+ Whether to apply async tensor parallel (currently only
277
+ effective when compile is enabled)
278
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
279
+ Pipeline Parallelism degree, or number of ranks. 1
280
+ means disabled. If using looped schedules, this still
281
+ specifies the number of physical ranks, not the number
282
+ of stages. Stages per rank are inferred from split
283
+ points degree, and schedule.
284
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
285
+ Specify comma-separated names of modules to use as the
286
+ beginning of a split point. e.g. "layers.0,layers.2"
287
+ will cause the model to be split into 3 stages, the
288
+ first containing all the layers up to layers.0, the
289
+ second containing layers.0 and up to layers.2, the
290
+ third containing layers.2 and all the remaining
291
+ layers. Note: fully-automated splitting may be enabled
292
+ in the future, but currently the split points must be
293
+ specified manually.
294
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
295
+ Specify the Pipeline Parallel schedule to use. The
296
+ supported schedules are: https://github.com/pytorch/py
297
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
298
+ rch/distributed/pipelining/schedules.py#L2161. The
299
+ schedule must be compatible with the split points and
300
+ stages_per_rank. Looped schedules (e.g.
301
+ Interleaved1F1B) require specifying
302
+ pipeline_parallel_degree = number of ranks, and
303
+ split_points = number of stages - 1
304
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
305
+ Specify the path to the pipeline parallel schedule csv
306
+ file to use. The pipeline_parallel_schedule argument
307
+ must be either PipelineScheduleSingle,
308
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
309
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
310
+ How many microbatches to split the global training
311
+ batch into when using pipeline parallelism. The global
312
+ training batch size must be evenly divisible by the
313
+ number of microbatches. The default value will be the
314
+ number of pipeline stages, if unspecified.
315
+ --experimental.enable_compiled_autograd
316
+ Enable CompiledAutograd to compile the backward.
317
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
318
+ Context parallelism degree. 1 means disabled.
319
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
320
+ The collective to use in context parallel SDPA for kv
321
+ shards exchange. 'allgather' means to all-gather all
322
+ kv shards on ranks after the first sub-SDPA
323
+ computation, 'alltoall' means to all-to-all shuffle
324
+ the kv shards. The default value is 'allgather'.
325
+ --checkpoint.enable_checkpoint
326
+ Whether to enable checkpoint
327
+ --checkpoint.folder CHECKPOINT.FOLDER
328
+ The folder to store the checkpoints. When
329
+ enable_checkpoint is set to true, checkpoints will be
330
+ in {--job.dump_folder}/{--checkpoint.folder}.
331
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
332
+ Checkpointing interval unit of measurement ['step',
333
+ 'seconds']
334
+ --checkpoint.interval CHECKPOINT.INTERVAL
335
+ Checkpointing interval, in steps or seconds depending
336
+ on --checkpoint.interval_type
337
+ --checkpoint.model_weights_only
338
+ When model_weights_only=True, only model weights will
339
+ be saved at the end of training. With this,
340
+ checkpoints can be loaded using `torch.load(...,
341
+ weights_only=True)` after conversion. When
342
+ model_weights_only=False, the full checkpoint will be
343
+ saved. A full checkpoint includes model, optimizer and
344
+ train_state, which can be used to resume training. The
345
+ default value is false.
346
+ --checkpoint.export_dtype {float16,bfloat16,float32}
347
+ Converts to the specified precision when training
348
+ completes and model_weights_only=true. Currently
349
+ supports float32, float16, and bfloat16. The default
350
+ value is float32.
351
+ --checkpoint.create_seed_checkpoint
352
+ Initializes the full model without applying
353
+ parallelisms, and then saves it as a seed checkpoint.
354
+ Note: requires user to call train.py without
355
+ specifying any parallelisms, e.g. NGPU=1. Could be
356
+ implemented as a separate script, but this way shares
357
+ more code.
358
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
359
+ Which async checkpoint mode to use. Currently there
360
+ are 3 different modes. 1. "disabled": synchronized
361
+ checkpointing will be used. 2. "async":
362
+ torch.distributed.checkpoint.async_save will be used.
363
+ 1. "async_with_pinned_mem": this option utilizes a
364
+ dedicated pinned memory space and creates a separate
365
+ process for faster GPU->CPU transfer performance and
366
+ eliminating GIL contention. The cost is increased CPU
367
+ memory usage. If insufficient CPU memory is available,
368
+ performance may degrade due to memory paging. For most
369
+ users, "async" should suffice as the performance
370
+ overhead is typically small (on the order of tens of
371
+ seconds) compared to checkpointing frequency. This
372
+ mode can be employed to pursue near-zero checkpointing
373
+ times (e.g., < 1 second) given appropriate hardware
374
+ support such as ample CPU memory and fast PCIe.
375
+ "disabled" is the default mode.
376
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
377
+ Keeps only the latest k checkpoints, and purging older
378
+ ones. If 0, keep all checkpoints. 0 is the default
379
+ value.
380
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
381
+ Load the checkpoint at the specified step. If -1, load
382
+ the latest checkpoint.
383
+ --float8.enable_float8_linear
384
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
385
+ This feature requires you to install 'torchao' which
386
+ can be found here: https://github.com/pytorch/ao
387
+ --float8.enable_fsdp_float8_all_gather
388
+ Whether enable float8 all-gather in FSDP
389
+ --float8.precompute_float8_dynamic_scale_for_fsdp
390
+ Whether precompute float8 scales dynamically for FSDP
391
+ --float8.scaling_type_input {dynamic,delayed}
392
+ float8 scaling for input, dynamic (default) or delayed
393
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
394
+ float8 scaling for input, dynamic (default) or delayed
395
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
396
+ float8 scaling for input, dynamic (default) or delayed
397
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
398
+ Timeout for communication operations, during
399
+ initialization and first train step.
400
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
401
+ Timeout for communication operations after the first
402
+ train step -- usually a tighter bound than during
403
+ initialization.
404
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
405
+ Flight recorder ring buffer size, >0 means recording
406
+ by default, 0 means disabled
407
+ --memory_estimation.enabled
408
+ Whether to estimate memory usage for FSDP
409
+ --memory_estimation.disable_fake_mode
410
+ Whether to estimate memory under FakeTensorMode
411
+ ```
412
+ </details>
413
+
414
+ ### Training with `torch.compile`
415
+
416
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
417
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
418
+
419
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
420
+ We are actively working on resolving these issues to make compilation transparent to users.
421
+ In the meantime, please ensure you are using the latest dependencies.
422
+
423
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
424
+
425
+ ### Training with multiple datasets
426
+
427
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
428
+ `flame` allows training with multiple datasets easily.
429
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
430
+
431
+ ```sh
432
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
433
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
434
+ ```
435
+
436
+ ### ~Finalizing training~
437
+
438
+ > [!NOTE]
439
+ > We have done this conversion automatically in the training script since our latest updates.
440
+
441
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
442
+ To facilitate this, we provide a straightforward conversion script:
443
+
444
+ ```sh
445
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
446
+ ```
447
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
448
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
449
+
450
+ ### Continual training
451
+
452
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
453
+ This allows you to seamlessly resume training with `flame`.
454
+ ```sh
455
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
456
+ ```
457
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
458
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
459
+
460
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
461
+
462
+ ## Multi-node training
463
+
464
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
465
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
466
+
467
+ To set up multi-node training:
468
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
469
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
470
+
471
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MTPTransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "bos_token_id": 1,
7
+ "elementwise_affine": true,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "fuse_swiglu": true,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "max_position_embeddings": 8192,
18
+ "model_type": "mtp_transformer",
19
+ "n_future_tokens": 4,
20
+ "norm_eps": 1e-06,
21
+ "num_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_kv_heads": null,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.50.3",
30
+ "use_cache": true,
31
+ "use_custom_backward": false,
32
+ "vocab_size": 32000,
33
+ "window_size": null
34
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/mtp_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "mtp_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "n_future_tokens": 4
23
+ }
configs/mtp_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/mtp_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "mtp_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "n_future_tokens": 4
22
+ }
configs/top_transformer_120M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 2048
20
+ }
configs/top_transformer_1B.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "top_transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "use_top_loss": true,
23
+ "top_window_size": 4096
24
+ }
configs/top_transformer_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "top_transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_top_loss": true,
19
+ "top_window_size": 4096
20
+ }
configs/top_transformer_7B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "top_transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "use_top_loss": true,
22
+ "top_window_size": 4096
23
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 30,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel
55
+ )
56
+
57
+ __all__ = [
58
+ 'ABCAttention',
59
+ 'Attention',
60
+ 'BasedLinearAttention',
61
+ 'BitAttention',
62
+ 'DeltaNet',
63
+ 'GatedDeltaNet',
64
+ 'GatedDeltaProduct',
65
+ 'GatedLinearAttention',
66
+ 'GatedSlotAttention',
67
+ 'HGRNAttention',
68
+ 'HGRN2Attention',
69
+ 'LightNetAttention',
70
+ 'LinearAttention',
71
+ 'MultiScaleRetention',
72
+ 'NativeSparseAttention',
73
+ 'ReBasedLinearAttention',
74
+ 'RWKV6Attention',
75
+ 'RWKV7Attention',
76
+ 'ABCForCausalLM',
77
+ 'ABCModel',
78
+ 'BitNetForCausalLM',
79
+ 'BitNetModel',
80
+ 'DeltaNetForCausalLM',
81
+ 'DeltaNetModel',
82
+ 'GatedDeltaNetForCausalLM',
83
+ 'GatedDeltaNetModel',
84
+ 'GatedDeltaProductForCausalLM',
85
+ 'GatedDeltaProductModel',
86
+ 'GLAForCausalLM',
87
+ 'GLAModel',
88
+ 'GSAForCausalLM',
89
+ 'GSAModel',
90
+ 'HGRNForCausalLM',
91
+ 'HGRNModel',
92
+ 'HGRN2ForCausalLM',
93
+ 'HGRN2Model',
94
+ 'LightNetForCausalLM',
95
+ 'LightNetModel',
96
+ 'LinearAttentionForCausalLM',
97
+ 'LinearAttentionModel',
98
+ 'NSAForCausalLM',
99
+ 'NSAModel',
100
+ 'RetNetForCausalLM',
101
+ 'RetNetModel',
102
+ 'RWKV6ForCausalLM',
103
+ 'RWKV6Model',
104
+ 'RWKV7ForCausalLM',
105
+ 'RWKV7Model',
106
+ 'TransformerForCausalLM',
107
+ 'TransformerModel',
108
+ ]
109
+
110
+ __version__ = '0.1.2'
fla/utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import contextlib
4
+ import functools
5
+ import os
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ from packaging import version
13
+
14
+
15
+ def tensor_cache(
16
+ fn: Callable[..., torch.Tensor]
17
+ ) -> Callable[..., torch.Tensor]:
18
+ """
19
+ A decorator that caches the most recent result of a function with tensor inputs.
20
+
21
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
22
+ If the function is called again with the same input tensors, it will return the cached result.
23
+
24
+
25
+ Args:
26
+ fn (Callable[..., torch.Tensor]):
27
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
28
+
29
+ Returns:
30
+ Callable[..., torch.Tensor]:
31
+ A wrapped version of the input function with single-entry caching.
32
+ """
33
+ last_args: Optional[Tuple] = None
34
+ last_kwargs: Optional[Dict] = None
35
+ last_result: Any = None
36
+
37
+ @functools.wraps(fn)
38
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
39
+ nonlocal last_args, last_kwargs, last_result
40
+
41
+ if last_args is not None and last_kwargs is not None:
42
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
43
+ if all(a is b for a, b in zip(args, last_args)) and \
44
+ all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
45
+ return last_result
46
+
47
+ result = fn(*args, **kwargs)
48
+ last_args, last_kwargs, last_result = args, kwargs, result
49
+ return result
50
+
51
+ return wrapper
52
+
53
+
54
+ def input_guard(
55
+ fn: Callable[..., torch.Tensor]
56
+ ) -> Callable[..., torch.Tensor]:
57
+ """
58
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
59
+ """
60
+
61
+ @functools.wraps(fn)
62
+ def wrapper(*args, **kwargs):
63
+ contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
64
+ contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
65
+
66
+ tensor = None
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ tensor = arg
70
+ break
71
+ if tensor is None:
72
+ for value in kwargs.values():
73
+ if isinstance(value, torch.Tensor):
74
+ tensor = value
75
+ break
76
+
77
+ if tensor is not None:
78
+ ctx = custom_device_ctx(tensor.device.index)
79
+ else:
80
+ ctx = contextlib.nullcontext()
81
+
82
+ with ctx:
83
+ return fn(*contiguous_args, **contiguous_kwargs)
84
+
85
+ return wrapper
86
+
87
+
88
+ contiguous = input_guard
89
+
90
+
91
+ def require_version(version, hint):
92
+ """
93
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
94
+ """
95
+ def decorator(fn):
96
+ @functools.wraps(fn)
97
+ def wrapper(ctx, *args, **kwargs):
98
+ from transformers.utils.versions import require_version
99
+ require_version(version, hint)
100
+ return fn(ctx,
101
+ *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
102
+ **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
103
+ return wrapper
104
+ return decorator
105
+
106
+
107
+ def checkpoint(fn):
108
+ def wrapper(*args, **kwargs):
109
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
110
+ return wrapper
111
+
112
+
113
+ @lru_cache(maxsize=None)
114
+ def check_pytorch_version(version_s: str = '2.4') -> bool:
115
+ return version.parse(torch.__version__) >= version.parse(version_s)
116
+
117
+
118
+ def _cpu_device_warning():
119
+ import warnings
120
+ warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
125
+ try:
126
+ # Only works if Homogeneous hardware
127
+ # TEMPORARY FIX since old version introduce graph break
128
+ return torch.cuda.get_device_properties().multi_processor_count
129
+ except BaseException:
130
+ _cpu_device_warning()
131
+ return -1
132
+
133
+
134
+ @lru_cache(maxsize=None)
135
+ def get_available_device() -> str:
136
+ try:
137
+ return triton.runtime.driver.active.get_current_target().backend
138
+ except BaseException:
139
+ _cpu_device_warning()
140
+ return 'cpu'
141
+
142
+
143
+ @lru_cache(maxsize=None)
144
+ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
145
+ device = get_available_device()
146
+ if device == 'cuda':
147
+ return 'nvidia'
148
+ elif device == 'hip':
149
+ return 'amd'
150
+ elif device == 'xpu':
151
+ return 'intel'
152
+ else:
153
+ return device
154
+
155
+
156
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
157
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
158
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
159
+ device = get_available_device() if get_available_device() != 'hip' else 'cuda'
160
+ device_torch_lib = getattr(torch, device)
161
+ device_platform = _check_platform()
162
+
163
+ is_amd = (device_platform == 'amd')
164
+ is_intel = (device_platform == 'intel')
165
+ is_nvidia = (device_platform == 'nvidia')
166
+ is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
167
+ is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
168
+ use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
169
+
170
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
171
+ is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
172
+ is_gather_supported = hasattr(triton.language, 'gather')
173
+
174
+
175
+ def get_all_max_shared_mem():
176
+ try:
177
+ return [
178
+ triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
179
+ for i in range(device_torch_lib.device_count())
180
+ ]
181
+ except BaseException:
182
+ _cpu_device_warning()
183
+ return [-1]
184
+
185
+
186
+ class Backend(Enum):
187
+ ADA = 101376 # RTX 4090
188
+ AMPERE = 166912 # A100
189
+ HOPPER = 232448 # H100
190
+ DEFAULT = 102400 # Default
191
+
192
+ @classmethod
193
+ def get_shared_memory(cls, arch: str) -> int:
194
+ try:
195
+ return cls[arch.upper()].value
196
+ except KeyError:
197
+ return cls.DEFAULT.value
198
+
199
+
200
+ @lru_cache(maxsize=None)
201
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
202
+ try:
203
+ device_shared_mem_list = get_all_max_shared_mem()
204
+ max_shared_memory = device_shared_mem_list[tensor_idx]
205
+ return max_shared_memory >= Backend.get_shared_memory(arch)
206
+ except Exception:
207
+ return False
208
+
209
+
210
+ if check_pytorch_version('2.4'):
211
+ device = 'cuda' if device == 'cpu' else device
212
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
213
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
214
+
215
+ def custom_device_ctx(index: int):
216
+ return device_torch_lib.device(index)
217
+ else:
218
+ assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
219
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
220
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
221
+
222
+ def custom_device_ctx(index: int):
223
+ return torch.cuda.device(index)
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
flame/__pycache__/train.cpython-312.pyc ADDED
Binary file (38.1 kB). View file
 
flame/components/__init__.py ADDED
File without changes
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (141 Bytes). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/config_manager.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.implementation",
188
+ type=str,
189
+ default="fused",
190
+ choices=["for-loop", "foreach", "fused"],
191
+ help="""
192
+ Specify which optimizer implementation to use:
193
+ - 'fused': Use fused implementation (CUDA only) for best performance.
194
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
195
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
196
+ - more info: https://pytorch.org/docs/stable/optim.html
197
+ """,
198
+ )
199
+ self.parser.add_argument(
200
+ "--optimizer.early_step_in_backward",
201
+ action="store_true",
202
+ help="""
203
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
204
+ is not compatible with gradients clipping, users should not call
205
+ register_post_accumulate_grad_hook after the optimizer is built.""",
206
+ )
207
+
208
+ # lr scheduler configs
209
+ self.parser.add_argument(
210
+ "--lr_scheduler.warmup_steps",
211
+ type=int,
212
+ default=200,
213
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
214
+ )
215
+ self.parser.add_argument(
216
+ "--lr_scheduler.decay_ratio",
217
+ type=float,
218
+ default=None,
219
+ help="""
220
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
221
+
222
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
223
+ Otherwise, the learning rate will remain stable after the warmup period and
224
+ only start decaying during the last `decay_ratio` portion of the total training steps.
225
+
226
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
227
+ """,
228
+ )
229
+ self.parser.add_argument(
230
+ "--lr_scheduler.decay_type",
231
+ type=str,
232
+ default="linear",
233
+ choices=["linear", "sqrt", "cosine"],
234
+ help="""
235
+ Learning rate decay type to use during training:
236
+ - 'linear': linearly decays learning rate from initial to final value
237
+ - 'sqrt': decays learning rate following a 1 minus square root curve
238
+ - 'cosine': smoothly decays learning rate following a cosine curve
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.lr_min",
243
+ type=float,
244
+ default=0.0,
245
+ help="""
246
+ Min lr ratio for lr scheduler.
247
+
248
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
249
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
250
+ """,
251
+ )
252
+
253
+ # training configs
254
+ self.parser.add_argument(
255
+ "--training.batch_size", type=int, default=8, help="Batch size"
256
+ )
257
+ self.parser.add_argument(
258
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
259
+ )
260
+ self.parser.add_argument(
261
+ "--training.context_len",
262
+ type=int,
263
+ default=2048,
264
+ help="Max length allowed for each sequence",
265
+ )
266
+ self.parser.add_argument(
267
+ "--training.varlen",
268
+ action="store_true",
269
+ help="Whether to take sequences of variable length as input",
270
+ )
271
+ self.parser.add_argument(
272
+ "--training.gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of steps to accumulate gradients before updating parameters",
276
+ )
277
+ self.parser.add_argument(
278
+ "--training.steps",
279
+ type=int,
280
+ default=10000,
281
+ help="How many train steps to run",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.max_norm",
285
+ type=float,
286
+ default=1.0,
287
+ help="Max norm for gradient clipping",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.skip_nan_inf",
291
+ action="store_true",
292
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
293
+ )
294
+ self.parser.add_argument(
295
+ "--training.dataset",
296
+ default="HuggingFaceFW/fineweb-edu",
297
+ help="Dataset to use, with comma separated values",
298
+ )
299
+ self.parser.add_argument(
300
+ "--training.dataset_name",
301
+ default=None,
302
+ help="The name of the dataset config, with comma separated values if provided",
303
+ )
304
+ self.parser.add_argument(
305
+ "--training.dataset_split",
306
+ default=None,
307
+ help="Dataset split to use, with comma separated values if provided",
308
+ )
309
+ self.parser.add_argument(
310
+ "--training.data_dir",
311
+ default=None,
312
+ help="Data dirs to use, with comma separated values if provided",
313
+ )
314
+ self.parser.add_argument(
315
+ "--training.data_files",
316
+ default=None,
317
+ help="Data files to use, with comma separated values if provided",
318
+ )
319
+ self.parser.add_argument(
320
+ "--training.data_probs",
321
+ default=None,
322
+ help="Data sampling probabilities, with comma separated values if provided",
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.streaming",
326
+ action="store_true",
327
+ help="Whether to load dataset in streaming mode, used for huge dataset",
328
+ )
329
+ self.parser.add_argument(
330
+ "--training.num_workers",
331
+ type=int,
332
+ default=32,
333
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
334
+ )
335
+ self.parser.add_argument(
336
+ "--training.prefetch_factor",
337
+ type=int,
338
+ default=2,
339
+ help="Number of batches loaded in advance by each worker."
340
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
341
+ )
342
+ self.parser.add_argument(
343
+ "--training.data_parallel_replicate_degree",
344
+ type=int,
345
+ default=1,
346
+ help="""
347
+ The `data_parallel_replicate_degree` argument specifies the degree of
348
+ data parallelism for weight replication. When this value is greater
349
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
350
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
351
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
352
+ parallelism method used is DDP (Distributed Data Parallelism).
353
+ 1 means disabled.""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.data_parallel_shard_degree",
357
+ type=int,
358
+ default=-1,
359
+ help="""
360
+ The `data_parallel_shard_degree` argument specifies the degree of data
361
+ parallelism for weight sharding. When this value is greater than 1, weights
362
+ will be sharded across `data_parallel_shard_degree` ranks. If
363
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
364
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
365
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
366
+
367
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
368
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
369
+ )
370
+ self.parser.add_argument(
371
+ "--training.enable_cpu_offload",
372
+ action="store_true",
373
+ help="""
374
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
375
+ )
376
+ self.parser.add_argument(
377
+ "--training.tensor_parallel_degree",
378
+ type=int,
379
+ default=1,
380
+ help="Tensor Parallelism degree. 1 means disabled.",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.disable_loss_parallel",
384
+ action="store_true",
385
+ help="Whether to apply loss parallel when sequence parallel is enabled",
386
+ )
387
+ self.parser.add_argument(
388
+ "--training.fsdp_reshard_after_forward",
389
+ type=str,
390
+ default="default",
391
+ choices=["default", "always", "never"],
392
+ help="""
393
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
394
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
395
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
396
+ on `reshard_after_forward`.
397
+ The supported policies include "default", "always" and "never":
398
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
399
+ scenarios.
400
+ - "always" will enable `reshard_after_forward` for all forward passes.
401
+ - "never" will disable `reshard_after_forward` for all forward passes.
402
+ """,
403
+ )
404
+ self.parser.add_argument(
405
+ "--training.mixed_precision_param",
406
+ type=str,
407
+ default="bfloat16",
408
+ choices=["bfloat16", "float32"],
409
+ help="""
410
+ torch dtype to use for parameters when applying mixed precision via FSDP.
411
+ This feature only takes effect when data_parallel_shard_degree > 1
412
+ """,
413
+ )
414
+ self.parser.add_argument(
415
+ "--training.mixed_precision_reduce",
416
+ type=str,
417
+ default="float32",
418
+ choices=["float32"],
419
+ help="""
420
+ torch dtype to use for reductions when applying mixed precision via FSDP.
421
+ This feature only takes effect when data_parallel_shard_degree > 1
422
+ """,
423
+ )
424
+ self.parser.add_argument(
425
+ "--training.compile",
426
+ action="store_true",
427
+ help="Whether to compile the model",
428
+ )
429
+ self.parser.add_argument(
430
+ "--training.gc_freq",
431
+ type=int,
432
+ default=50,
433
+ help="Python garbage control scheduling interval, in steps",
434
+ )
435
+ self.parser.add_argument(
436
+ "--training.seed",
437
+ type=int,
438
+ default=42,
439
+ help="Choose the base RNG seed used for training",
440
+ )
441
+ self.parser.add_argument(
442
+ "--training.deterministic",
443
+ action="store_true",
444
+ help="Use deterministic algorithms wherever possible, may be slower",
445
+ )
446
+ # metrics configs
447
+ self.parser.add_argument(
448
+ "--metrics.log_freq",
449
+ type=int,
450
+ default=10,
451
+ help="How often to log metrics to TensorBoard, in iterations",
452
+ )
453
+ self.parser.add_argument(
454
+ "--metrics.enable_tensorboard",
455
+ action="store_true",
456
+ help="Whether to log metrics to TensorBoard",
457
+ )
458
+ self.parser.add_argument(
459
+ "--metrics.disable_color_printing",
460
+ action="store_true",
461
+ help="Whether to disable color printing in logs",
462
+ )
463
+ self.parser.add_argument(
464
+ "--metrics.save_tb_folder",
465
+ type=str,
466
+ default="tb",
467
+ help="Folder to dump TensorBoard states",
468
+ )
469
+ self.parser.add_argument(
470
+ "--metrics.save_for_all_ranks",
471
+ action="store_true",
472
+ default=False,
473
+ help="""
474
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
476
+ component uses the 0th rank of the last stage pipeline group, which is the
477
+ only stage that computes loss metrics.
478
+ """,
479
+ )
480
+ self.parser.add_argument(
481
+ "--metrics.enable_wandb",
482
+ action="store_true",
483
+ help="Whether to log metrics to Weights & Biases",
484
+ )
485
+
486
+ self.parser.add_argument(
487
+ "--experimental.enable_async_tensor_parallel",
488
+ action="store_true",
489
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
490
+ )
491
+ self.parser.add_argument(
492
+ "--experimental.pipeline_parallel_degree",
493
+ type=int,
494
+ default=1,
495
+ help="""
496
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
497
+ If using looped schedules, this still specifies the number of physical ranks, not the number
498
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
499
+ )
500
+ self.parser.add_argument(
501
+ "--experimental.pipeline_parallel_split_points",
502
+ type=string_list,
503
+ nargs="+",
504
+ default=[],
505
+ help="""
506
+ Specify comma-separated names of modules to use as the beginning of a split point.
507
+
508
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
509
+ the first containing all the layers up to layers.0,
510
+ the second containing layers.0 and up to layers.2,
511
+ the third containing layers.2 and all the remaining layers.
512
+
513
+ Note: fully-automated splitting may be enabled in the future,
514
+ but currently the split points must be specified manually.""",
515
+ )
516
+ self.parser.add_argument(
517
+ "--experimental.pipeline_parallel_schedule",
518
+ type=str,
519
+ default="1F1B",
520
+ help="""
521
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
522
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
523
+ The schedule must be compatible with the split points and stages_per_rank.
524
+
525
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
526
+ and split_points = number of stages - 1
527
+ """,
528
+ )
529
+ self.parser.add_argument(
530
+ "--experimental.pipeline_parallel_schedule_csv",
531
+ type=str,
532
+ default="",
533
+ help="""
534
+ Specify the path to the pipeline parallel schedule csv file to use.
535
+ The pipeline_parallel_schedule argument must be either
536
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
537
+ """,
538
+ )
539
+
540
+ self.parser.add_argument(
541
+ "--experimental.pipeline_parallel_microbatches",
542
+ type=int,
543
+ default=None,
544
+ help="""
545
+ How many microbatches to split the global training batch into when using pipeline parallelism.
546
+
547
+ The global training batch size must be evenly divisible by the number of microbatches.
548
+
549
+ The default value will be the number of pipeline stages, if unspecified.
550
+ """,
551
+ )
552
+ self.parser.add_argument(
553
+ "--experimental.enable_compiled_autograd",
554
+ action="store_true",
555
+ help="Enable CompiledAutograd to compile the backward.",
556
+ )
557
+ self.parser.add_argument(
558
+ "--experimental.context_parallel_degree",
559
+ type=int,
560
+ default=1,
561
+ help="Context parallelism degree. 1 means disabled.",
562
+ )
563
+ self.parser.add_argument(
564
+ "--experimental.context_parallel_rotate_method",
565
+ type=str,
566
+ default="allgather",
567
+ help="""
568
+ The collective to use in context parallel SDPA for kv shards exchange.
569
+
570
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
571
+
572
+ 'alltoall' means to all-to-all shuffle the kv shards.
573
+
574
+ The default value is 'allgather'.
575
+ """,
576
+ )
577
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
578
+ # module and import TorchTitan training loop and execute it, which look cleaner.
579
+ # One reason to provide this option is to allow users to use the existing run script.
580
+ # While the script is pretty trivial now, we may add more logic when integrating
581
+ # with TorchFT.
582
+ # This option is subject to change and may be deleted in the future.
583
+ self.parser.add_argument(
584
+ "--experimental.custom_model_path",
585
+ type=str,
586
+ default="",
587
+ help="""
588
+ The --custom_model_path option allows to specify a custom path to a model module
589
+ that is not natively implemented within TorchTitan.
590
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
591
+ dotted import module (e.g., some_package.model_x).
592
+ """,
593
+ )
594
+ # checkpointing configs
595
+ self.parser.add_argument(
596
+ "--checkpoint.enable_checkpoint",
597
+ action="store_true",
598
+ help="Whether to enable checkpoint",
599
+ )
600
+ self.parser.add_argument(
601
+ "--checkpoint.folder",
602
+ type=str,
603
+ default="checkpoint",
604
+ help="""
605
+ The folder to store the checkpoints.
606
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607
+ """,
608
+ )
609
+ self.parser.add_argument(
610
+ "--checkpoint.interval",
611
+ type=int,
612
+ default=500,
613
+ help="Checkpointing interval in steps.",
614
+ )
615
+ self.parser.add_argument(
616
+ "--checkpoint.model_weights_only",
617
+ action="store_true",
618
+ help="""
619
+ When model_weights_only=True, only model weights will be saved at the end of training.
620
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621
+ When model_weights_only=False, the full checkpoint will be saved.
622
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623
+ The default value is false.
624
+ """,
625
+ )
626
+ self.parser.add_argument(
627
+ "--checkpoint.export_dtype",
628
+ type=str,
629
+ default="float32",
630
+ choices=["float16", "bfloat16", "float32"],
631
+ help="""
632
+ Converts to the specified precision when training completes and model_weights_only=true.
633
+ Currently supports float32, float16, and bfloat16.
634
+ The default value is float32.
635
+ """,
636
+ )
637
+ self.parser.add_argument(
638
+ "--checkpoint.create_seed_checkpoint",
639
+ action="store_true",
640
+ help="""
641
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
642
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
643
+ Could be implemented as a separate script, but this way shares more code.
644
+ """,
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.async_mode",
648
+ type=str,
649
+ default="disabled",
650
+ help="""
651
+ Which async checkpoint mode to use. Currently there are 3 different modes.
652
+ 1. "disabled": synchronized checkpointing will be used.
653
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
654
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
655
+ space and creates a separate process for faster GPU->CPU transfer
656
+ performance and eliminating GIL contention. The cost is increased CPU
657
+ memory usage. If insufficient CPU memory is available, performance may
658
+ degrade due to memory paging. For most users, "async" should suffice as
659
+ the performance overhead is typically small (on the order of tens of
660
+ seconds) compared to checkpointing frequency. This mode can be employed
661
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
662
+ appropriate hardware support such as ample CPU memory and fast PCIe.
663
+
664
+ "disabled" is the default mode.
665
+ """,
666
+ )
667
+ self.parser.add_argument(
668
+ "--checkpoint.keep_latest_k",
669
+ type=int,
670
+ default=0,
671
+ help="""
672
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
673
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
674
+ saved. As a result, the metadata of the last one may not be ready yet.
675
+ """,
676
+ )
677
+ self.parser.add_argument(
678
+ "--checkpoint.load_step",
679
+ type=int,
680
+ default=-1,
681
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
682
+ )
683
+ self.parser.add_argument(
684
+ "--checkpoint.exclude_from_loading",
685
+ type=string_list,
686
+ nargs="*",
687
+ default=[],
688
+ help="""
689
+ Exclude specific keys from being loaded from the checkpoint.
690
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
691
+ This will load the model only, excluding the specified keys.
692
+ """,
693
+ )
694
+ self.parser.add_argument(
695
+ "--checkpoint.convert_to_hf_on_save",
696
+ action="store_true",
697
+ help="""
698
+ If true, automatically convert the saved DCP checkpoint to Hugging Face format
699
+ in a parallel directory (e.g., step-1000-hf) after each save.
700
+ """,
701
+ )
702
+ self.parser.add_argument(
703
+ "--checkpoint.hf_upload_enabled",
704
+ action="store_true",
705
+ help="Enable uploading converted Hugging Face checkpoints to the Hub.",
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.hf_repo_base_name",
709
+ type=str,
710
+ default=None,
711
+ help="Hugging Face Hub repository ID to upload checkpoints to (e.g., 'username/repo').",
712
+ )
713
+ self.parser.add_argument(
714
+ "--checkpoint.hf_upload_format",
715
+ type=str,
716
+ default="dcp",
717
+ choices=["dcp", "hf"],
718
+ help="""
719
+ Format to upload to Hugging Face Hub. 'dcp' for DCP format, 'hf' for Hugging Face format.
720
+ Note: 'hf' is only supported for models with a single pipeline stage.
721
+ """,
722
+ )
723
+ # activation checkpointing configs
724
+ self.parser.add_argument(
725
+ "--activation_checkpoint.mode",
726
+ type=str,
727
+ default="selective",
728
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
729
+ )
730
+ self.parser.add_argument(
731
+ "--activation_checkpoint.selective_ac_option",
732
+ type=str,
733
+ default="2", # 2 = checkpoint every other layer
734
+ help="""
735
+ Selective activation checkpointing options ['int', 'op'].
736
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
737
+ """,
738
+ )
739
+
740
+ self.parser.add_argument(
741
+ "--activation_offload.mode",
742
+ type=str,
743
+ default="none",
744
+ help="""
745
+ if we are using activation offload or not. Options are ['none', 'full'].
746
+ """,
747
+ )
748
+
749
+ # float8 configs
750
+ self.parser.add_argument(
751
+ "--float8.enable_fsdp_float8_all_gather",
752
+ action="store_true",
753
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
754
+ )
755
+ self.parser.add_argument(
756
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
757
+ action="store_true",
758
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
759
+ )
760
+ self.parser.add_argument(
761
+ "--float8.force_recompute_fp8_weight_in_bwd",
762
+ action="store_true",
763
+ help="""
764
+ Whether to force the recomputation of FP8 weights during backward pass.
765
+ When using FSDP with tensorwise scaling, it is recommended to enable
766
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
767
+ for backward computation.
768
+ """,
769
+ )
770
+ self.parser.add_argument(
771
+ "--float8.recipe_name",
772
+ type=str,
773
+ default=None,
774
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
775
+ help="""
776
+ If specified, creates float8 config from recipe name, valid choices are
777
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
778
+ """,
779
+ )
780
+
781
+ # communications library settings
782
+ self.parser.add_argument(
783
+ "--comm.init_timeout_seconds",
784
+ type=int,
785
+ default=300,
786
+ help="Timeout for communication operations, during initialization and first train step.",
787
+ )
788
+ self.parser.add_argument(
789
+ "--comm.train_timeout_seconds",
790
+ type=int,
791
+ default=100,
792
+ help=(
793
+ "Timeout for communication operations after the first train step -- "
794
+ "usually a tighter bound than during initialization."
795
+ ),
796
+ )
797
+ self.parser.add_argument(
798
+ "--comm.trace_buf_size",
799
+ type=int,
800
+ default=20000,
801
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
802
+ )
803
+
804
+ # memory estimation settings
805
+ self.parser.add_argument(
806
+ "--memory_estimation.enabled",
807
+ help="Whether to estimate memory usage for FSDP",
808
+ action="store_true",
809
+ )
810
+
811
+ self.parser.add_argument(
812
+ "--memory_estimation.disable_fake_mode",
813
+ help="Whether to estimate memory under FakeTensorMode",
814
+ action="store_true",
815
+ )
816
+
817
+ self.parser.add_argument(
818
+ "--fault_tolerance.enable",
819
+ action="store_true",
820
+ help="""
821
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
822
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
823
+ --fault_tolerance.group_size will be used to control the maximum
824
+ replicate group size as the replicate group size is dynamic.
825
+
826
+ Note that this is still an experimental feature.
827
+ """,
828
+ )
829
+
830
+ self.parser.add_argument(
831
+ "--fault_tolerance.replica_id",
832
+ type=int,
833
+ default=0,
834
+ help="The TorchFT replica ID of this run.",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.group_size",
839
+ type=int,
840
+ default=0,
841
+ help="""
842
+ The number of TorchFT replicate groups. This number will be used for
843
+ dataloader to split the dataset across the replicate groups and FSDP
844
+ dimension
845
+ """,
846
+ )
847
+
848
+ self.parser.add_argument(
849
+ "--fault_tolerance.min_replica_size",
850
+ type=int,
851
+ default=1,
852
+ help="The minimum number of FT replica for each step.",
853
+ )
854
+
855
+ def to_dict(self):
856
+ return self.args_dict
857
+
858
+ def parse_args(self, args_list: list = sys.argv[1:]):
859
+ args, cmd_args = self.parse_args_from_command_line(args_list)
860
+ config_file = getattr(args, "job.config_file", None)
861
+ # build up a two level dict
862
+ args_dict = self._args_to_two_level_dict(args)
863
+ if config_file is not None:
864
+ try:
865
+ with open(config_file, "rb") as f:
866
+ for k, v in tomllib.load(f).items():
867
+ # to prevent overwrite of non-specified keys
868
+ args_dict[k] |= v
869
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
870
+ logger.exception(
871
+ f"Error while loading the configuration file: {config_file}"
872
+ )
873
+ logger.exception(f"Error details: {str(e)}")
874
+ raise e
875
+
876
+ # Checking string-list arguments are properly split into a list
877
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
878
+ string_list_argnames = self._get_string_list_argument_names()
879
+ for n in string_list_argnames:
880
+ check_string_list_argument(args_dict, n)
881
+
882
+ # override args dict with cmd_args
883
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
884
+ for section, section_args in cmd_args_dict.items():
885
+ for k, v in section_args.items():
886
+ args_dict[section][k] = v
887
+
888
+ self.args_dict = args_dict
889
+
890
+ for k, v in args_dict.items():
891
+ class_type = type(k.title(), (), v)
892
+ setattr(self, k, class_type())
893
+ self._validate_config()
894
+
895
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
896
+ args_dict = defaultdict(defaultdict)
897
+ for k, v in vars(args).items():
898
+ first_level_key, second_level_key = k.split(".", 1)
899
+ args_dict[first_level_key][second_level_key] = v
900
+ return args_dict
901
+
902
+ def _validate_config(self) -> None:
903
+ # TODO: Add more mandatory validations
904
+ assert self.model.config
905
+ assert self.model.tokenizer_path
906
+
907
+ def _get_string_list_argument_names(self) -> list[str]:
908
+ """Get the parser argument names of type `string_list`."""
909
+ string_list_args = [
910
+ v.dest for v in self.parser._actions if v.type is string_list
911
+ ]
912
+ return string_list_args
913
+
914
+ def parse_args_from_command_line(
915
+ self, args_list
916
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
917
+ """
918
+ Parse command line arguments and return the parsed args and the command line only args
919
+ """
920
+ args = self.parser.parse_args(args_list)
921
+ string_list_argnames = set(self._get_string_list_argument_names())
922
+
923
+ # aux parser to parse the command line only args, with no defaults from main parser
924
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
925
+ for arg, val in vars(args).items():
926
+ if isinstance(val, bool):
927
+ aux_parser.add_argument(
928
+ "--" + arg, action="store_true" if val else "store_false"
929
+ )
930
+ elif arg in string_list_argnames:
931
+ # without this special case, type inference breaks here,
932
+ # since the inferred type is just 'list' and it ends up flattening
933
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
934
+ aux_parser.add_argument("--" + arg, type=string_list)
935
+ else:
936
+ aux_parser.add_argument("--" + arg, type=type(val))
937
+
938
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
939
+
940
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class BufferShuffledIterableDataset(IterableDataset):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ tokenizer: PreTrainedTokenizer,
28
+ seq_len: int = 2048,
29
+ rank: int = 0,
30
+ world_size: int = 1,
31
+ buffer_size: int = 1024,
32
+ ) -> BufferShuffledIterableDataset:
33
+ self.dataset = dataset
34
+ self.tokenizer = tokenizer
35
+
36
+ self.data = dataset.shard(world_size, rank)
37
+ self.seq_len = seq_len
38
+
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.buffer_size = buffer_size
42
+
43
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
+ self.dtype = torch.int16
45
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
+ self.dtype = torch.int32
47
+ else:
48
+ self.dtype = torch.int64
49
+ self.states = None
50
+ self.buffer = torch.tensor([], dtype=self.dtype)
51
+ self.tokens = []
52
+ self.rand_id = 0
53
+ self.token_id = 0
54
+ self.rng_state = None
55
+ self._epoch = 0
56
+
57
+ def __iter__(self):
58
+ g = torch.Generator()
59
+ g.manual_seed(self._epoch + self.rank)
60
+ if self.rng_state is not None:
61
+ g.set_state(self.rng_state)
62
+
63
+ rand_it = self.randint(0, self.buffer_size, g=g)
64
+ if self.states is not None:
65
+ self.data.load_state_dict(self.states)
66
+
67
+ # max number of tokens allowed in the chunk buffer
68
+ n_tokens = self.buffer_size * self.seq_len
69
+
70
+ while True:
71
+ for sample in self.tokenize(self.data):
72
+ # keep appending the samples to the token buffer
73
+ self.tokens += sample
74
+ # if the token buffer is full, start sampling
75
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
76
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
77
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
78
+ self.tokens = self.tokens[n_tokens:]
79
+ if len(self.buffer) == self.buffer_size:
80
+ yield from self.sample(rand_it)
81
+
82
+ n_chunks = len(self.tokens) // self.seq_len
83
+ # handle the left tokens in the buffer
84
+ if n_chunks > 0:
85
+ n_tokens = n_chunks * self.seq_len
86
+ indices = torch.randperm(n_chunks, generator=g).tolist()
87
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
88
+ self.tokens = self.tokens[n_tokens:]
89
+ for i in indices:
90
+ yield {'input_ids': self.buffer[i]}
91
+
92
+ def tokenize(self, data, batch_size: int = 64):
93
+ texts, states = [], []
94
+ for sample in data:
95
+ texts.append(sample['text'])
96
+ states.append(self.data.state_dict())
97
+ if len(texts) == batch_size:
98
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
99
+ self.states = s
100
+ yield tokenized
101
+ texts, states = [], []
102
+ if len(texts) > 0:
103
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
104
+ self.states = s
105
+ yield tokenized
106
+
107
+ def sample(self, indices):
108
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
109
+ while self.token_id < n_tokens:
110
+ i = next(indices)
111
+ start, end = self.token_id, self.token_id + self.seq_len
112
+ self.token_id += self.seq_len
113
+ yield {'input_ids': self.buffer[i].to(torch.long)}
114
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
115
+ self.token_id = 0
116
+ self.tokens = self.tokens[n_tokens:]
117
+
118
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
119
+ indices = torch.empty(buffer_size, dtype=torch.long)
120
+ while True:
121
+ # record the generator states before sampling
122
+ self.rng_state = g.get_state()
123
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
124
+ for i in indices[self.rand_id:].tolist():
125
+ self.rand_id += 1
126
+ yield i
127
+ self.rand_id = 0
128
+
129
+ def set_epoch(self, epoch):
130
+ self._epoch = epoch
131
+ if hasattr(self.dataset, 'set_epoch'):
132
+ self.dataset.set_epoch(epoch)
133
+
134
+ def state_dict(self):
135
+ return {
136
+ 'states': self.states,
137
+ 'buffer': self.buffer.clone(),
138
+ 'tokens': deepcopy(self.tokens),
139
+ 'rand_id': self.rand_id,
140
+ 'token_id': self.token_id,
141
+ 'rng_state': self.rng_state,
142
+ 'epoch': self._epoch,
143
+ }
144
+
145
+ def load_state_dict(self, state_dict):
146
+ self.states = state_dict['states']
147
+ self.buffer = state_dict['buffer'].clone()
148
+ self.tokens = deepcopy(state_dict['tokens'])
149
+ self.rand_id = state_dict['rand_id']
150
+ self.token_id = state_dict['token_id']
151
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
152
+ self._epoch = state_dict['epoch']
153
+
154
+
155
+ class OnlineTokenizedIterableDataset(IterableDataset):
156
+ def __init__(
157
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
158
+ ) -> OnlineTokenizedIterableDataset:
159
+ self.dataset = dataset
160
+ self.tokenizer = tokenizer
161
+
162
+ self.data = dataset.shard(world_size, rank)
163
+ self.seq_len = seq_len
164
+ self.rank = rank
165
+ self.world_size = world_size
166
+
167
+ self.states = None
168
+ self.tokens = []
169
+
170
+ def __iter__(self):
171
+ if self.states is not None:
172
+ self.data.load_state_dict(self.states)
173
+
174
+ while True:
175
+ for sample in self.tokenize(self.data):
176
+ # keep appending the samples to the token buffer
177
+ self.tokens += sample
178
+
179
+ while len(self.tokens) >= self.seq_len:
180
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
181
+ self.tokens = self.tokens[self.seq_len:]
182
+ yield {'input_ids': input_ids}
183
+
184
+ def tokenize(self, data, buffer_size: int = 64):
185
+ buffer, states = [], []
186
+ for sample in data:
187
+ if sample.get('text', None) is not None:
188
+ buffer.append(sample['text'])
189
+ elif sample.get('content', None) is not None:
190
+ buffer.append(sample['content'])
191
+ else:
192
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
193
+ states.append(self.data.state_dict())
194
+ if len(buffer) == buffer_size:
195
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
196
+ self.states = s
197
+ yield tokenized
198
+ buffer, states = [], []
199
+ if len(buffer) > 0:
200
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
201
+ self.states = s
202
+ yield tokenized
203
+
204
+ def state_dict(self):
205
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
206
+
207
+ def load_state_dict(self, state_dict):
208
+ self.states = state_dict['states']
209
+ self.tokens = deepcopy(state_dict['tokens'])
210
+
211
+
212
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def _init_state_dict(self) -> dict:
217
+ self._state_dict = self.ex_iterable._init_state_dict()
218
+ self._state_dict['mem_buffer'] = ([],)
219
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
220
+ self._state_dict['bit_generator_index_offset'] = 0
221
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
222
+ return self._state_dict
223
+
224
+ def __iter__(self):
225
+ buffer_size = self.buffer_size
226
+ rng = deepcopy(self.generator)
227
+ # this is the shuffle buffer that we keep in memory
228
+ mem_buffer = self._state_dict['mem_buffer'][0]
229
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
230
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
231
+ if self._state_dict:
232
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
233
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
234
+ # skip already consumed ones
235
+ for _ in range(index_offset):
236
+ i = next(indices_iterator)
237
+
238
+ for x in self.ex_iterable:
239
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
240
+ mem_buffer.append(x)
241
+ else: # otherwise, pick an example from it
242
+ i = next(indices_iterator)
243
+ index_offset = (index_offset + 1) % buffer_size
244
+ if self._state_dict:
245
+ self._state_dict['bit_generator_index_offset'] = index_offset
246
+ if index_offset == 0:
247
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
248
+ selected = mem_buffer[i]
249
+ mem_buffer[i] = x # replace the picked example by a new one
250
+ yield selected
251
+
252
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
253
+ if self._state_dict:
254
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
255
+
256
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
257
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
258
+ index_offset = index_offset + 1
259
+ if self._state_dict:
260
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
261
+ yield mem_buffer[i]
262
+
263
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
264
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
265
+ return BufferShuffledExamplesIterable(
266
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
267
+ )
268
+
269
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
270
+ """Keep only the requested shard."""
271
+ return BufferShuffledExamplesIterable(
272
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
273
+ buffer_size=self.buffer_size,
274
+ generator=self.generator,
275
+ )
276
+
277
+ def load_state_dict(self, state_dict: dict) -> dict:
278
+ def _inner_load_state_dict(state, new_state):
279
+ if new_state is not None and isinstance(state, dict):
280
+ for key in new_state:
281
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
282
+ return state
283
+ elif new_state is not None and isinstance(state, list):
284
+ for i in range(len(state)):
285
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
286
+ return state
287
+ return new_state
288
+
289
+ return _inner_load_state_dict(self._state_dict, state_dict)
290
+
291
+
292
+ def shuffle(
293
+ dataset: IterableDataset,
294
+ seed: int = 42,
295
+ generator: np.random.Generator = None,
296
+ buffer_size: int = 1024,
297
+ ):
298
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
299
+ return IterableDataset(
300
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
301
+ info=dataset._info.copy(),
302
+ split=dataset._split,
303
+ formatting=dataset._formatting,
304
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
305
+ distributed=copy.deepcopy(dataset._distributed),
306
+ token_per_repo_id=dataset._token_per_repo_id,
307
+ )
308
+
309
+
310
+ @dataclass
311
+ class DataCollatorForLanguageModeling:
312
+ """
313
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
314
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
315
+
316
+ Args:
317
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
318
+ The tokenizer used for encoding the data.
319
+ context_len (`int`, optional):
320
+ When `varlen=True`, sequences longer than this length within a document
321
+ (as determined by `cu_seqlens`) will be further chunked.
322
+ varlen (`bool`):
323
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
324
+
325
+ Returns:
326
+ A dictionary with the following keys:
327
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
328
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
329
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
330
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
331
+
332
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
333
+ """
334
+
335
+ tokenizer: PreTrainedTokenizer
336
+ context_len: Optional[int] = None
337
+ varlen: bool = False
338
+
339
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
340
+ if not isinstance(examples[0], Dict):
341
+ examples = [{'input_ids': example} for example in examples]
342
+
343
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
344
+ tensorized = {}
345
+ for key in ['input_ids', 'cu_seqlens']:
346
+ if key not in example:
347
+ continue
348
+ if isinstance(example[key], List):
349
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
350
+ elif isinstance(example[key], np.ndarray):
351
+ tensorized[key] = torch.from_numpy(example[key])
352
+ else:
353
+ tensorized[key] = example[key]
354
+ return tensorized
355
+
356
+ examples = list(map(tensorize, examples))
357
+
358
+ if not self.varlen:
359
+ # --- Handling for varlen=False (Batch Padding) ---
360
+ length_of_first = examples[0]['input_ids'].size(0)
361
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
362
+
363
+ if needs_padding:
364
+ # Check for pad token if padding is actually required
365
+ if self.tokenizer.pad_token_id is None:
366
+ raise ValueError(
367
+ f'You are attempting to pad samples but the tokenizer you are using '
368
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
369
+ )
370
+ # Pad using the tokenizer, ensuring attention_mask is returned
371
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
372
+ else:
373
+ # No padding needed, stack directly and create a full attention mask
374
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
375
+ batch = {
376
+ 'input_ids': input_ids,
377
+ # Create attention mask of all ones
378
+ 'attention_mask': torch.ones_like(input_ids),
379
+ }
380
+
381
+ # Create labels by cloning input_ids
382
+ labels = batch['input_ids'].clone()
383
+ # Mask labels only where attention_mask is 0 (padding positions)
384
+ if 'attention_mask' in batch:
385
+ labels[batch['attention_mask'] == 0] = -100
386
+ batch['labels'] = labels
387
+
388
+ else:
389
+ # --- Handling for varlen=True (Concatenated Sequences) ---
390
+ if len(examples) > 1:
391
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
392
+
393
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
394
+
395
+ # --- cu_seqlens calculation logic remains the same ---
396
+ if 'cu_seqlens' in examples[0]:
397
+ batch['cu_seqlens'] = (
398
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
399
+ ) # Ensure int32
400
+ else:
401
+ # determine boundaries by bos/eos positions
402
+ # Check for bos_token_id first
403
+ if self.tokenizer.bos_token_id is not None:
404
+ cu_seqlens = []
405
+ # Handle case where the sequence doesn't start with BOS
406
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
407
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
408
+ # Find all BOS token positions
409
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
410
+ # Ensure bos_positions is on the correct device if empty
411
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
412
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
413
+ elif bos_positions.numel() > 0:
414
+ cu_seqlens.append(bos_positions)
415
+ # Add the end of the entire batch
416
+ cu_seqlens.append(
417
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
418
+ ) # Match device and use size(1)
419
+ # Filter out empty tensors before cat
420
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
421
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
422
+ batch['cu_seqlens'] = torch.tensor(
423
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
424
+ )
425
+ else:
426
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
427
+
428
+ # Else, check for eos_token_id
429
+ elif self.tokenizer.eos_token_id is not None:
430
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
431
+ # Find positions *after* EOS tokens
432
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
433
+ # Ensure eos_positions is on the correct device if empty
434
+ if eos_positions.numel() > 0:
435
+ cu_seqlens.append(eos_positions)
436
+ # Handle case where the sequence doesn't end with EOS
437
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
438
+ # Only add the final length if the last found EOS wasn't already the end
439
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
440
+ cu_seqlens.append(
441
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
442
+ ) # Match device and use size(1)
443
+ # Filter out empty tensors before cat
444
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
445
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
446
+ batch['cu_seqlens'] = torch.tensor(
447
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
448
+ )
449
+ else:
450
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
451
+ # Else, neither BOS nor EOS is usable
452
+ else:
453
+ raise ValueError(
454
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
455
+ 'or an eos_token_id defined to act as sequence separators.'
456
+ )
457
+
458
+ # --- cu_seqlens validation checks remain the same ---
459
+ if batch['cu_seqlens'].numel() < 2:
460
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
461
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
462
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
463
+ if batch['cu_seqlens'][0] != 0:
464
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
466
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
467
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
468
+ raise ValueError(
469
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
470
+ f'{batch["cu_seqlens"]}'
471
+ )
472
+
473
+ # --- context_len splitting logic remains the same ---
474
+ if self.context_len is not None:
475
+ # This logic splits sequences based on context_len *after* initial boundaries are found
476
+ bos = batch['cu_seqlens'][:-1].tolist()
477
+ eos = batch['cu_seqlens'][1:].tolist()
478
+ # Handle empty sequences between boundaries
479
+ split_boundaries = []
480
+ for i, j in zip(bos, eos):
481
+ if i < j: # Only process non-empty sequences
482
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
483
+ # Add the final end point if it wasn't included by arange
484
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
485
+ # Concatenate all boundaries
486
+ if not split_boundaries: # Handle case of completely empty input
487
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
488
+ else:
489
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
490
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
491
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
492
+
493
+ # Create labels directly from input_ids, NO padding mask needed for varlen
494
+ labels = batch['input_ids'].clone()
495
+ batch['labels'] = labels
496
+
497
+ return batch
498
+
499
+
500
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
501
+ """
502
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ rank: int,
508
+ dataset: IterableDataset,
509
+ batch_size: int,
510
+ collate_fn: Callable,
511
+ num_workers: int = 0,
512
+ pin_memory: bool = False,
513
+ prefetch_factor: int = 2,
514
+ persistent_workers: bool = False,
515
+ snapshot_every_n_steps: Optional[int] = 1,
516
+ ):
517
+ super().__init__(
518
+ dataset=dataset,
519
+ batch_size=batch_size,
520
+ collate_fn=collate_fn,
521
+ num_workers=num_workers,
522
+ pin_memory=pin_memory,
523
+ prefetch_factor=prefetch_factor,
524
+ persistent_workers=persistent_workers,
525
+ snapshot_every_n_steps=snapshot_every_n_steps,
526
+ )
527
+ self.rank = rank
528
+
529
+ def state_dict(self) -> Dict[str, Any]:
530
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
531
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
532
+
533
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
534
+ # State being empty is valid
535
+ if not state_dict:
536
+ return
537
+
538
+ if f'rank_{self.rank}' not in state_dict:
539
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
540
+ return
541
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542
+
543
+
544
+ def build_dataloader(
545
+ dataset: IterableDataset,
546
+ tokenizer: PreTrainedTokenizer,
547
+ rank: int,
548
+ world_size: int,
549
+ batch_size: int,
550
+ seq_len: int,
551
+ context_len: Optional[int] = None,
552
+ varlen: bool = False,
553
+ num_workers: int = 0,
554
+ pin_memory: bool = False,
555
+ persistent_workers: bool = False,
556
+ snapshot_every_n_steps: Optional[int] = 1,
557
+ ):
558
+ dataset = OnlineTokenizedIterableDataset(
559
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
560
+ )
561
+ return ParallelAwareDataLoader(
562
+ rank=rank,
563
+ dataset=dataset,
564
+ batch_size=batch_size,
565
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
566
+ num_workers=num_workers,
567
+ pin_memory=pin_memory,
568
+ persistent_workers=persistent_workers,
569
+ snapshot_every_n_steps=snapshot_every_n_steps,
570
+ )
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
flame/models/activation_offloading.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import contextlib
9
+ from typing import Union
10
+ from warnings import warn
11
+
12
+ import psutil
13
+ import torch
14
+ from torch import nn
15
+ from torch.autograd.graph import saved_tensors_hooks
16
+
17
+ from torchtitan.tools.logging import logger
18
+
19
+ try:
20
+ import torchao
21
+ from torchao.dtypes.nf4tensor import NF4Tensor
22
+ except ImportError:
23
+ torchao = None
24
+ NF4Tensor = None
25
+ logger.warning("torchao not found. ")
26
+
27
+ # from torchtune.modules import TiedLinear
28
+
29
+
30
+ class OffloadActivations(saved_tensors_hooks):
31
+ """Context manager under which activation tensors created in the forward pass will be offloaded.
32
+
33
+ Enable the memory efficiency technique of activation offloading, where activations bigger than
34
+ min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
35
+ This is in contrast to maintaining the activation on GPU VRAM throughout the program.
36
+
37
+ This manager contains the option of using one additional CUDA stream to handle the communication
38
+ between CUDA and CPU, which is intended to overlap with the default computation stream to improve
39
+ runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
40
+ runtime vs memory usage.
41
+
42
+ Args:
43
+ use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
44
+ memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
45
+ but is a limited resource. Default: True.
46
+
47
+ use_streams (bool): Whether or not to use streams for performance optimization where
48
+ the communications get overlapped with the computation. Requires a torch build
49
+ after torch-2.5.0.]. Default: True.
50
+
51
+ max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
52
+ consecutive activations to keep alive during the forward pass. This number must be at
53
+ least 1. Keeping alive more activations will potentially allow more overlap between the
54
+ communication and compute streams at the cost of increasing memory usage. Keeping alive
55
+ fewer activations will conserve memory, but may cause poor overlap between the streams,
56
+ increasing runtime. Default: 5.
57
+
58
+ min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
59
+ for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
60
+ moving it to CPU and back. Default: 1024 bytes.
61
+
62
+ Raises:
63
+ ValueError: if max_fwd_stash_size is not at least 1.
64
+
65
+ Example:
66
+ >>> with OffloadActivations():
67
+ >>> logits = model(inputs)
68
+ >>> loss = ...
69
+ >>> loss.backward()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ use_pin_memory: bool = True,
75
+ use_streams: bool = True,
76
+ max_fwd_stash_size: int = 5,
77
+ min_offload_size: int = 1024,
78
+ ) -> None:
79
+
80
+ self.use_streams: bool = use_streams
81
+
82
+ self.min_tensor_size_bytes = (
83
+ min_offload_size # we don't want to bother with small tensors
84
+ )
85
+ self.tracker = (
86
+ {}
87
+ ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
88
+ self.tensor_id: int = 0
89
+ self.is_first_forward_call = True
90
+ self.is_first_backward_call = True
91
+ self.is_first_forward_pass = True
92
+
93
+ # managing cpu memory
94
+ self.use_pin_memory: bool = use_pin_memory
95
+ self.virtual_memory_safe_pct = (
96
+ 60 # we should not exceed this percentage of memory
97
+ )
98
+
99
+ self.s0 = torch.cuda.default_stream() # comp stream
100
+
101
+ # for streaming
102
+ if self.use_streams:
103
+ self.s1 = torch.cuda.Stream() # comms stream
104
+ self.fwd_stash = {} # tensor_id => (activation, ev1)
105
+ if max_fwd_stash_size < 1:
106
+ raise ValueError(
107
+ f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
108
+ )
109
+ self.max_fwd_stash_size = max_fwd_stash_size
110
+ self.bwd_tensor_stash = {} # tensor_id => activation
111
+ self.bwd_ev_stash = {} # tensor_id => ev0
112
+ self.curr_graph_id = None
113
+ self.curr_autograd_node = None
114
+
115
+ # -------- platform util functions -------- #
116
+ def verify_sufficient_virtual_memory():
117
+ curr_pct = get_cpu_ram_pct()
118
+ if curr_pct > self.virtual_memory_safe_pct:
119
+ warn(
120
+ f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
121
+ )
122
+
123
+ def get_cpu_ram_pct() -> float:
124
+ # get the percentage of memory used by the system
125
+ return psutil.virtual_memory().percent
126
+
127
+ def get_tensor_id() -> int:
128
+ # create a unique id for each tensor we are managing
129
+ self.tensor_id += 1
130
+ return self.tensor_id
131
+
132
+ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
+ # get the number of bytes in a tensor, for memory management purposes
134
+ return (
135
+ x.element_size() * x.nelement()
136
+ ) # x.element_size() * x._base_storage().nbytes()
137
+
138
+ # -------- core pack / unpack work -------- #
139
+ def pack_tensor(activation: torch.Tensor) -> int:
140
+ # activations are passed in during forward pass - from here we take over and return a unique id
141
+ if self.is_first_forward_call:
142
+ assert (
143
+ len(self.tracker) == 0
144
+ ), "backward pass should have cleared tracker of all tensors"
145
+
146
+ # set training phase trackers
147
+ self.is_first_forward_call = False
148
+ self.is_first_backward_call = True
149
+
150
+ # query for basic tensor info
151
+ num_bytes = get_num_bytes_tensor(activation)
152
+ tensor_id = get_tensor_id()
153
+
154
+ # only offload hefty bois if they're activations on CUDA (our heuristic
155
+ # for that is to check if they're not params or buffers)!
156
+ if (
157
+ activation.is_cuda
158
+ and num_bytes >= self.min_tensor_size_bytes
159
+ and (
160
+ not isinstance(activation, torch.nn.Parameter)
161
+ and not isinstance(activation, torch.nn.Buffer)
162
+ )
163
+ ):
164
+ if self.use_streams:
165
+ # First, sync back and dereference previously offloaded tensors
166
+ # as the offloading should be done sufficiently long ago.
167
+ for id in [k for k in self.fwd_stash.keys()]:
168
+ if id <= tensor_id - self.max_fwd_stash_size:
169
+ _, ev = self.fwd_stash[id]
170
+ self.s0.wait_event(ev)
171
+ del self.fwd_stash[id]
172
+ else:
173
+ break
174
+
175
+ # Sync in, offload, and add an event to sync back later
176
+ self.s1.wait_stream(self.s0)
177
+
178
+ stream = self.s1 if self.use_streams else self.s0
179
+ with torch.cuda.stream(stream):
180
+ try:
181
+ cpu_tensor = torch.empty_like(
182
+ activation, pin_memory=self.use_pin_memory, device="cpu"
183
+ )
184
+ except NotImplementedError as e:
185
+ if (
186
+ isinstance(activation, NF4Tensor)
187
+ and torchao.__version__ < "0.6.0.dev20240917"
188
+ ):
189
+ raise RuntimeError(
190
+ "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
191
+ ) from e
192
+ raise e
193
+ cpu_tensor.copy_(activation, non_blocking=True)
194
+ self.tracker[tensor_id] = (
195
+ cpu_tensor,
196
+ True,
197
+ ) # True = (in future) modified
198
+
199
+ if self.use_streams:
200
+ event = self.s1.record_event()
201
+
202
+ # Stash to keep activation alive til s1 is done
203
+ self.fwd_stash[tensor_id] = (activation, event)
204
+ else:
205
+ self.tracker[tensor_id] = (
206
+ activation,
207
+ False,
208
+ ) # False = not modified, tensor is as is
209
+
210
+ return tensor_id
211
+
212
+ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
213
+ # backward pass - we are called with the tensor_id, which
214
+ # we will use to retrieve the saved/offloaded tensor
215
+ if self.is_first_backward_call:
216
+ if self.is_first_forward_pass:
217
+ self.is_first_forward_pass = False
218
+ if self.use_pin_memory:
219
+ verify_sufficient_virtual_memory()
220
+
221
+ self.is_first_backward_call = False
222
+ self.is_first_forward_call = True
223
+
224
+ assert (
225
+ unpack_tensor_id in self.tracker
226
+ ), f"untracked tensor with id {unpack_tensor_id}"
227
+
228
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
229
+ if modified:
230
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
231
+ maybe_gpu_tensor = gpu_tensor
232
+
233
+ # clear tensor from tracking
234
+ del self.tracker[unpack_tensor_id]
235
+ return maybe_gpu_tensor
236
+
237
+ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
238
+ # backward pass - we are called with the tensor_id, which
239
+ # we will use to retrieve the saved/offloaded tensor
240
+ if self.is_first_backward_call:
241
+ self.curr_graph_id = torch._C._current_graph_task_id()
242
+
243
+ def wait_and_del_remaining_references() -> None:
244
+ for id in [k for k in self.bwd_tensor_stash.keys()]:
245
+ event = self.bwd_ev_stash[id]
246
+ self.s1.wait_event(event)
247
+ del self.bwd_tensor_stash[id]
248
+
249
+ # Register a callback to the end of autograd to clean everything up
250
+ torch.autograd.variable.Variable._execution_engine.queue_callback(
251
+ wait_and_del_remaining_references
252
+ )
253
+
254
+ if self.is_first_forward_pass:
255
+ self.is_first_forward_pass = False
256
+ if self.use_pin_memory:
257
+ verify_sufficient_virtual_memory()
258
+
259
+ self.is_first_backward_call = False
260
+ self.is_first_forward_call = True
261
+
262
+ assert (
263
+ unpack_tensor_id in self.tracker
264
+ ), f"untracked tensor with id {unpack_tensor_id}"
265
+
266
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
267
+ if modified:
268
+ # Get data on the current autograd node
269
+ graph_id = torch._C._current_graph_task_id()
270
+ node = torch._C._current_autograd_node()
271
+ prev_node_ids = []
272
+
273
+ # If we're on a new node, mark prev node's tensors to be freed later
274
+ if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
275
+ self.curr_autograd_node = node
276
+ prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
277
+
278
+ brought_back_from_cpu = True
279
+ if unpack_tensor_id in self.fwd_stash:
280
+ maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
281
+ brought_back_from_cpu = False
282
+ else:
283
+ # Kick off the process to bring tensors back
284
+ with torch.cuda.stream(self.s1):
285
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
286
+ maybe_gpu_tensor = gpu_tensor
287
+
288
+ # Tell comp stream to wait for the info to be loaded before executing
289
+ self.s0.wait_stream(self.s1)
290
+
291
+ # Stash the tensor to keep memory alive until compute stream is complete
292
+ self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
293
+
294
+ # Note: [Track views of the unpacked]
295
+ # Why do we get the use count of the unpacked tensor here? We want an
296
+ # initial count to compare to later, during the post-hook of the
297
+ # backward node, when we need to decide whether we're allowed to free
298
+ # the tensor yet. In what obscure cases must we delay freeing the
299
+ # tensor (and thus call record_stream)?
300
+ # 1. Any of the outputs of the backward node is a view of the unpacked
301
+ # tensor.
302
+ # 2. In the case that this unpacked tensor will be used in a
303
+ # checkpointed region, if one of the recomputed saved tensors ends
304
+ # up as a view of the unpacked tensor.
305
+ # 3. The user abuses the system somehow and manually relies on the
306
+ # unpacked tensor to exist after the backward node has executed.
307
+ storage_refcount = torch._C._storage_Use_Count(
308
+ maybe_gpu_tensor.untyped_storage()._cdata
309
+ )
310
+
311
+ def hook(outputs, inputs):
312
+ # create events for the current node inputs/outputs if they were streamed in
313
+ if brought_back_from_cpu:
314
+ # See Note: [Track views of the unpacked]
315
+ # IF any of the outputs is a view of the tensor, OR if a view of
316
+ # the tensor has been saved as a part of checkpoint's recompute
317
+ # process, OR the user has abusedly incurred a reference on the
318
+ # unpacked tensor, THEN the tensor might be used later and we
319
+ # cannot presume to delete it after only the current node is
320
+ # done! So we use our frenemy, record_stream, to ensure the
321
+ # Tensor stays unmessed with until it's done getting used in the
322
+ # compute stream (s0 here). Note that the con here is we introduce
323
+ # non-deterministic (thus higher) memory usage, but this case
324
+ # should not happen often.
325
+ unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
326
+ if (
327
+ torch._C._storage_Use_Count(
328
+ unpacked_tensor.untyped_storage()._cdata
329
+ )
330
+ > storage_refcount
331
+ ):
332
+ unpacked_tensor.record_stream(self.s0)
333
+ del self.bwd_tensor_stash[unpack_tensor_id]
334
+ else:
335
+ event = self.s0.record_event()
336
+ self.bwd_ev_stash[unpack_tensor_id] = event
337
+
338
+ # if there are still things in the fwd_stash, get rid of them as we're in bwd now
339
+ for id in [k for k in self.fwd_stash.keys()]:
340
+ _, ev = self.fwd_stash[id]
341
+ self.s0.wait_event(ev)
342
+ del self.fwd_stash[id]
343
+
344
+ # wait on prev node's events and del those
345
+ for id in prev_node_ids:
346
+ event = self.bwd_ev_stash[id]
347
+ self.s1.wait_event(event)
348
+ del self.bwd_tensor_stash[id]
349
+
350
+ return outputs
351
+
352
+ node.register_hook(hook)
353
+
354
+ # clear tensor from tracking
355
+ del self.tracker[unpack_tensor_id]
356
+ return maybe_gpu_tensor
357
+
358
+ unpack_tensor = (
359
+ unpack_tensor_with_streams
360
+ if self.use_streams
361
+ else unpack_tensor_single_stream
362
+ )
363
+ super().__init__(pack_tensor, unpack_tensor)
364
+
365
+
366
+ class NoOpManager(saved_tensors_hooks):
367
+ """
368
+ A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
369
+ applied before. This relies on the behavior that only the most recently registered
370
+ saved_tensors_hook will run.
371
+
372
+ One example usage is to opt a local region of code out of activations offloading,
373
+ which is usually applied globally to best track state.
374
+ """
375
+
376
+ def __init__(self) -> None:
377
+ def noop(tensor):
378
+ return tensor
379
+
380
+ super().__init__(noop, noop)
381
+
382
+
383
+ def get_act_offloading_ctx_manager(
384
+ model: nn.Module, enable_activation_offloading: bool
385
+ ) -> Union[OffloadActivations, contextlib.nullcontext]:
386
+ """Returns the activation offloading context manager for the model, which will be
387
+ a null context if enable_activation_offloading is False.
388
+
389
+ If activation offloading is enabled, we return the OffloadActivations context manager.
390
+ If activation offloading is disabled, we return a NoOpManager context manager.
391
+
392
+ Args:
393
+ model (nn.Module): the model to wrap with the activation offloading context manager.
394
+ enable_activation_offloading (bool): whether or not to enable activation offloading
395
+ for the model.
396
+
397
+ Returns:
398
+ contextlib.ContextDecorator: the activation offloading context manager for the model.
399
+
400
+ Raises:
401
+ NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
402
+ """
403
+ if enable_activation_offloading:
404
+ activations_handling_ctx = OffloadActivations()
405
+
406
+ # Below is our hack to disable offloading the last output Linear in every
407
+ # step, as the cost for offloading the activation and then soon after bringing
408
+ # it back is expensive. Moreover, due to heuristics in our streaming API,
409
+ # we actually use more memory if we offload it as it interferes with chunkedCE.
410
+ output_head_detected = False
411
+ noop_ctx = NoOpManager()
412
+
413
+ if hasattr(model, "output"):
414
+ if isinstance(model.output, nn.Module):
415
+ model.output.register_forward_pre_hook(
416
+ lambda *args: noop_ctx.__enter__()
417
+ )
418
+ model.output.register_forward_hook(
419
+ lambda *args: noop_ctx.__exit__(), always_call=True
420
+ )
421
+ print("registering hooks for model.output ============ ")
422
+ output_head_detected = True
423
+ # ================================
424
+ # ! TODO[flame] check if we need to detal with TiedLinear
425
+ # The following code appears in `torchtune`
426
+ # elif isinstance(model.output, TiedLinear):
427
+ # model.output.linear.register_forward_pre_hook(
428
+ # lambda *args: noop_ctx.__enter__()
429
+ # )
430
+ # model.output.linear.register_forward_hook(
431
+ # lambda *args: noop_ctx.__exit__(), always_call=True
432
+ # )
433
+ # output_head_detected = True
434
+
435
+ if not output_head_detected:
436
+ logger.warning(
437
+ "During activation offloading, no output head was detected. "
438
+ "If your model has an output head, it will be offloaded. "
439
+ "This usually greatly slows training, given the large vocabulary size. "
440
+ "To change this behavior, set your output head as model.output and make it "
441
+ "an nn.Module."
442
+ )
443
+
444
+ else:
445
+ activations_handling_ctx = contextlib.nullcontext()
446
+
447
+ return activations_handling_ctx
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/train.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+ from collections import defaultdict
12
+ import dataclasses
13
+
14
+ import torch
15
+ from datasets import interleave_datasets, load_dataset
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
18
+
19
+ import fla # noqa
20
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
21
+ from fla.ops.common.utils import prepare_position_ids
22
+ from flame.components.checkpoint import TrainState
23
+ from flame.config_manager import JobConfig
24
+ from flame.data import build_dataloader, shuffle
25
+ from flame.models.parallelize_fla import parallelize_fla
26
+ from flame.models.pipeline_fla import pipeline_fla
27
+ from flame.tools.utils import get_nparams_and_flops
28
+ from flame.utils.checkpoint import cleanup_local_checkpoints
29
+ from flame.utils.convert_dcp_to_hf import save_pretrained
30
+ from flame.utils.hf_utils import upload_checkpoint_to_hf
31
+ from datetime import datetime
32
+ from torchtitan.components.checkpoint import CheckpointManager
33
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
34
+ from torchtitan.components.loss import build_cross_entropy_loss
35
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
36
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
37
+ from torchtitan.components.optimizer import build_optimizers
38
+ from torchtitan.distributed import ParallelDims
39
+ from torchtitan.distributed import utils as dist_utils
40
+ from torchtitan.protocols.model_converter import build_model_converters
41
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
42
+ from torchtitan.tools import utils
43
+ from torchtitan.tools.logging import init_logger, logger
44
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
45
+
46
+ from dotenv import load_dotenv
47
+ load_dotenv()
48
+
49
+ import wandb
50
+ wandb.login(key=os.environ["WANDB_API_KEY"])
51
+
52
+ import huggingface_hub
53
+ huggingface_hub.login(token=os.environ["HF_TOKEN"])
54
+
55
+
56
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
57
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
58
+
59
+
60
+ register_train_spec(
61
+ TrainSpec(
62
+ name="fla",
63
+ cls=AutoModelForCausalLM,
64
+ config=AutoConfig,
65
+ parallelize_fn=parallelize_fla,
66
+ pipelining_fn=pipeline_fla,
67
+ build_optimizers_fn=build_optimizers,
68
+ build_lr_schedulers_fn=build_lr_schedulers,
69
+ build_dataloader_fn=build_dataloader,
70
+ build_tokenizer_fn=build_tokenizer,
71
+ build_loss_fn=build_cross_entropy_loss,
72
+ )
73
+ )
74
+
75
+
76
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
77
+ @record
78
+ def main(job_config: JobConfig):
79
+ logger.info(f"Starting job: {job_config.job.description}")
80
+
81
+ if job_config.experimental.custom_model_path:
82
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
83
+
84
+ # used for colorful printing
85
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
86
+
87
+ if job_config.job.print_args:
88
+ logger.info(
89
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
90
+ )
91
+
92
+ # take control of garbage collection to avoid stragglers
93
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
94
+
95
+ device_module, device_type = utils.device_module, utils.device_type
96
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
97
+ # Device has to be set before creating TorchFT manager.
98
+ device_module.set_device(device)
99
+ ft_manager = init_ft_manager(job_config)
100
+
101
+ run_specific_repo_id = None
102
+ if getattr(job_config.checkpoint, "hf_upload_enabled", False):
103
+ hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None)
104
+ if hf_repo_base:
105
+ # Generate timestamp (adjust format if desired)
106
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
107
+ run_specific_repo_id = f"{hf_repo_base}-{timestamp}"
108
+ logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}")
109
+ else:
110
+ logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.")
111
+ # Disable upload if base name is missing
112
+ job_config.checkpoint.hf_upload_enabled = False
113
+
114
+ # init distributed
115
+ world_size = int(os.environ["WORLD_SIZE"])
116
+ if not ft_manager.enabled:
117
+ parallel_dims = ParallelDims(
118
+ dp_shard=job_config.training.data_parallel_shard_degree,
119
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
120
+ cp=job_config.experimental.context_parallel_degree,
121
+ tp=job_config.training.tensor_parallel_degree,
122
+ pp=job_config.experimental.pipeline_parallel_degree,
123
+ world_size=world_size,
124
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
125
+ )
126
+ else:
127
+ parallel_dims = FTParallelDims(
128
+ dp_shard=job_config.training.data_parallel_shard_degree,
129
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
130
+ cp=job_config.experimental.context_parallel_degree,
131
+ tp=job_config.training.tensor_parallel_degree,
132
+ pp=job_config.experimental.pipeline_parallel_degree,
133
+ world_size=world_size,
134
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
135
+ ft_manager=ft_manager,
136
+ )
137
+ dist_utils.init_distributed(job_config)
138
+ # initialize device memory monitor and get peak flops for MFU calculation
139
+ device_memory_monitor = build_device_memory_monitor()
140
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
141
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
142
+
143
+ # build meshes
144
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
145
+ if parallel_dims.dp_enabled:
146
+ dp_mesh = world_mesh["dp"]
147
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
148
+ else:
149
+ dp_degree, dp_rank = 1, 0
150
+
151
+ if parallel_dims.pp_enabled:
152
+ raise NotImplementedError(
153
+ "Pipeline parallelism is not supported in this version"
154
+ )
155
+ """
156
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
157
+ [x] Match the key of models' components with the actual naming
158
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
159
+ forces to tie if head is None, we need to handle this case
160
+ [ ]
161
+ """
162
+ pp_mesh = world_mesh["pp"]
163
+
164
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
165
+ dist_utils.set_determinism(
166
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
167
+ )
168
+ train_spec = get_train_spec(job_config.model.name)
169
+
170
+ logger.info("Loading tokenizer...")
171
+ tokenizer = AutoTokenizer.from_pretrained(
172
+ job_config.model.tokenizer_path,
173
+ trust_remote_code=True,
174
+ model_max_length=int(1e10),
175
+ )
176
+ logger.info(f"{tokenizer}")
177
+ logger.info(
178
+ f"Loading dataset {job_config.training.dataset}"
179
+ f":{job_config.training.dataset_name}"
180
+ if job_config.training.dataset_name is not None
181
+ else ""
182
+ )
183
+
184
+ min_num_shards = dp_degree * job_config.training.num_workers
185
+ if len(job_config.training.dataset.split(",")) == 1:
186
+ dataset = load_dataset(
187
+ path=job_config.training.dataset,
188
+ name=getattr(job_config.training, "dataset_name", None),
189
+ data_dir=getattr(job_config.training, "data_dir", None),
190
+ data_files=getattr(job_config.training, "data_files", None),
191
+ split=job_config.training.dataset_split or "train",
192
+ trust_remote_code=True,
193
+ streaming=job_config.training.streaming,
194
+ num_proc=(
195
+ job_config.training.num_workers
196
+ if not job_config.training.streaming
197
+ else None
198
+ ),
199
+ )
200
+ logger.info(f"{dataset}")
201
+
202
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
203
+ if not job_config.training.streaming:
204
+ # the states of map-style dataset is recoverable after shuffling
205
+ dataset = dataset.shuffle(
206
+ seed=job_config.training.seed
207
+ ).to_iterable_dataset(num_shards=min_num_shards)
208
+ else:
209
+ if dataset.num_shards < min_num_shards:
210
+ logger.warning(
211
+ f"{color.red}"
212
+ f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). "
213
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
214
+ f"{job_config.training.num_workers} dataloader workers. "
215
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
216
+ f"{color.reset}"
217
+ )
218
+ dataset = (
219
+ load_dataset(
220
+ path=job_config.training.dataset,
221
+ name=getattr(job_config.training, "dataset_name", None),
222
+ data_dir=getattr(job_config.training, "data_dir", None),
223
+ data_files=getattr(job_config.training, "data_files", None),
224
+ split=job_config.training.dataset_split or "train",
225
+ trust_remote_code=True,
226
+ streaming=False,
227
+ num_proc=job_config.training.num_workers,
228
+ )
229
+ .shuffle(seed=job_config.training.seed)
230
+ .to_iterable_dataset(num_shards=min_num_shards)
231
+ )
232
+ else:
233
+ dataset = shuffle(dataset, seed=job_config.training.seed)
234
+ else:
235
+ datasets = job_config.training.dataset.split(",")
236
+ if job_config.training.dataset_name is not None:
237
+ dataset_names = [
238
+ name or None for name in job_config.training.dataset_name.split(",")
239
+ ]
240
+ assert len(dataset_names) == len(datasets), (
241
+ "The number of dataset names must match the number of datasets"
242
+ )
243
+ else:
244
+ dataset_names = [None] * len(datasets)
245
+ if job_config.training.dataset_split is not None:
246
+ dataset_splits = [
247
+ split or "train"
248
+ for split in job_config.training.dataset_split.split(",")
249
+ ]
250
+ assert len(dataset_splits) == len(datasets), (
251
+ "The number of dataset splits must match the number of datasets"
252
+ )
253
+ else:
254
+ dataset_splits = ["train"] * len(datasets)
255
+ if job_config.training.data_dir is not None:
256
+ data_dirs = [
257
+ data_dir or None for data_dir in job_config.training.data_dir.split(",")
258
+ ]
259
+ assert len(data_dirs) == len(datasets), (
260
+ "The number of data dirs must match the number of datasets"
261
+ )
262
+ else:
263
+ data_dirs = [None] * len(datasets)
264
+ if job_config.training.data_files is not None:
265
+ data_files = job_config.training.data_files.split(",")
266
+ assert len(data_files) == len(datasets), (
267
+ "The number of data files must match the number of datasets"
268
+ )
269
+ else:
270
+ data_files = [None] * len(datasets)
271
+ if job_config.training.data_probs is not None:
272
+ data_probs = [float(p) for p in job_config.training.data_probs.split(",")]
273
+ assert len(data_probs) == len(datasets), (
274
+ "The number of data probabilities must match the number of datasets"
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ "Data sampling probabilities are required if using multiple datasets"
279
+ )
280
+
281
+ subsets = []
282
+ for i, prob in enumerate(data_probs):
283
+ subset = load_dataset(
284
+ path=datasets[i],
285
+ name=dataset_names[i],
286
+ data_dir=data_dirs[i],
287
+ data_files=data_files[i],
288
+ split=dataset_splits[i],
289
+ trust_remote_code=True,
290
+ streaming=job_config.training.streaming,
291
+ num_proc=(
292
+ job_config.training.num_workers
293
+ if not job_config.training.streaming
294
+ else None
295
+ ),
296
+ )
297
+ logger.info(
298
+ f"Subset {color.cyan}{datasets[i]}"
299
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
300
+ + f"(p = {prob:.3f}){color.reset}:\n"
301
+ + f"{subset}"
302
+ )
303
+
304
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
305
+ if not job_config.training.streaming:
306
+ # the states of map-style dataset is recoverable after shuffling
307
+ subset = subset.shuffle(
308
+ seed=job_config.training.seed
309
+ ).to_iterable_dataset(num_shards=min_num_shards)
310
+ else:
311
+ if subset.num_shards < min_num_shards:
312
+ logger.warning(
313
+ f"{color.red}"
314
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
315
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
316
+ f"{job_config.training.num_workers} dataloader workers. "
317
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
318
+ f"{color.reset}"
319
+ )
320
+ # again, it's ok to directly shuffle the map-style dataset
321
+ # we expect an error raised if the map-style dataset still has not enough data shards
322
+ subset = (
323
+ load_dataset(
324
+ path=datasets[i],
325
+ name=dataset_names[i],
326
+ data_dir=data_dirs[i],
327
+ data_files=data_files[i],
328
+ split=dataset_splits[i],
329
+ trust_remote_code=True,
330
+ streaming=False,
331
+ num_proc=job_config.training.num_workers,
332
+ )
333
+ .shuffle(seed=job_config.training.seed)
334
+ .to_iterable_dataset(min_num_shards)
335
+ )
336
+ else:
337
+ # we set relatively small buffer size here as interleaving could provide some randomness
338
+ subset = shuffle(
339
+ subset,
340
+ seed=job_config.training.seed,
341
+ buffer_size=max(128, 1024 // len(datasets)),
342
+ )
343
+
344
+ if "text" in subset.column_names:
345
+ subset = subset.select_columns("text")
346
+ elif "content" in subset.column_names:
347
+ subset = subset.select_columns("content")
348
+ else:
349
+ raise ValueError(
350
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
351
+ )
352
+ subsets.append(subset)
353
+
354
+ logger.info(
355
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
356
+ )
357
+ dataset = interleave_datasets(
358
+ datasets=subsets,
359
+ probabilities=data_probs,
360
+ stopping_strategy="all_exhausted",
361
+ seed=job_config.training.seed,
362
+ )
363
+ logger.info(f"{dataset}")
364
+
365
+
366
+ logger.info(f"Loading model config from {job_config.model.config}")
367
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
368
+
369
+ logger.info("Building dataloader...")
370
+ dataloader = build_dataloader(
371
+ dataset=dataset,
372
+ tokenizer=tokenizer,
373
+ rank=dp_rank,
374
+ world_size=dp_degree,
375
+ batch_size=job_config.training.batch_size,
376
+ # TODO: Make this more modular
377
+ # seq_len=job_config.training.seq_len if not model_config.use_top_loss else job_config.training.seq_len*2,
378
+ seq_len=job_config.training.seq_len * 2,
379
+ context_len=job_config.training.context_len,
380
+ varlen=job_config.training.varlen,
381
+ num_workers=job_config.training.num_workers,
382
+ pin_memory=job_config.training.pin_memory,
383
+ persistent_workers=job_config.training.persistent_workers,
384
+ snapshot_every_n_steps=job_config.checkpoint.interval,
385
+ )
386
+
387
+ # set the model configs from training inputs:
388
+ # 1. norm type to decide which norm layer to use
389
+ # 2. disable fused norm if TP is enabled
390
+ # 3. vocab size from tokenizer
391
+ # 4. context_len base on inputs
392
+ if parallel_dims.tp_enabled:
393
+ if model_config.fuse_norm:
394
+ logger.warning(
395
+ f"{color.red}"
396
+ f"Fused norm is not compatible with tensor parallelism. "
397
+ f"Disabling it for now."
398
+ f"{color.reset}"
399
+ )
400
+ model_config.fuse_norm = False
401
+ if parallel_dims.loss_parallel_enabled:
402
+ if model_config.fuse_cross_entropy:
403
+ logger.warning(
404
+ f"{color.red}"
405
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
406
+ f"{color.reset}"
407
+ )
408
+ model_config.fuse_cross_entropy = False
409
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
410
+
411
+ logger.info(
412
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
413
+ )
414
+ with torch.device("meta"):
415
+ model = AutoModelForCausalLM.from_config(model_config)
416
+ if (
417
+ getattr(model_config, "fuse_cross_entropy", False)
418
+ and FusedLinearCrossEntropyLoss is not None
419
+ ):
420
+ model.criterion = FusedLinearCrossEntropyLoss(
421
+ num_chunks=8 // parallel_dims.tp
422
+ )
423
+ # defer weight initialization until after parallelisms are applied
424
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
425
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
426
+
427
+ # Build the collection of model converters. No-op if `model.converters` empty
428
+ model_converters = build_model_converters(job_config, parallel_dims)
429
+ model_converters.convert(model)
430
+
431
+ # calculate model size and flops per token
432
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
433
+ model, model_config, job_config.training.context_len
434
+ )
435
+
436
+ # move sharded model to CPU/GPU and initialize weights via DTensor
437
+ if job_config.checkpoint.create_seed_checkpoint:
438
+ init_device = "cpu"
439
+ elif job_config.training.enable_cpu_offload:
440
+ init_device = "cpu"
441
+ else:
442
+ init_device = device_type
443
+
444
+ # apply parallelisms and initialization
445
+ if parallel_dims.pp_enabled:
446
+ # apply PT-D Pipeline Parallel
447
+ (
448
+ pp_schedule,
449
+ model_parts,
450
+ has_first_stage,
451
+ has_last_stage,
452
+ ) = train_spec.pipelining_fn(
453
+ model,
454
+ pp_mesh,
455
+ parallel_dims,
456
+ job_config,
457
+ device,
458
+ model_config,
459
+ train_spec.loss_fn,
460
+ )
461
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
462
+ del model
463
+
464
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
465
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
466
+ # optimizer, and checkpointing
467
+ for m in model_parts:
468
+ # apply SPMD-style PT-D techniques
469
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
470
+ m.to_empty(device=init_device)
471
+ with torch.no_grad():
472
+ m.post_init()
473
+ m.train()
474
+
475
+ # confirm that user will be able to view loss metrics on the console
476
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
477
+ else:
478
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
479
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
480
+ model.to_empty(device=init_device)
481
+ with torch.no_grad():
482
+ model.post_init()
483
+ model.train()
484
+
485
+ model_parts = [model]
486
+
487
+ device_mem_stats = device_memory_monitor.get_peak_stats()
488
+ logger.info(
489
+ f"{device_type.upper()} memory usage for model: "
490
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
491
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
492
+ )
493
+
494
+ # build optimizer after applying parallelisms to the model
495
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
496
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
497
+ # Post optimizer step model converters hook.
498
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
499
+ # where it issues a single all-reduce for all parameters at once for better performance
500
+ optimizers.register_step_post_hook(
501
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
502
+ )
503
+
504
+ train_state = TrainState()
505
+
506
+ # load initial checkpoint
507
+ checkpoint = CheckpointManager(
508
+ dataloader=dataloader,
509
+ model_parts=model_parts,
510
+ optimizers=optimizers,
511
+ lr_schedulers=lr_schedulers,
512
+ states={"train_state": train_state},
513
+ job_config=job_config,
514
+ ft_manager=ft_manager,
515
+ )
516
+
517
+ if job_config.checkpoint.create_seed_checkpoint:
518
+ assert world_size == 1, (
519
+ "Must create seed checkpoint using a single device, to disable sharding"
520
+ )
521
+ assert job_config.checkpoint.enable_checkpoint, (
522
+ "Must enable checkpointing when creating a seed checkpoint"
523
+ )
524
+ checkpoint.save(curr_step=0, force=True)
525
+ logger.info("Created seed checkpoint")
526
+ return
527
+
528
+ checkpoint.load(step=job_config.checkpoint.load_step)
529
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
530
+ # Set dependent attributes for metric_logger
531
+ metric_logger.num_flops_per_token = num_flops_per_token
532
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
533
+ metric_logger.lr_schedulers = (
534
+ lr_schedulers # Pass schedulers if needed by logger logic
535
+ )
536
+
537
+ # plot losses loaded from checkpoint (if any) to TensorBoard
538
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
539
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
540
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
541
+ for idx, step in enumerate(train_state.log_steps):
542
+ metric_logger.log(
543
+ step,
544
+ global_avg_loss=train_state.global_avg_losses[idx],
545
+ global_max_loss=train_state.global_max_losses[idx],
546
+ )
547
+
548
+ data_iterator = iter(dataloader)
549
+
550
+ train_context = dist_utils.get_train_context(
551
+ parallel_dims.loss_parallel_enabled,
552
+ job_config.experimental.enable_compiled_autograd,
553
+ )
554
+
555
+ # variables used to keep info for metrics logging
556
+ device_memory_monitor.reset_peak_stats()
557
+
558
+ global_batch_size = (
559
+ job_config.training.batch_size
560
+ * dp_degree
561
+ * job_config.training.gradient_accumulation_steps
562
+ )
563
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
564
+ # train loop
565
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
566
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
567
+ logger.info(
568
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
569
+ )
570
+ logger.info(
571
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
572
+ )
573
+ logger.info(
574
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
575
+ )
576
+ logger.info(
577
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
578
+ f" ({num_tokens_per_step:,} tokens)"
579
+ )
580
+ logger.info(
581
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
582
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
583
+ )
584
+ logger.info(
585
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
586
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
587
+ )
588
+ logger.info(
589
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
590
+ )
591
+
592
+ with (
593
+ maybe_enable_profiling(
594
+ job_config, global_step=train_state.step
595
+ ) as torch_profiler,
596
+ maybe_enable_memory_snapshot(
597
+ job_config, global_step=train_state.step
598
+ ) as memory_profiler,
599
+ ):
600
+ while train_state.step < job_config.training.steps:
601
+ train_state.step += 1
602
+ gc_handler.run(train_state.step)
603
+
604
+ optimizers.zero_grad()
605
+
606
+ losses = defaultdict(list)
607
+ actual_loss = []
608
+ # do gradient accumulation if enabled
609
+ for _ in range(job_config.training.gradient_accumulation_steps):
610
+ # get batch
611
+ data_load_start = time.perf_counter()
612
+ batch = next(data_iterator)
613
+ # Recall that this is, for top and MTP, it will be
614
+ # input_ids : (B, seq_len)
615
+ # labels : (B, seq_len * 2)
616
+ input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"]
617
+
618
+ # Update metrics processor state before forward/backward
619
+ metric_logger.ntokens_since_last_log += input_ids.numel()
620
+ metric_logger.data_loading_times.append(
621
+ time.perf_counter() - data_load_start
622
+ )
623
+
624
+ input_ids = input_ids.to(device_type)
625
+
626
+ """
627
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
628
+ Depending on the Models'PE, the position_ids might be different.
629
+
630
+ e.g. for TP
631
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
632
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
633
+
634
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
635
+ Each rank has the coresponding chunked position_ids. [FOR All model]
636
+
637
+ """
638
+ labels = labels.to(device_type)
639
+ cu_seqlens = (
640
+ batch["cu_seqlens"].to(device_type)
641
+ if "cu_seqlens" in batch
642
+ else None
643
+ )
644
+ if cu_seqlens is not None:
645
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
646
+ else:
647
+ position_ids = (
648
+ torch.arange(0, input_ids.shape[1], device=device_type)
649
+ .repeat(input_ids.shape[0], 1)
650
+ .to(torch.int32)
651
+ )
652
+ # apply context parallelism if cp is enabled
653
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
654
+ optional_context_parallel_ctx = (
655
+ dist_utils.create_context_parallel_ctx(
656
+ cp_mesh=world_mesh["cp"],
657
+ cp_buffers=[input_ids, labels, position_ids],
658
+ cp_seq_dims=[1, 1, 1],
659
+ cp_no_restore_buffers={input_ids, labels, position_ids},
660
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
661
+ )
662
+ if parallel_dims.cp_enabled
663
+ else None
664
+ )
665
+
666
+ # #! TODO[flame], we should distribute the position_ids as well with CP
667
+ if parallel_dims.pp_enabled:
668
+ raise NotImplementedError(
669
+ "Pipeline parallelism is not supported in this version"
670
+ )
671
+ # Pipeline Parallel forward / backward inside step() call
672
+ with train_context(optional_context_parallel_ctx):
673
+ targets, losses = (
674
+ (labels, []) if has_last_stage else (None, None)
675
+ )
676
+
677
+ if has_first_stage:
678
+ pp_schedule.step(input_ids, target=targets, losses=losses)
679
+ else:
680
+ pp_schedule.step(target=targets, losses=losses)
681
+
682
+ # accumulate losses across pipeline microbatches
683
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
684
+ loss = (
685
+ torch.mean(torch.stack(losses)).to(device)
686
+ if has_last_stage
687
+ else torch.tensor([-1.0], device=device)
688
+ )
689
+ else:
690
+ # Non-PP forward / backward
691
+ with train_context(optional_context_parallel_ctx):
692
+ output = model(
693
+ input_ids=input_ids,
694
+ labels=labels,
695
+ position_ids=position_ids,
696
+ cu_seqlens=cu_seqlens,
697
+ )
698
+ output_attributes = [field.name for field in dataclasses.fields(output)]
699
+ losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"]
700
+ loss = (
701
+ output.loss
702
+ / job_config.training.gradient_accumulation_steps
703
+ )
704
+ loss.backward()
705
+
706
+ actual_loss.append(loss)
707
+ for loss_attr in losses_atributes:
708
+ custom_loss = getattr(output, loss_attr, None)
709
+ if custom_loss is not None:
710
+ custom_loss = custom_loss / job_config.training.gradient_accumulation_steps
711
+ custom_loss = custom_loss
712
+ losses[loss_attr].append(custom_loss)
713
+
714
+ loss = sum(actual_loss)
715
+ for loss_attr, loss_values in losses.items():
716
+ losses[loss_attr] = sum(loss_values)
717
+
718
+ # clip gradients
719
+ grad_norm = dist_utils.clip_grad_norm_(
720
+ [p for m in model_parts for p in m.parameters()],
721
+ job_config.training.max_norm,
722
+ foreach=True,
723
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
724
+ )
725
+
726
+ # optimizer step
727
+ checkpoint.maybe_wait_for_staging()
728
+ if job_config.training.skip_nan_inf and (
729
+ grad_norm.isnan() or grad_norm.isinf()
730
+ ):
731
+ logger.warning(
732
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
733
+ )
734
+ optimizers.zero_grad()
735
+ train_state.skipped_step += 1
736
+ else:
737
+ optimizers.step()
738
+ lr_schedulers.step()
739
+
740
+ # log metrics - Use MetricsProcessor
741
+ global_avg_custom_loss = {}
742
+ global_max_custom_loss = {}
743
+ if metric_logger.should_log(train_state.step):
744
+ if (
745
+ parallel_dims.dp_replicate_enabled
746
+ or parallel_dims.dp_shard_enabled
747
+ or parallel_dims.cp_enabled
748
+ ):
749
+ loss = loss.detach()
750
+ # Use dist_mean/max on the accumulated loss for the step
751
+ global_avg_loss, global_max_loss = (
752
+ dist_utils.dist_mean(
753
+ loss,
754
+ world_mesh["dp_cp"],
755
+ ),
756
+ dist_utils.dist_max(
757
+ loss,
758
+ world_mesh["dp_cp"],
759
+ ),
760
+ )
761
+ for loss_attr, loss_value in losses.items():
762
+ global_avg_custom_loss[loss_attr] = dist_utils.dist_mean(
763
+ loss_value, world_mesh["dp_cp"]
764
+ )
765
+ global_max_custom_loss[loss_attr] = dist_utils.dist_max(
766
+ loss_value, world_mesh["dp_cp"]
767
+ )
768
+ else:
769
+ # Scale back the loss before logging
770
+ global_avg_loss = global_max_loss = loss.item()
771
+ for loss_attr, loss_value in losses.items():
772
+ global_avg_custom_loss[loss_attr] = global_max_custom_loss[
773
+ loss_attr
774
+ ] = loss_value.item()
775
+
776
+ # Update train state tokens and elapsed time
777
+ time_now = time.perf_counter()
778
+ time_delta = (
779
+ time_now - metric_logger.time_last_log
780
+ ) # Use metric_logger's time
781
+ train_state.token += (
782
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
783
+ * parallel_dims.world_size
784
+ / parallel_dims.non_data_parallel_size
785
+ )
786
+ train_state.elapsed += timedelta(seconds=time_delta)
787
+ train_state.log_steps.append(train_state.step)
788
+ train_state.global_avg_losses.append(global_avg_loss)
789
+ train_state.global_max_losses.append(global_max_loss)
790
+
791
+ # Log using the metric processor
792
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
793
+ eta = (
794
+ train_state.elapsed
795
+ * (job_config.training.steps - train_state.step)
796
+ / train_state.step
797
+ )
798
+ extra_metrics = {
799
+ "optimizer/lr": last_lr,
800
+ "optimizer/grad_norm": grad_norm.item(),
801
+ "optimizer/skipped_step": train_state.skipped_step,
802
+ }
803
+ for loss_attr, loss_value in global_avg_custom_loss.items():
804
+ extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value
805
+ metric_logger.log(
806
+ train_state.step,
807
+ global_avg_loss,
808
+ global_max_loss,
809
+ extra_metrics=extra_metrics,
810
+ )
811
+
812
+ logger.info(
813
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
814
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
815
+ )
816
+
817
+ checkpoint.save(
818
+ train_state.step, force=(train_state.step == job_config.training.steps)
819
+ )
820
+
821
+ if torch.distributed.get_rank() == 0:
822
+ if job_config.checkpoint.enable_checkpoint:
823
+ hf_target_path = None
824
+ dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}")
825
+
826
+ # TODO: Haven't tested this one yet
827
+ if getattr(job_config.checkpoint, "convert_to_hf_on_save", False):
828
+ try:
829
+ # Get the path where DCP was just saved
830
+ # Check CheckpointManager API for the best way, assuming get_save_path exists
831
+ hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf
832
+
833
+ logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}")
834
+ save_pretrained( # Call the imported function
835
+ path=hf_target_path, # Pass target HF path as 'path'
836
+ step=train_state.step,
837
+ config=job_config.model.config, # Pass model config path/id
838
+ tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id
839
+ )
840
+ logger.info(f"Successfully converted step {train_state.step} to HF format.")
841
+
842
+ except Exception as e:
843
+ logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True)
844
+
845
+ base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder)
846
+ if getattr(job_config.checkpoint, "hf_upload_enabled", True):
847
+ upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf")
848
+ keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5)
849
+
850
+ local_path_to_upload = None
851
+ if upload_format == "hf":
852
+ if hf_target_path and os.path.isdir(hf_target_path):
853
+ local_path_to_upload = hf_target_path
854
+ elif upload_format == "dcp":
855
+ if dcp_save_path and os.path.isdir(dcp_save_path):
856
+ local_path_to_upload = dcp_save_path
857
+
858
+ if local_path_to_upload:
859
+ try:
860
+ upload_checkpoint_to_hf(
861
+ local_path=local_path_to_upload,
862
+ step=train_state.step,
863
+ hf_repo_id_for_run=run_specific_repo_id,
864
+ upload_format=upload_format,
865
+ hf_keep_latest_k=job_config.checkpoint.keep_latest_k,
866
+ )
867
+ except Exception as e:
868
+ logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True)
869
+
870
+ # signal the profiler that the next profiling step has started
871
+ if torch_profiler:
872
+ torch_profiler.step()
873
+ if memory_profiler:
874
+ memory_profiler.step()
875
+
876
+ # reduce timeout after first train step for faster signal
877
+ # (assuming lazy init and compilation are finished)
878
+ if train_state.step == 1:
879
+ dist_utils.set_pg_timeouts(
880
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
881
+ world_mesh=world_mesh,
882
+ )
883
+
884
+ if torch.distributed.get_rank() == 0:
885
+ logger.info("Sleeping 2 seconds for other ranks to complete")
886
+ time.sleep(2)
887
+
888
+ metric_logger.close()
889
+ logger.info("Training completed")
890
+
891
+
892
+ if __name__ == "__main__":
893
+ init_logger()
894
+ config = JobConfig()
895
+ config.parse_args()
896
+ main(config)
897
+ torch.distributed.destroy_process_group()
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file