zaydzuhri commited on
Commit
0298ad2
·
verified ·
1 Parent(s): 854552e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +471 -0
  3. config.json +33 -0
  4. configs/delta_net_1B.json +29 -0
  5. configs/delta_net_340M.json +27 -0
  6. configs/gla_340M.json +24 -0
  7. configs/gla_7B.json +25 -0
  8. configs/gsa_340M.json +29 -0
  9. configs/hgrn2_340M.json +20 -0
  10. configs/rectified_transformer_120M.json +19 -0
  11. configs/rectified_transformer_340M.json +19 -0
  12. configs/softpick_transformer_120M.json +19 -0
  13. configs/softpick_transformer_340M.json +19 -0
  14. configs/transformer_120M.json +18 -0
  15. configs/transformer_1B.json +22 -0
  16. configs/transformer_340M.json +18 -0
  17. configs/transformer_7B.json +21 -0
  18. configs/vanilla_transformer_120M.json +19 -0
  19. configs/vanilla_transformer_340M.json +19 -0
  20. download_checkpoint.py +35 -0
  21. fla/__init__.py +110 -0
  22. fla/layers/gla.py +294 -0
  23. fla/layers/rwkv7.py +221 -0
  24. fla/utils.py +221 -0
  25. flame/__init__.py +1 -0
  26. flame/components/__init__.py +0 -0
  27. flame/components/checkpoint.py +59 -0
  28. flame/config_manager.py +940 -0
  29. flame/data.py +570 -0
  30. flame/models/__init__.py +0 -0
  31. flame/models/activation_offloading.py +447 -0
  32. flame/models/fla.toml +67 -0
  33. flame/models/parallelize_fla.py +550 -0
  34. flame/models/pipeline_fla.py +162 -0
  35. flame/tools/__init__.py +0 -0
  36. flame/tools/utils.py +41 -0
  37. flame/train.py +851 -0
  38. flame/utils/__init__.py +0 -0
  39. flame/utils/checkpoint.py +50 -0
  40. flame/utils/convert_dcp_to_hf.py +66 -0
  41. flame/utils/convert_hf_to_dcp.py +34 -0
  42. flame/utils/hf_utils.py +77 -0
  43. generation_config.json +6 -0
  44. logs/none_nygareex/attempt_0/0/stderr.log +0 -0
  45. logs/none_nygareex/attempt_0/0/stdout.log +0 -0
  46. logs/none_nygareex/attempt_0/1/stderr.log +0 -0
  47. logs/none_nygareex/attempt_0/1/stdout.log +0 -0
  48. logs/none_nygareex/attempt_0/2/stderr.log +0 -0
  49. logs/none_nygareex/attempt_0/2/stdout.log +0 -0
  50. logs/none_nygareex/attempt_0/3/stderr.log +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
8
+
9
+ **Feature Highlights:**
10
+
11
+ - 🚀 Minimal, easy-to-use, extensible training framework
12
+ - 🤗 Seamless integration with `fla` and `transformers`
13
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
14
+ - 🔮 4D parallelism (coming soon)
15
+
16
+ ## Setup
17
+
18
+ To get started, clone the `flame` repository and install the required dependencies:
19
+
20
+ ```bash
21
+ git clone https://github.com/fla-org/flame.git
22
+ cd flame
23
+ pip install .
24
+ ```
25
+
26
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
27
+ After installation, initialize and update the submodules:
28
+ ```sh
29
+ git submodule update --init --recursive
30
+ ```
31
+
32
+ ## Dataset Preparation
33
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
34
+
35
+ ```py
36
+ from datasets import load_dataset
37
+
38
+ # load fineweb-edu with parallel processing
39
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
40
+
41
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
42
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
43
+ ```
44
+
45
+ ## Training Recipes
46
+
47
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
48
+
49
+ > [!WARNING]
50
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
51
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
52
+
53
+ ```sh
54
+ bash train.sh \
55
+ --job.config_file flame/models/fla.toml \
56
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
57
+ --model.config configs/transformer_340M.json \
58
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
59
+ --optimizer.name AdamW \
60
+ --optimizer.eps 1e-15 \
61
+ --optimizer.lr 3e-4 \
62
+ --lr_scheduler.warmup_steps 1024 \
63
+ --lr_scheduler.lr_min 0.1 \
64
+ --lr_scheduler.decay_type cosine \
65
+ --training.batch_size 1 \
66
+ --training.seq_len 65536 \
67
+ --training.context_len 4096 \
68
+ --training.varlen \
69
+ --training.gradient_accumulation_steps 1 \
70
+ --training.steps 20480 \
71
+ --training.max_norm 1.0 \
72
+ --training.skip_nan_inf \
73
+ --training.dataset HuggingFaceFW/fineweb-edu \
74
+ --training.dataset_name sample-100BT \
75
+ --training.dataset_split train \
76
+ --training.streaming \
77
+ --training.num_workers 32 \
78
+ --training.prefetch_factor 2 \
79
+ --training.seed 42 \
80
+ --training.compile \
81
+ --checkpoint.interval 2048 \
82
+ --checkpoint.load_step -1 \
83
+ --checkpoint.keep_latest_k 2 \
84
+ --metrics.log_freq 1
85
+ ```
86
+
87
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
88
+ **For single-GPU debugging, set `NGPU=1`.**
89
+
90
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
91
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
92
+
93
+ **Key parameters:**
94
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
95
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
96
+ - `--training.steps`: Total number of training steps.
97
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
98
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
99
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
100
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
101
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
102
+
103
+ > [!WARNING]
104
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
105
+ > Each step processes `global_batch_size * seq_len` tokens.
106
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
107
+
108
+ For a detailed explanation of all parameters, run:
109
+
110
+ ```sh
111
+ bash train.sh -h
112
+ ```
113
+
114
+ <details>
115
+ <summary>Usage</summary>
116
+
117
+ ```py
118
+ options:
119
+ -h, --help show this help message and exit
120
+ --job.config_file JOB.CONFIG_FILE
121
+ Job config file
122
+ --job.dump_folder JOB.DUMP_FOLDER
123
+ Folder to dump job outputs
124
+ --job.description JOB.DESCRIPTION
125
+ Description of the job
126
+ --job.use_for_integration_test
127
+ Add this config to the integration test suite
128
+ --job.print_args Print the args to terminal
129
+ --model.config MODEL.CONFIG
130
+ Path to the model config
131
+ --model.norm_type MODEL.NORM_TYPE
132
+ Type of layer normalization to use [layernorm,
133
+ np_layernorm, rmsnorm, fused_rmsnorm]
134
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
135
+ Tokenizer path
136
+ --profiling.enable_profiling
137
+ Whether to enable pytorch profiler
138
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
139
+ Trace files location
140
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
141
+ How often to collect profiler traces, in iterations
142
+ --profiling.enable_memory_snapshot
143
+ Whether to dump memory snapshot
144
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
145
+ Memeory snapshot files location
146
+ --optimizer.name OPTIMIZER.NAME
147
+ Optimizer to use
148
+ --optimizer.eps OPTIMIZER.EPS
149
+ Epsilon value for the optimizer.
150
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
151
+ --optimizer.scheduler {wsd,cosine,linear}
152
+ Scheduler to use. Currently supported: wsd, cosine,
153
+ and linear.
154
+ --optimizer.lr OPTIMIZER.LR
155
+ Learning rate to use
156
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
157
+ Min lr ratio for lr scheduler
158
+ --optimizer.early_step_in_backward
159
+ Whether to apply optimizer in the backward. Caution,
160
+ optimizer_in_backward is not compatible with gradients
161
+ clipping, users should not call
162
+ register_post_accumulate_grad_hook after the optimizer
163
+ is built.
164
+ --training.batch_size TRAINING.BATCH_SIZE
165
+ Batch size
166
+ --training.seq_len TRAINING.SEQ_LEN
167
+ Sequence length
168
+ --training.context_len TRAINING.CONTEXT_LEN
169
+ Max length allowed for each sequence
170
+ --training.varlen Whether to take sequences of variable length as input
171
+ --training.warmup_steps TRAINING.WARMUP_STEPS
172
+ Steps for lr scheduler warmup, normally 1/5 of
173
+ --training.steps
174
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
175
+ Number of steps to accumulate gradients before
176
+ updating parameters
177
+ --training.steps TRAINING.STEPS
178
+ How many train steps to run
179
+ --training.max_norm TRAINING.MAX_NORM
180
+ Max norm for gradient clipping
181
+ --training.skip_nan_inf
182
+ Skip batch updates when NaN or INF gradients are
183
+ encountered during training
184
+ --training.dataset TRAINING.DATASET
185
+ Dataset to use, with comma separated values
186
+ --training.dataset_name TRAINING.DATASET_NAME
187
+ The name of the dataset config, with comma separated
188
+ values if provided
189
+ --training.dataset_split TRAINING.DATASET_SPLIT
190
+ Dataset split to use, with comma separated values if
191
+ provided
192
+ --training.data_dir TRAINING.DATA_DIR
193
+ Data dirs to use, with comma separated values if
194
+ provided
195
+ --training.data_files TRAINING.DATA_FILES
196
+ Data files to use, with comma separated values if
197
+ provided
198
+ --training.data_probs TRAINING.DATA_PROBS
199
+ Data sampling probabilities, with comma separated
200
+ values if provided
201
+ --training.streaming Whether to load dataset in streaming mode, used for
202
+ huge dataset
203
+ --training.num_workers TRAINING.NUM_WORKERS
204
+ Number of subprocesses to use for data loading. 0
205
+ means that the data will be loaded in the main
206
+ process.
207
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
208
+ Number of batches loaded in advance by each worker.2
209
+ means there will be a total of 2 * num_workers batches
210
+ prefetched across all workers.
211
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
212
+ The `data_parallel_replicate_degree` argument
213
+ specifies the degree of data parallelism for weight
214
+ replication. When this value is greater than 1,
215
+ weights will be replicated across
216
+ `data_parallel_replicate_degree` ranks. If
217
+ `data_parallel_shard_degree` is also greater than 1,
218
+ the parallelism method used is HSDP (Hybrid Sharded
219
+ Data Parallelism). Otherwise, the parallelism method
220
+ used is DDP (Distributed Data Parallelism). 1 means
221
+ disabled.
222
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
223
+ The `data_parallel_shard_degree` argument specifies
224
+ the degree of data parallelism for weight sharding.
225
+ When this value is greater than 1, weights will be
226
+ sharded across `data_parallel_shard_degree` ranks. If
227
+ `data_parallel_replicate_degree` is also greater than
228
+ 1, the parallelism method used is HSDP (Hybrid Sharded
229
+ Data Parallelism). Otherwise, the parallelism method
230
+ used is FSDP (Fully Sharded Data Parallelism). -1
231
+ means leftover ranks will be used (After
232
+ DP_REPLICATE/SP/PP). Note that only
233
+ `data_parallel_shard_degree` can be negative. 1 means
234
+ disabled.
235
+ --training.enable_cpu_offload
236
+ Whether to apply CPU offloading of parameters,
237
+ gradients, and optimizer states in FSDP
238
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
239
+ Tensor Parallelism degree. 1 means disabled.
240
+ --training.disable_loss_parallel
241
+ Whether to apply loss parallel when sequence parallel
242
+ is enabled
243
+ --training.mixed_precision_param {bfloat16,float32}
244
+ torch dtype to use for parameters when applying mixed
245
+ precision via FSDP. This feature only takes effect
246
+ when data_parallel_shard_degree > 1
247
+ --training.mixed_precision_reduce {float32}
248
+ torch dtype to use for reductions when applying mixed
249
+ precision via FSDP. This feature only takes effect
250
+ when data_parallel_shard_degree > 1
251
+ --training.compile Whether to compile the model
252
+ --training.gc_freq TRAINING.GC_FREQ
253
+ Python garbage control scheduling interval, in steps
254
+ --training.seed TRAINING.SEED
255
+ Choose the base RNG seed used for training
256
+ --training.deterministic
257
+ Use deterministic algorithms wherever possible, may be
258
+ slower
259
+ --metrics.log_freq METRICS.LOG_FREQ
260
+ How often to log metrics to TensorBoard, in iterations
261
+ --metrics.enable_tensorboard
262
+ Whether to log metrics to TensorBoard
263
+ --metrics.disable_color_printing
264
+ Whether to disable color printing in logs
265
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
266
+ Folder to dump TensorBoard states
267
+ --metrics.rank_0_only
268
+ Whether to save TensorBoard metrics only for rank 0 or
269
+ for all ranks. When pipeline_parallel_degree is > 1,
270
+ this option uses the 0th rank of the last stage
271
+ pipeline group, which is the only stage that computes
272
+ loss metrics.
273
+ --metrics.enable_wandb
274
+ Whether to log metrics to Weights & Biases
275
+ --experimental.enable_async_tensor_parallel
276
+ Whether to apply async tensor parallel (currently only
277
+ effective when compile is enabled)
278
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
279
+ Pipeline Parallelism degree, or number of ranks. 1
280
+ means disabled. If using looped schedules, this still
281
+ specifies the number of physical ranks, not the number
282
+ of stages. Stages per rank are inferred from split
283
+ points degree, and schedule.
284
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
285
+ Specify comma-separated names of modules to use as the
286
+ beginning of a split point. e.g. "layers.0,layers.2"
287
+ will cause the model to be split into 3 stages, the
288
+ first containing all the layers up to layers.0, the
289
+ second containing layers.0 and up to layers.2, the
290
+ third containing layers.2 and all the remaining
291
+ layers. Note: fully-automated splitting may be enabled
292
+ in the future, but currently the split points must be
293
+ specified manually.
294
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
295
+ Specify the Pipeline Parallel schedule to use. The
296
+ supported schedules are: https://github.com/pytorch/py
297
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
298
+ rch/distributed/pipelining/schedules.py#L2161. The
299
+ schedule must be compatible with the split points and
300
+ stages_per_rank. Looped schedules (e.g.
301
+ Interleaved1F1B) require specifying
302
+ pipeline_parallel_degree = number of ranks, and
303
+ split_points = number of stages - 1
304
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
305
+ Specify the path to the pipeline parallel schedule csv
306
+ file to use. The pipeline_parallel_schedule argument
307
+ must be either PipelineScheduleSingle,
308
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
309
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
310
+ How many microbatches to split the global training
311
+ batch into when using pipeline parallelism. The global
312
+ training batch size must be evenly divisible by the
313
+ number of microbatches. The default value will be the
314
+ number of pipeline stages, if unspecified.
315
+ --experimental.enable_compiled_autograd
316
+ Enable CompiledAutograd to compile the backward.
317
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
318
+ Context parallelism degree. 1 means disabled.
319
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
320
+ The collective to use in context parallel SDPA for kv
321
+ shards exchange. 'allgather' means to all-gather all
322
+ kv shards on ranks after the first sub-SDPA
323
+ computation, 'alltoall' means to all-to-all shuffle
324
+ the kv shards. The default value is 'allgather'.
325
+ --checkpoint.enable_checkpoint
326
+ Whether to enable checkpoint
327
+ --checkpoint.folder CHECKPOINT.FOLDER
328
+ The folder to store the checkpoints. When
329
+ enable_checkpoint is set to true, checkpoints will be
330
+ in {--job.dump_folder}/{--checkpoint.folder}.
331
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
332
+ Checkpointing interval unit of measurement ['step',
333
+ 'seconds']
334
+ --checkpoint.interval CHECKPOINT.INTERVAL
335
+ Checkpointing interval, in steps or seconds depending
336
+ on --checkpoint.interval_type
337
+ --checkpoint.model_weights_only
338
+ When model_weights_only=True, only model weights will
339
+ be saved at the end of training. With this,
340
+ checkpoints can be loaded using `torch.load(...,
341
+ weights_only=True)` after conversion. When
342
+ model_weights_only=False, the full checkpoint will be
343
+ saved. A full checkpoint includes model, optimizer and
344
+ train_state, which can be used to resume training. The
345
+ default value is false.
346
+ --checkpoint.export_dtype {float16,bfloat16,float32}
347
+ Converts to the specified precision when training
348
+ completes and model_weights_only=true. Currently
349
+ supports float32, float16, and bfloat16. The default
350
+ value is float32.
351
+ --checkpoint.create_seed_checkpoint
352
+ Initializes the full model without applying
353
+ parallelisms, and then saves it as a seed checkpoint.
354
+ Note: requires user to call train.py without
355
+ specifying any parallelisms, e.g. NGPU=1. Could be
356
+ implemented as a separate script, but this way shares
357
+ more code.
358
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
359
+ Which async checkpoint mode to use. Currently there
360
+ are 3 different modes. 1. "disabled": synchronized
361
+ checkpointing will be used. 2. "async":
362
+ torch.distributed.checkpoint.async_save will be used.
363
+ 1. "async_with_pinned_mem": this option utilizes a
364
+ dedicated pinned memory space and creates a separate
365
+ process for faster GPU->CPU transfer performance and
366
+ eliminating GIL contention. The cost is increased CPU
367
+ memory usage. If insufficient CPU memory is available,
368
+ performance may degrade due to memory paging. For most
369
+ users, "async" should suffice as the performance
370
+ overhead is typically small (on the order of tens of
371
+ seconds) compared to checkpointing frequency. This
372
+ mode can be employed to pursue near-zero checkpointing
373
+ times (e.g., < 1 second) given appropriate hardware
374
+ support such as ample CPU memory and fast PCIe.
375
+ "disabled" is the default mode.
376
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
377
+ Keeps only the latest k checkpoints, and purging older
378
+ ones. If 0, keep all checkpoints. 0 is the default
379
+ value.
380
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
381
+ Load the checkpoint at the specified step. If -1, load
382
+ the latest checkpoint.
383
+ --float8.enable_float8_linear
384
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
385
+ This feature requires you to install 'torchao' which
386
+ can be found here: https://github.com/pytorch/ao
387
+ --float8.enable_fsdp_float8_all_gather
388
+ Whether enable float8 all-gather in FSDP
389
+ --float8.precompute_float8_dynamic_scale_for_fsdp
390
+ Whether precompute float8 scales dynamically for FSDP
391
+ --float8.scaling_type_input {dynamic,delayed}
392
+ float8 scaling for input, dynamic (default) or delayed
393
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
394
+ float8 scaling for input, dynamic (default) or delayed
395
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
396
+ float8 scaling for input, dynamic (default) or delayed
397
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
398
+ Timeout for communication operations, during
399
+ initialization and first train step.
400
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
401
+ Timeout for communication operations after the first
402
+ train step -- usually a tighter bound than during
403
+ initialization.
404
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
405
+ Flight recorder ring buffer size, >0 means recording
406
+ by default, 0 means disabled
407
+ --memory_estimation.enabled
408
+ Whether to estimate memory usage for FSDP
409
+ --memory_estimation.disable_fake_mode
410
+ Whether to estimate memory under FakeTensorMode
411
+ ```
412
+ </details>
413
+
414
+ ### Training with `torch.compile`
415
+
416
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
417
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
418
+
419
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
420
+ We are actively working on resolving these issues to make compilation transparent to users.
421
+ In the meantime, please ensure you are using the latest dependencies.
422
+
423
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
424
+
425
+ ### Training with multiple datasets
426
+
427
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
428
+ `flame` allows training with multiple datasets easily.
429
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
430
+
431
+ ```sh
432
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
433
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
434
+ ```
435
+
436
+ ### ~Finalizing training~
437
+
438
+ > [!NOTE]
439
+ > We have done this conversion automatically in the training script since our latest updates.
440
+
441
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
442
+ To facilitate this, we provide a straightforward conversion script:
443
+
444
+ ```sh
445
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
446
+ ```
447
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
448
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
449
+
450
+ ### Continual training
451
+
452
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
453
+ This allows you to seamlessly resume training with `flame`.
454
+ ```sh
455
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
456
+ ```
457
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
458
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
459
+
460
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
461
+
462
+ ## Multi-node training
463
+
464
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
465
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
466
+
467
+ To set up multi-node training:
468
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
469
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
470
+
471
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attn_impl": "parallel_softpick_attn",
7
+ "bos_token_id": 1,
8
+ "elementwise_affine": true,
9
+ "eos_token_id": 2,
10
+ "fuse_cross_entropy": true,
11
+ "fuse_norm": true,
12
+ "fuse_swiglu": true,
13
+ "hidden_act": "swish",
14
+ "hidden_ratio": 4,
15
+ "hidden_size": 1024,
16
+ "initializer_range": 0.006,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 8192,
19
+ "model_type": "transformer",
20
+ "norm_eps": 1e-06,
21
+ "num_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_kv_heads": null,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.51.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32000,
32
+ "window_size": null
33
+ }
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/rectified_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_rectified_attn"
19
+ }
configs/rectified_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_rectified_attn"
19
+ }
configs/softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_softpick_attn"
19
+ }
configs/softpick_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_softpick_attn"
19
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
configs/vanilla_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_attn"
19
+ }
configs/vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_attn"
19
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel
55
+ )
56
+
57
+ __all__ = [
58
+ 'ABCAttention',
59
+ 'Attention',
60
+ 'BasedLinearAttention',
61
+ 'BitAttention',
62
+ 'DeltaNet',
63
+ 'GatedDeltaNet',
64
+ 'GatedDeltaProduct',
65
+ 'GatedLinearAttention',
66
+ 'GatedSlotAttention',
67
+ 'HGRNAttention',
68
+ 'HGRN2Attention',
69
+ 'LightNetAttention',
70
+ 'LinearAttention',
71
+ 'MultiScaleRetention',
72
+ 'NativeSparseAttention',
73
+ 'ReBasedLinearAttention',
74
+ 'RWKV6Attention',
75
+ 'RWKV7Attention',
76
+ 'ABCForCausalLM',
77
+ 'ABCModel',
78
+ 'BitNetForCausalLM',
79
+ 'BitNetModel',
80
+ 'DeltaNetForCausalLM',
81
+ 'DeltaNetModel',
82
+ 'GatedDeltaNetForCausalLM',
83
+ 'GatedDeltaNetModel',
84
+ 'GatedDeltaProductForCausalLM',
85
+ 'GatedDeltaProductModel',
86
+ 'GLAForCausalLM',
87
+ 'GLAModel',
88
+ 'GSAForCausalLM',
89
+ 'GSAModel',
90
+ 'HGRNForCausalLM',
91
+ 'HGRNModel',
92
+ 'HGRN2ForCausalLM',
93
+ 'HGRN2Model',
94
+ 'LightNetForCausalLM',
95
+ 'LightNetModel',
96
+ 'LinearAttentionForCausalLM',
97
+ 'LinearAttentionModel',
98
+ 'NSAForCausalLM',
99
+ 'NSAModel',
100
+ 'RetNetForCausalLM',
101
+ 'RetNetModel',
102
+ 'RWKV6ForCausalLM',
103
+ 'RWKV6Model',
104
+ 'RWKV7ForCausalLM',
105
+ 'RWKV7Model',
106
+ 'TransformerForCausalLM',
107
+ 'TransformerModel',
108
+ ]
109
+
110
+ __version__ = '0.1.2'
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/rwkv7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.layers.rwkv6 import LoRA
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV7Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ head_dim: Optional[int] = 64,
29
+ num_heads: Optional[int] = None,
30
+ decay_low_rank_dim: int = 64,
31
+ gate_low_rank_dim: int = 128,
32
+ a_low_rank_dim: int = 64,
33
+ v_low_rank_dim: int = 16,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None,
37
+ fuse_norm: bool = False,
38
+ value_dim: int = None,
39
+ **kwargs
40
+ ) -> RWKV7Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
45
+ self.hidden_size = hidden_size
46
+
47
+ self.key_dim = hidden_size
48
+ self.value_dim = value_dim if value_dim is not None else hidden_size
49
+ if head_dim is None and num_heads is None:
50
+ raise ValueError("Either `head_dim` or `num_heads` must be specified.")
51
+ elif head_dim is not None:
52
+ self.head_dim = head_dim
53
+ self.num_heads = int(hidden_size // head_dim)
54
+ elif num_heads is not None:
55
+ self.head_dim = int(hidden_size // num_heads)
56
+ self.num_heads = num_heads
57
+ self.head_v_dim = int(self.value_dim // self.num_heads)
58
+
59
+ self.decay_low_rank_dim = decay_low_rank_dim
60
+ self.gate_low_rank_dim = gate_low_rank_dim
61
+ self.a_low_rank_dim = a_low_rank_dim
62
+ self.v_low_rank_dim = v_low_rank_dim
63
+ self.layer_idx = layer_idx
64
+ self.fuse_norm = fuse_norm
65
+
66
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
67
+
68
+ self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
69
+
70
+ self.k_k = nn.Parameter(torch.zeros(self.key_dim))
71
+ self.k_a = nn.Parameter(torch.zeros(self.key_dim))
72
+ self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
73
+
74
+ self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
75
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
76
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
77
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
78
+
79
+ self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh')
80
+ if self.layer_idx != 0:
81
+ self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None)
82
+ self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None)
83
+ self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False)
84
+
85
+ if self.fuse_norm:
86
+ self.g_norm = GroupNorm(
87
+ num_groups=self.num_heads,
88
+ hidden_size=self.value_dim,
89
+ elementwise_affine=elementwise_affine,
90
+ eps=self.head_dim*norm_eps,
91
+ bias=True,
92
+ )
93
+ else:
94
+ self.g_norm = nn.GroupNorm(
95
+ num_groups=self.num_heads,
96
+ num_channels=self.value_dim,
97
+ eps=self.head_dim*norm_eps,
98
+ affine=elementwise_affine
99
+ )
100
+
101
+ self.apply(self._initialize_weights)
102
+
103
+ def _initialize_weights(self, module: nn.Module):
104
+ if getattr(module, "_is_hf_initialized", False):
105
+ return
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ if isinstance(module, nn.Parameter):
111
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
112
+ module._is_hf_initialized = True
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ past_key_values: Optional[Cache] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = False,
121
+ v_first: torch.Tensor = None,
122
+ **kwargs
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
124
+ if attention_mask is not None:
125
+ assert len(attention_mask.shape) == 2, (
126
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
127
+ "for padding purposes (0 indicating padding). "
128
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
129
+ )
130
+
131
+ batch_size, seq_len, _ = hidden_states.shape
132
+
133
+ if self.training:
134
+ # if training, use chunk mode no matter how short the sequence is
135
+ mode = 'chunk'
136
+ else:
137
+ # launching the triton kernel for just one token will actually be slower
138
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
139
+
140
+ last_state = None
141
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
142
+ last_state = past_key_values[self.layer_idx]
143
+
144
+ if attention_mask is not None:
145
+ hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
146
+ if hidden_states.shape[1] == 1 and last_state is not None:
147
+ shifted = last_state['conv_state'].unsqueeze(1)
148
+ else:
149
+ shifted = self.time_shift(hidden_states)
150
+ if last_state is not None:
151
+ shifted[:, 0] = last_state['conv_state']
152
+
153
+ # [batch_size, seq_len, hidden_size]
154
+ delta = shifted - hidden_states
155
+ xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
156
+
157
+ r = self.r_proj(xr)
158
+ # -math.exp(-0.5) = -0.6065306597126334
159
+ # I think .to(torch.float) is unnecessary here, since we calculate lora in bloat16
160
+ # when we apply sigmoid, bf16 input will not have numerical issue
161
+ # FIXME: check if we can remove .to(torch.float)
162
+ w = -0.6065306597126334 * self.w_lora(xw).to(torch.float).sigmoid()
163
+
164
+ k = self.k_proj(xk)
165
+ v = self.v_proj(xv)
166
+
167
+ if self.layer_idx == 0:
168
+ v_first = v
169
+ else:
170
+ v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid())
171
+ a = self.a_lora(xa).sigmoid()
172
+ g = self.g_lora(xg)
173
+
174
+ if self.fuse_norm:
175
+ kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim))
176
+ else:
177
+ kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0)
178
+
179
+ k = k.addcmul(k * (a - 1), self.k_a)
180
+
181
+ # dealing with left-padding
182
+ if attention_mask is not None:
183
+ v = v * attention_mask[:, -v.shape[-2]:, None]
184
+ r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
185
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
186
+
187
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
188
+
189
+ rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
190
+ cu_seqlens = kwargs.get('cu_seqlens', None)
191
+ o, recurrent_state = rwkv7_fn(
192
+ r=r,
193
+ w=w,
194
+ k=k,
195
+ v=v,
196
+ a=-kk,
197
+ b=kk * a,
198
+ scale=1.,
199
+ initial_state=recurrent_state,
200
+ output_final_state=use_cache,
201
+ cu_seqlens=cu_seqlens,
202
+ head_first=False
203
+ )
204
+
205
+ if past_key_values is not None:
206
+ past_key_values.update(
207
+ recurrent_state=recurrent_state,
208
+ conv_state=hidden_states[:, -1],
209
+ layer_idx=self.layer_idx,
210
+ offset=r.shape[1]
211
+ )
212
+
213
+ if self.fuse_norm:
214
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)'))
215
+ else:
216
+ o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1)
217
+
218
+ o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1)
219
+ o = self.o_proj(o * g)
220
+
221
+ return o, None, past_key_values, v_first
fla/utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import contextlib
4
+ import functools
5
+ import os
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ from packaging import version
13
+
14
+
15
+ def tensor_cache(
16
+ fn: Callable[..., torch.Tensor]
17
+ ) -> Callable[..., torch.Tensor]:
18
+ """
19
+ A decorator that caches the most recent result of a function with tensor inputs.
20
+
21
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
22
+ If the function is called again with the same input tensors, it will return the cached result.
23
+
24
+
25
+ Args:
26
+ fn (Callable[..., torch.Tensor]):
27
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
28
+
29
+ Returns:
30
+ Callable[..., torch.Tensor]:
31
+ A wrapped version of the input function with single-entry caching.
32
+ """
33
+ last_args: Optional[Tuple] = None
34
+ last_kwargs: Optional[Dict] = None
35
+ last_result: Any = None
36
+
37
+ @functools.wraps(fn)
38
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
39
+ nonlocal last_args, last_kwargs, last_result
40
+
41
+ if last_args is not None and last_kwargs is not None:
42
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
43
+ if all(a is b for a, b in zip(args, last_args)) and \
44
+ all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
45
+ return last_result
46
+
47
+ result = fn(*args, **kwargs)
48
+ last_args, last_kwargs, last_result = args, kwargs, result
49
+ return result
50
+
51
+ return wrapper
52
+
53
+
54
+ def input_guard(
55
+ fn: Callable[..., torch.Tensor]
56
+ ) -> Callable[..., torch.Tensor]:
57
+ """
58
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
59
+ """
60
+
61
+ @functools.wraps(fn)
62
+ def wrapper(*args, **kwargs):
63
+ contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
64
+ contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
65
+
66
+ tensor = None
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ tensor = arg
70
+ break
71
+ if tensor is None:
72
+ for value in kwargs.values():
73
+ if isinstance(value, torch.Tensor):
74
+ tensor = value
75
+ break
76
+
77
+ if tensor is not None:
78
+ ctx = custom_device_ctx(tensor.device.index)
79
+ else:
80
+ ctx = contextlib.nullcontext()
81
+
82
+ with ctx:
83
+ return fn(*contiguous_args, **contiguous_kwargs)
84
+
85
+ return wrapper
86
+
87
+
88
+ contiguous = input_guard
89
+
90
+
91
+ def require_version(version, hint):
92
+ """
93
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
94
+ """
95
+ def decorator(fn):
96
+ @functools.wraps(fn)
97
+ def wrapper(ctx, *args, **kwargs):
98
+ from transformers.utils.versions import require_version
99
+ require_version(version, hint)
100
+ return fn(ctx,
101
+ *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
102
+ **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
103
+ return wrapper
104
+ return decorator
105
+
106
+
107
+ def checkpoint(fn):
108
+ def wrapper(*args, **kwargs):
109
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
110
+ return wrapper
111
+
112
+
113
+ @lru_cache(maxsize=None)
114
+ def check_pytorch_version(version_s: str = '2.4') -> bool:
115
+ return version.parse(torch.__version__) >= version.parse(version_s)
116
+
117
+
118
+ def _cpu_device_warning():
119
+ import warnings
120
+ warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
125
+ try:
126
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count']
127
+ except BaseException:
128
+ _cpu_device_warning()
129
+ return -1
130
+
131
+
132
+ @lru_cache(maxsize=None)
133
+ def get_available_device() -> str:
134
+ try:
135
+ return triton.runtime.driver.active.get_current_target().backend
136
+ except BaseException:
137
+ _cpu_device_warning()
138
+ return 'cpu'
139
+
140
+
141
+ @lru_cache(maxsize=None)
142
+ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
143
+ device = get_available_device()
144
+ if device == 'cuda':
145
+ return 'nvidia'
146
+ elif device == 'hip':
147
+ return 'amd'
148
+ elif device == 'xpu':
149
+ return 'intel'
150
+ else:
151
+ return device
152
+
153
+
154
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
155
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
156
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
157
+ device = get_available_device() if get_available_device() != 'hip' else 'cuda'
158
+ device_torch_lib = getattr(torch, device)
159
+ device_platform = _check_platform()
160
+
161
+ is_amd = (device_platform == 'amd')
162
+ is_intel = (device_platform == 'intel')
163
+ is_nvidia = (device_platform == 'nvidia')
164
+ is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
165
+ is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
166
+ use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
167
+
168
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
169
+ is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
170
+ is_gather_supported = hasattr(triton.language, 'gather')
171
+
172
+
173
+ def get_all_max_shared_mem():
174
+ try:
175
+ return [
176
+ triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
177
+ for i in range(device_torch_lib.device_count())
178
+ ]
179
+ except BaseException:
180
+ _cpu_device_warning()
181
+ return [-1]
182
+
183
+
184
+ class Backend(Enum):
185
+ ADA = 101376 # RTX 4090
186
+ AMPERE = 166912 # A100
187
+ HOPPER = 232448 # H100
188
+ DEFAULT = 102400 # Default
189
+
190
+ @classmethod
191
+ def get_shared_memory(cls, arch: str) -> int:
192
+ try:
193
+ return cls[arch.upper()].value
194
+ except KeyError:
195
+ return cls.DEFAULT.value
196
+
197
+
198
+ @lru_cache(maxsize=None)
199
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
200
+ try:
201
+ device_shared_mem_list = get_all_max_shared_mem()
202
+ max_shared_memory = device_shared_mem_list[tensor_idx]
203
+ return max_shared_memory >= Backend.get_shared_memory(arch)
204
+ except Exception:
205
+ return False
206
+
207
+
208
+ if check_pytorch_version('2.4'):
209
+ device = 'cuda' if device == 'cpu' else device
210
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
211
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
212
+
213
+ def custom_device_ctx(index: int):
214
+ return device_torch_lib.device(index)
215
+ else:
216
+ assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
217
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
218
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
219
+
220
+ def custom_device_ctx(index: int):
221
+ return torch.cuda.device(index)
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/components/__init__.py ADDED
File without changes
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/config_manager.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.implementation",
188
+ type=str,
189
+ default="fused",
190
+ choices=["for-loop", "foreach", "fused"],
191
+ help="""
192
+ Specify which optimizer implementation to use:
193
+ - 'fused': Use fused implementation (CUDA only) for best performance.
194
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
195
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
196
+ - more info: https://pytorch.org/docs/stable/optim.html
197
+ """,
198
+ )
199
+ self.parser.add_argument(
200
+ "--optimizer.early_step_in_backward",
201
+ action="store_true",
202
+ help="""
203
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
204
+ is not compatible with gradients clipping, users should not call
205
+ register_post_accumulate_grad_hook after the optimizer is built.""",
206
+ )
207
+
208
+ # lr scheduler configs
209
+ self.parser.add_argument(
210
+ "--lr_scheduler.warmup_steps",
211
+ type=int,
212
+ default=200,
213
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
214
+ )
215
+ self.parser.add_argument(
216
+ "--lr_scheduler.decay_ratio",
217
+ type=float,
218
+ default=None,
219
+ help="""
220
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
221
+
222
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
223
+ Otherwise, the learning rate will remain stable after the warmup period and
224
+ only start decaying during the last `decay_ratio` portion of the total training steps.
225
+
226
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
227
+ """,
228
+ )
229
+ self.parser.add_argument(
230
+ "--lr_scheduler.decay_type",
231
+ type=str,
232
+ default="linear",
233
+ choices=["linear", "sqrt", "cosine"],
234
+ help="""
235
+ Learning rate decay type to use during training:
236
+ - 'linear': linearly decays learning rate from initial to final value
237
+ - 'sqrt': decays learning rate following a 1 minus square root curve
238
+ - 'cosine': smoothly decays learning rate following a cosine curve
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.lr_min",
243
+ type=float,
244
+ default=0.0,
245
+ help="""
246
+ Min lr ratio for lr scheduler.
247
+
248
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
249
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
250
+ """,
251
+ )
252
+
253
+ # training configs
254
+ self.parser.add_argument(
255
+ "--training.batch_size", type=int, default=8, help="Batch size"
256
+ )
257
+ self.parser.add_argument(
258
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
259
+ )
260
+ self.parser.add_argument(
261
+ "--training.context_len",
262
+ type=int,
263
+ default=2048,
264
+ help="Max length allowed for each sequence",
265
+ )
266
+ self.parser.add_argument(
267
+ "--training.varlen",
268
+ action="store_true",
269
+ help="Whether to take sequences of variable length as input",
270
+ )
271
+ self.parser.add_argument(
272
+ "--training.gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of steps to accumulate gradients before updating parameters",
276
+ )
277
+ self.parser.add_argument(
278
+ "--training.steps",
279
+ type=int,
280
+ default=10000,
281
+ help="How many train steps to run",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.max_norm",
285
+ type=float,
286
+ default=1.0,
287
+ help="Max norm for gradient clipping",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.skip_nan_inf",
291
+ action="store_true",
292
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
293
+ )
294
+ self.parser.add_argument(
295
+ "--training.dataset",
296
+ default="HuggingFaceFW/fineweb-edu",
297
+ help="Dataset to use, with comma separated values",
298
+ )
299
+ self.parser.add_argument(
300
+ "--training.dataset_name",
301
+ default=None,
302
+ help="The name of the dataset config, with comma separated values if provided",
303
+ )
304
+ self.parser.add_argument(
305
+ "--training.dataset_split",
306
+ default=None,
307
+ help="Dataset split to use, with comma separated values if provided",
308
+ )
309
+ self.parser.add_argument(
310
+ "--training.data_dir",
311
+ default=None,
312
+ help="Data dirs to use, with comma separated values if provided",
313
+ )
314
+ self.parser.add_argument(
315
+ "--training.data_files",
316
+ default=None,
317
+ help="Data files to use, with comma separated values if provided",
318
+ )
319
+ self.parser.add_argument(
320
+ "--training.data_probs",
321
+ default=None,
322
+ help="Data sampling probabilities, with comma separated values if provided",
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.streaming",
326
+ action="store_true",
327
+ help="Whether to load dataset in streaming mode, used for huge dataset",
328
+ )
329
+ self.parser.add_argument(
330
+ "--training.num_workers",
331
+ type=int,
332
+ default=32,
333
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
334
+ )
335
+ self.parser.add_argument(
336
+ "--training.prefetch_factor",
337
+ type=int,
338
+ default=2,
339
+ help="Number of batches loaded in advance by each worker."
340
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
341
+ )
342
+ self.parser.add_argument(
343
+ "--training.data_parallel_replicate_degree",
344
+ type=int,
345
+ default=1,
346
+ help="""
347
+ The `data_parallel_replicate_degree` argument specifies the degree of
348
+ data parallelism for weight replication. When this value is greater
349
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
350
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
351
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
352
+ parallelism method used is DDP (Distributed Data Parallelism).
353
+ 1 means disabled.""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.data_parallel_shard_degree",
357
+ type=int,
358
+ default=-1,
359
+ help="""
360
+ The `data_parallel_shard_degree` argument specifies the degree of data
361
+ parallelism for weight sharding. When this value is greater than 1, weights
362
+ will be sharded across `data_parallel_shard_degree` ranks. If
363
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
364
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
365
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
366
+
367
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
368
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
369
+ )
370
+ self.parser.add_argument(
371
+ "--training.enable_cpu_offload",
372
+ action="store_true",
373
+ help="""
374
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
375
+ )
376
+ self.parser.add_argument(
377
+ "--training.tensor_parallel_degree",
378
+ type=int,
379
+ default=1,
380
+ help="Tensor Parallelism degree. 1 means disabled.",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.disable_loss_parallel",
384
+ action="store_true",
385
+ help="Whether to apply loss parallel when sequence parallel is enabled",
386
+ )
387
+ self.parser.add_argument(
388
+ "--training.fsdp_reshard_after_forward",
389
+ type=str,
390
+ default="default",
391
+ choices=["default", "always", "never"],
392
+ help="""
393
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
394
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
395
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
396
+ on `reshard_after_forward`.
397
+ The supported policies include "default", "always" and "never":
398
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
399
+ scenarios.
400
+ - "always" will enable `reshard_after_forward` for all forward passes.
401
+ - "never" will disable `reshard_after_forward` for all forward passes.
402
+ """,
403
+ )
404
+ self.parser.add_argument(
405
+ "--training.mixed_precision_param",
406
+ type=str,
407
+ default="bfloat16",
408
+ choices=["bfloat16", "float32"],
409
+ help="""
410
+ torch dtype to use for parameters when applying mixed precision via FSDP.
411
+ This feature only takes effect when data_parallel_shard_degree > 1
412
+ """,
413
+ )
414
+ self.parser.add_argument(
415
+ "--training.mixed_precision_reduce",
416
+ type=str,
417
+ default="float32",
418
+ choices=["float32"],
419
+ help="""
420
+ torch dtype to use for reductions when applying mixed precision via FSDP.
421
+ This feature only takes effect when data_parallel_shard_degree > 1
422
+ """,
423
+ )
424
+ self.parser.add_argument(
425
+ "--training.compile",
426
+ action="store_true",
427
+ help="Whether to compile the model",
428
+ )
429
+ self.parser.add_argument(
430
+ "--training.gc_freq",
431
+ type=int,
432
+ default=50,
433
+ help="Python garbage control scheduling interval, in steps",
434
+ )
435
+ self.parser.add_argument(
436
+ "--training.seed",
437
+ type=int,
438
+ default=42,
439
+ help="Choose the base RNG seed used for training",
440
+ )
441
+ self.parser.add_argument(
442
+ "--training.deterministic",
443
+ action="store_true",
444
+ help="Use deterministic algorithms wherever possible, may be slower",
445
+ )
446
+ # metrics configs
447
+ self.parser.add_argument(
448
+ "--metrics.log_freq",
449
+ type=int,
450
+ default=10,
451
+ help="How often to log metrics to TensorBoard, in iterations",
452
+ )
453
+ self.parser.add_argument(
454
+ "--metrics.enable_tensorboard",
455
+ action="store_true",
456
+ help="Whether to log metrics to TensorBoard",
457
+ )
458
+ self.parser.add_argument(
459
+ "--metrics.disable_color_printing",
460
+ action="store_true",
461
+ help="Whether to disable color printing in logs",
462
+ )
463
+ self.parser.add_argument(
464
+ "--metrics.save_tb_folder",
465
+ type=str,
466
+ default="tb",
467
+ help="Folder to dump TensorBoard states",
468
+ )
469
+ self.parser.add_argument(
470
+ "--metrics.save_for_all_ranks",
471
+ action="store_true",
472
+ default=False,
473
+ help="""
474
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
476
+ component uses the 0th rank of the last stage pipeline group, which is the
477
+ only stage that computes loss metrics.
478
+ """,
479
+ )
480
+ self.parser.add_argument(
481
+ "--metrics.enable_wandb",
482
+ action="store_true",
483
+ help="Whether to log metrics to Weights & Biases",
484
+ )
485
+
486
+ self.parser.add_argument(
487
+ "--experimental.enable_async_tensor_parallel",
488
+ action="store_true",
489
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
490
+ )
491
+ self.parser.add_argument(
492
+ "--experimental.pipeline_parallel_degree",
493
+ type=int,
494
+ default=1,
495
+ help="""
496
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
497
+ If using looped schedules, this still specifies the number of physical ranks, not the number
498
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
499
+ )
500
+ self.parser.add_argument(
501
+ "--experimental.pipeline_parallel_split_points",
502
+ type=string_list,
503
+ nargs="+",
504
+ default=[],
505
+ help="""
506
+ Specify comma-separated names of modules to use as the beginning of a split point.
507
+
508
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
509
+ the first containing all the layers up to layers.0,
510
+ the second containing layers.0 and up to layers.2,
511
+ the third containing layers.2 and all the remaining layers.
512
+
513
+ Note: fully-automated splitting may be enabled in the future,
514
+ but currently the split points must be specified manually.""",
515
+ )
516
+ self.parser.add_argument(
517
+ "--experimental.pipeline_parallel_schedule",
518
+ type=str,
519
+ default="1F1B",
520
+ help="""
521
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
522
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
523
+ The schedule must be compatible with the split points and stages_per_rank.
524
+
525
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
526
+ and split_points = number of stages - 1
527
+ """,
528
+ )
529
+ self.parser.add_argument(
530
+ "--experimental.pipeline_parallel_schedule_csv",
531
+ type=str,
532
+ default="",
533
+ help="""
534
+ Specify the path to the pipeline parallel schedule csv file to use.
535
+ The pipeline_parallel_schedule argument must be either
536
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
537
+ """,
538
+ )
539
+
540
+ self.parser.add_argument(
541
+ "--experimental.pipeline_parallel_microbatches",
542
+ type=int,
543
+ default=None,
544
+ help="""
545
+ How many microbatches to split the global training batch into when using pipeline parallelism.
546
+
547
+ The global training batch size must be evenly divisible by the number of microbatches.
548
+
549
+ The default value will be the number of pipeline stages, if unspecified.
550
+ """,
551
+ )
552
+ self.parser.add_argument(
553
+ "--experimental.enable_compiled_autograd",
554
+ action="store_true",
555
+ help="Enable CompiledAutograd to compile the backward.",
556
+ )
557
+ self.parser.add_argument(
558
+ "--experimental.context_parallel_degree",
559
+ type=int,
560
+ default=1,
561
+ help="Context parallelism degree. 1 means disabled.",
562
+ )
563
+ self.parser.add_argument(
564
+ "--experimental.context_parallel_rotate_method",
565
+ type=str,
566
+ default="allgather",
567
+ help="""
568
+ The collective to use in context parallel SDPA for kv shards exchange.
569
+
570
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
571
+
572
+ 'alltoall' means to all-to-all shuffle the kv shards.
573
+
574
+ The default value is 'allgather'.
575
+ """,
576
+ )
577
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
578
+ # module and import TorchTitan training loop and execute it, which look cleaner.
579
+ # One reason to provide this option is to allow users to use the existing run script.
580
+ # While the script is pretty trivial now, we may add more logic when integrating
581
+ # with TorchFT.
582
+ # This option is subject to change and may be deleted in the future.
583
+ self.parser.add_argument(
584
+ "--experimental.custom_model_path",
585
+ type=str,
586
+ default="",
587
+ help="""
588
+ The --custom_model_path option allows to specify a custom path to a model module
589
+ that is not natively implemented within TorchTitan.
590
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
591
+ dotted import module (e.g., some_package.model_x).
592
+ """,
593
+ )
594
+ # checkpointing configs
595
+ self.parser.add_argument(
596
+ "--checkpoint.enable_checkpoint",
597
+ action="store_true",
598
+ help="Whether to enable checkpoint",
599
+ )
600
+ self.parser.add_argument(
601
+ "--checkpoint.folder",
602
+ type=str,
603
+ default="checkpoint",
604
+ help="""
605
+ The folder to store the checkpoints.
606
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607
+ """,
608
+ )
609
+ self.parser.add_argument(
610
+ "--checkpoint.interval",
611
+ type=int,
612
+ default=500,
613
+ help="Checkpointing interval in steps.",
614
+ )
615
+ self.parser.add_argument(
616
+ "--checkpoint.model_weights_only",
617
+ action="store_true",
618
+ help="""
619
+ When model_weights_only=True, only model weights will be saved at the end of training.
620
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621
+ When model_weights_only=False, the full checkpoint will be saved.
622
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623
+ The default value is false.
624
+ """,
625
+ )
626
+ self.parser.add_argument(
627
+ "--checkpoint.export_dtype",
628
+ type=str,
629
+ default="float32",
630
+ choices=["float16", "bfloat16", "float32"],
631
+ help="""
632
+ Converts to the specified precision when training completes and model_weights_only=true.
633
+ Currently supports float32, float16, and bfloat16.
634
+ The default value is float32.
635
+ """,
636
+ )
637
+ self.parser.add_argument(
638
+ "--checkpoint.create_seed_checkpoint",
639
+ action="store_true",
640
+ help="""
641
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
642
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
643
+ Could be implemented as a separate script, but this way shares more code.
644
+ """,
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.async_mode",
648
+ type=str,
649
+ default="disabled",
650
+ help="""
651
+ Which async checkpoint mode to use. Currently there are 3 different modes.
652
+ 1. "disabled": synchronized checkpointing will be used.
653
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
654
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
655
+ space and creates a separate process for faster GPU->CPU transfer
656
+ performance and eliminating GIL contention. The cost is increased CPU
657
+ memory usage. If insufficient CPU memory is available, performance may
658
+ degrade due to memory paging. For most users, "async" should suffice as
659
+ the performance overhead is typically small (on the order of tens of
660
+ seconds) compared to checkpointing frequency. This mode can be employed
661
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
662
+ appropriate hardware support such as ample CPU memory and fast PCIe.
663
+
664
+ "disabled" is the default mode.
665
+ """,
666
+ )
667
+ self.parser.add_argument(
668
+ "--checkpoint.keep_latest_k",
669
+ type=int,
670
+ default=0,
671
+ help="""
672
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
673
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
674
+ saved. As a result, the metadata of the last one may not be ready yet.
675
+ """,
676
+ )
677
+ self.parser.add_argument(
678
+ "--checkpoint.load_step",
679
+ type=int,
680
+ default=-1,
681
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
682
+ )
683
+ self.parser.add_argument(
684
+ "--checkpoint.exclude_from_loading",
685
+ type=string_list,
686
+ nargs="*",
687
+ default=[],
688
+ help="""
689
+ Exclude specific keys from being loaded from the checkpoint.
690
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
691
+ This will load the model only, excluding the specified keys.
692
+ """,
693
+ )
694
+ self.parser.add_argument(
695
+ "--checkpoint.convert_to_hf_on_save",
696
+ action="store_true",
697
+ help="""
698
+ If true, automatically convert the saved DCP checkpoint to Hugging Face format
699
+ in a parallel directory (e.g., step-1000-hf) after each save.
700
+ """,
701
+ )
702
+ self.parser.add_argument(
703
+ "--checkpoint.hf_upload_enabled",
704
+ action="store_true",
705
+ help="Enable uploading converted Hugging Face checkpoints to the Hub.",
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.hf_repo_base_name",
709
+ type=str,
710
+ default=None,
711
+ help="Hugging Face Hub repository ID to upload checkpoints to (e.g., 'username/repo').",
712
+ )
713
+ self.parser.add_argument(
714
+ "--checkpoint.hf_upload_format",
715
+ type=str,
716
+ default="dcp",
717
+ choices=["dcp", "hf"],
718
+ help="""
719
+ Format to upload to Hugging Face Hub. 'dcp' for DCP format, 'hf' for Hugging Face format.
720
+ Note: 'hf' is only supported for models with a single pipeline stage.
721
+ """,
722
+ )
723
+ # activation checkpointing configs
724
+ self.parser.add_argument(
725
+ "--activation_checkpoint.mode",
726
+ type=str,
727
+ default="selective",
728
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
729
+ )
730
+ self.parser.add_argument(
731
+ "--activation_checkpoint.selective_ac_option",
732
+ type=str,
733
+ default="2", # 2 = checkpoint every other layer
734
+ help="""
735
+ Selective activation checkpointing options ['int', 'op'].
736
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
737
+ """,
738
+ )
739
+
740
+ self.parser.add_argument(
741
+ "--activation_offload.mode",
742
+ type=str,
743
+ default="none",
744
+ help="""
745
+ if we are using activation offload or not. Options are ['none', 'full'].
746
+ """,
747
+ )
748
+
749
+ # float8 configs
750
+ self.parser.add_argument(
751
+ "--float8.enable_fsdp_float8_all_gather",
752
+ action="store_true",
753
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
754
+ )
755
+ self.parser.add_argument(
756
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
757
+ action="store_true",
758
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
759
+ )
760
+ self.parser.add_argument(
761
+ "--float8.force_recompute_fp8_weight_in_bwd",
762
+ action="store_true",
763
+ help="""
764
+ Whether to force the recomputation of FP8 weights during backward pass.
765
+ When using FSDP with tensorwise scaling, it is recommended to enable
766
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
767
+ for backward computation.
768
+ """,
769
+ )
770
+ self.parser.add_argument(
771
+ "--float8.recipe_name",
772
+ type=str,
773
+ default=None,
774
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
775
+ help="""
776
+ If specified, creates float8 config from recipe name, valid choices are
777
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
778
+ """,
779
+ )
780
+
781
+ # communications library settings
782
+ self.parser.add_argument(
783
+ "--comm.init_timeout_seconds",
784
+ type=int,
785
+ default=300,
786
+ help="Timeout for communication operations, during initialization and first train step.",
787
+ )
788
+ self.parser.add_argument(
789
+ "--comm.train_timeout_seconds",
790
+ type=int,
791
+ default=100,
792
+ help=(
793
+ "Timeout for communication operations after the first train step -- "
794
+ "usually a tighter bound than during initialization."
795
+ ),
796
+ )
797
+ self.parser.add_argument(
798
+ "--comm.trace_buf_size",
799
+ type=int,
800
+ default=20000,
801
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
802
+ )
803
+
804
+ # memory estimation settings
805
+ self.parser.add_argument(
806
+ "--memory_estimation.enabled",
807
+ help="Whether to estimate memory usage for FSDP",
808
+ action="store_true",
809
+ )
810
+
811
+ self.parser.add_argument(
812
+ "--memory_estimation.disable_fake_mode",
813
+ help="Whether to estimate memory under FakeTensorMode",
814
+ action="store_true",
815
+ )
816
+
817
+ self.parser.add_argument(
818
+ "--fault_tolerance.enable",
819
+ action="store_true",
820
+ help="""
821
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
822
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
823
+ --fault_tolerance.group_size will be used to control the maximum
824
+ replicate group size as the replicate group size is dynamic.
825
+
826
+ Note that this is still an experimental feature.
827
+ """,
828
+ )
829
+
830
+ self.parser.add_argument(
831
+ "--fault_tolerance.replica_id",
832
+ type=int,
833
+ default=0,
834
+ help="The TorchFT replica ID of this run.",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.group_size",
839
+ type=int,
840
+ default=0,
841
+ help="""
842
+ The number of TorchFT replicate groups. This number will be used for
843
+ dataloader to split the dataset across the replicate groups and FSDP
844
+ dimension
845
+ """,
846
+ )
847
+
848
+ self.parser.add_argument(
849
+ "--fault_tolerance.min_replica_size",
850
+ type=int,
851
+ default=1,
852
+ help="The minimum number of FT replica for each step.",
853
+ )
854
+
855
+ def to_dict(self):
856
+ return self.args_dict
857
+
858
+ def parse_args(self, args_list: list = sys.argv[1:]):
859
+ args, cmd_args = self.parse_args_from_command_line(args_list)
860
+ config_file = getattr(args, "job.config_file", None)
861
+ # build up a two level dict
862
+ args_dict = self._args_to_two_level_dict(args)
863
+ if config_file is not None:
864
+ try:
865
+ with open(config_file, "rb") as f:
866
+ for k, v in tomllib.load(f).items():
867
+ # to prevent overwrite of non-specified keys
868
+ args_dict[k] |= v
869
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
870
+ logger.exception(
871
+ f"Error while loading the configuration file: {config_file}"
872
+ )
873
+ logger.exception(f"Error details: {str(e)}")
874
+ raise e
875
+
876
+ # Checking string-list arguments are properly split into a list
877
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
878
+ string_list_argnames = self._get_string_list_argument_names()
879
+ for n in string_list_argnames:
880
+ check_string_list_argument(args_dict, n)
881
+
882
+ # override args dict with cmd_args
883
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
884
+ for section, section_args in cmd_args_dict.items():
885
+ for k, v in section_args.items():
886
+ args_dict[section][k] = v
887
+
888
+ self.args_dict = args_dict
889
+
890
+ for k, v in args_dict.items():
891
+ class_type = type(k.title(), (), v)
892
+ setattr(self, k, class_type())
893
+ self._validate_config()
894
+
895
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
896
+ args_dict = defaultdict(defaultdict)
897
+ for k, v in vars(args).items():
898
+ first_level_key, second_level_key = k.split(".", 1)
899
+ args_dict[first_level_key][second_level_key] = v
900
+ return args_dict
901
+
902
+ def _validate_config(self) -> None:
903
+ # TODO: Add more mandatory validations
904
+ assert self.model.config
905
+ assert self.model.tokenizer_path
906
+
907
+ def _get_string_list_argument_names(self) -> list[str]:
908
+ """Get the parser argument names of type `string_list`."""
909
+ string_list_args = [
910
+ v.dest for v in self.parser._actions if v.type is string_list
911
+ ]
912
+ return string_list_args
913
+
914
+ def parse_args_from_command_line(
915
+ self, args_list
916
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
917
+ """
918
+ Parse command line arguments and return the parsed args and the command line only args
919
+ """
920
+ args = self.parser.parse_args(args_list)
921
+ string_list_argnames = set(self._get_string_list_argument_names())
922
+
923
+ # aux parser to parse the command line only args, with no defaults from main parser
924
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
925
+ for arg, val in vars(args).items():
926
+ if isinstance(val, bool):
927
+ aux_parser.add_argument(
928
+ "--" + arg, action="store_true" if val else "store_false"
929
+ )
930
+ elif arg in string_list_argnames:
931
+ # without this special case, type inference breaks here,
932
+ # since the inferred type is just 'list' and it ends up flattening
933
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
934
+ aux_parser.add_argument("--" + arg, type=string_list)
935
+ else:
936
+ aux_parser.add_argument("--" + arg, type=type(val))
937
+
938
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
939
+
940
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class BufferShuffledIterableDataset(IterableDataset):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ tokenizer: PreTrainedTokenizer,
28
+ seq_len: int = 2048,
29
+ rank: int = 0,
30
+ world_size: int = 1,
31
+ buffer_size: int = 1024,
32
+ ) -> BufferShuffledIterableDataset:
33
+ self.dataset = dataset
34
+ self.tokenizer = tokenizer
35
+
36
+ self.data = dataset.shard(world_size, rank)
37
+ self.seq_len = seq_len
38
+
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.buffer_size = buffer_size
42
+
43
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
+ self.dtype = torch.int16
45
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
+ self.dtype = torch.int32
47
+ else:
48
+ self.dtype = torch.int64
49
+ self.states = None
50
+ self.buffer = torch.tensor([], dtype=self.dtype)
51
+ self.tokens = []
52
+ self.rand_id = 0
53
+ self.token_id = 0
54
+ self.rng_state = None
55
+ self._epoch = 0
56
+
57
+ def __iter__(self):
58
+ g = torch.Generator()
59
+ g.manual_seed(self._epoch + self.rank)
60
+ if self.rng_state is not None:
61
+ g.set_state(self.rng_state)
62
+
63
+ rand_it = self.randint(0, self.buffer_size, g=g)
64
+ if self.states is not None:
65
+ self.data.load_state_dict(self.states)
66
+
67
+ # max number of tokens allowed in the chunk buffer
68
+ n_tokens = self.buffer_size * self.seq_len
69
+
70
+ while True:
71
+ for sample in self.tokenize(self.data):
72
+ # keep appending the samples to the token buffer
73
+ self.tokens += sample
74
+ # if the token buffer is full, start sampling
75
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
76
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
77
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
78
+ self.tokens = self.tokens[n_tokens:]
79
+ if len(self.buffer) == self.buffer_size:
80
+ yield from self.sample(rand_it)
81
+
82
+ n_chunks = len(self.tokens) // self.seq_len
83
+ # handle the left tokens in the buffer
84
+ if n_chunks > 0:
85
+ n_tokens = n_chunks * self.seq_len
86
+ indices = torch.randperm(n_chunks, generator=g).tolist()
87
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
88
+ self.tokens = self.tokens[n_tokens:]
89
+ for i in indices:
90
+ yield {'input_ids': self.buffer[i]}
91
+
92
+ def tokenize(self, data, batch_size: int = 64):
93
+ texts, states = [], []
94
+ for sample in data:
95
+ texts.append(sample['text'])
96
+ states.append(self.data.state_dict())
97
+ if len(texts) == batch_size:
98
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
99
+ self.states = s
100
+ yield tokenized
101
+ texts, states = [], []
102
+ if len(texts) > 0:
103
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
104
+ self.states = s
105
+ yield tokenized
106
+
107
+ def sample(self, indices):
108
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
109
+ while self.token_id < n_tokens:
110
+ i = next(indices)
111
+ start, end = self.token_id, self.token_id + self.seq_len
112
+ self.token_id += self.seq_len
113
+ yield {'input_ids': self.buffer[i].to(torch.long)}
114
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
115
+ self.token_id = 0
116
+ self.tokens = self.tokens[n_tokens:]
117
+
118
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
119
+ indices = torch.empty(buffer_size, dtype=torch.long)
120
+ while True:
121
+ # record the generator states before sampling
122
+ self.rng_state = g.get_state()
123
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
124
+ for i in indices[self.rand_id:].tolist():
125
+ self.rand_id += 1
126
+ yield i
127
+ self.rand_id = 0
128
+
129
+ def set_epoch(self, epoch):
130
+ self._epoch = epoch
131
+ if hasattr(self.dataset, 'set_epoch'):
132
+ self.dataset.set_epoch(epoch)
133
+
134
+ def state_dict(self):
135
+ return {
136
+ 'states': self.states,
137
+ 'buffer': self.buffer.clone(),
138
+ 'tokens': deepcopy(self.tokens),
139
+ 'rand_id': self.rand_id,
140
+ 'token_id': self.token_id,
141
+ 'rng_state': self.rng_state,
142
+ 'epoch': self._epoch,
143
+ }
144
+
145
+ def load_state_dict(self, state_dict):
146
+ self.states = state_dict['states']
147
+ self.buffer = state_dict['buffer'].clone()
148
+ self.tokens = deepcopy(state_dict['tokens'])
149
+ self.rand_id = state_dict['rand_id']
150
+ self.token_id = state_dict['token_id']
151
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
152
+ self._epoch = state_dict['epoch']
153
+
154
+
155
+ class OnlineTokenizedIterableDataset(IterableDataset):
156
+ def __init__(
157
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
158
+ ) -> OnlineTokenizedIterableDataset:
159
+ self.dataset = dataset
160
+ self.tokenizer = tokenizer
161
+
162
+ self.data = dataset.shard(world_size, rank)
163
+ self.seq_len = seq_len
164
+ self.rank = rank
165
+ self.world_size = world_size
166
+
167
+ self.states = None
168
+ self.tokens = []
169
+
170
+ def __iter__(self):
171
+ if self.states is not None:
172
+ self.data.load_state_dict(self.states)
173
+
174
+ while True:
175
+ for sample in self.tokenize(self.data):
176
+ # keep appending the samples to the token buffer
177
+ self.tokens += sample
178
+
179
+ while len(self.tokens) >= self.seq_len:
180
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
181
+ self.tokens = self.tokens[self.seq_len:]
182
+ yield {'input_ids': input_ids}
183
+
184
+ def tokenize(self, data, buffer_size: int = 64):
185
+ buffer, states = [], []
186
+ for sample in data:
187
+ if sample.get('text', None) is not None:
188
+ buffer.append(sample['text'])
189
+ elif sample.get('content', None) is not None:
190
+ buffer.append(sample['content'])
191
+ else:
192
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
193
+ states.append(self.data.state_dict())
194
+ if len(buffer) == buffer_size:
195
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
196
+ self.states = s
197
+ yield tokenized
198
+ buffer, states = [], []
199
+ if len(buffer) > 0:
200
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
201
+ self.states = s
202
+ yield tokenized
203
+
204
+ def state_dict(self):
205
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
206
+
207
+ def load_state_dict(self, state_dict):
208
+ self.states = state_dict['states']
209
+ self.tokens = deepcopy(state_dict['tokens'])
210
+
211
+
212
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def _init_state_dict(self) -> dict:
217
+ self._state_dict = self.ex_iterable._init_state_dict()
218
+ self._state_dict['mem_buffer'] = ([],)
219
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
220
+ self._state_dict['bit_generator_index_offset'] = 0
221
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
222
+ return self._state_dict
223
+
224
+ def __iter__(self):
225
+ buffer_size = self.buffer_size
226
+ rng = deepcopy(self.generator)
227
+ # this is the shuffle buffer that we keep in memory
228
+ mem_buffer = self._state_dict['mem_buffer'][0]
229
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
230
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
231
+ if self._state_dict:
232
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
233
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
234
+ # skip already consumed ones
235
+ for _ in range(index_offset):
236
+ i = next(indices_iterator)
237
+
238
+ for x in self.ex_iterable:
239
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
240
+ mem_buffer.append(x)
241
+ else: # otherwise, pick an example from it
242
+ i = next(indices_iterator)
243
+ index_offset = (index_offset + 1) % buffer_size
244
+ if self._state_dict:
245
+ self._state_dict['bit_generator_index_offset'] = index_offset
246
+ if index_offset == 0:
247
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
248
+ selected = mem_buffer[i]
249
+ mem_buffer[i] = x # replace the picked example by a new one
250
+ yield selected
251
+
252
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
253
+ if self._state_dict:
254
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
255
+
256
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
257
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
258
+ index_offset = index_offset + 1
259
+ if self._state_dict:
260
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
261
+ yield mem_buffer[i]
262
+
263
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
264
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
265
+ return BufferShuffledExamplesIterable(
266
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
267
+ )
268
+
269
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
270
+ """Keep only the requested shard."""
271
+ return BufferShuffledExamplesIterable(
272
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
273
+ buffer_size=self.buffer_size,
274
+ generator=self.generator,
275
+ )
276
+
277
+ def load_state_dict(self, state_dict: dict) -> dict:
278
+ def _inner_load_state_dict(state, new_state):
279
+ if new_state is not None and isinstance(state, dict):
280
+ for key in new_state:
281
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
282
+ return state
283
+ elif new_state is not None and isinstance(state, list):
284
+ for i in range(len(state)):
285
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
286
+ return state
287
+ return new_state
288
+
289
+ return _inner_load_state_dict(self._state_dict, state_dict)
290
+
291
+
292
+ def shuffle(
293
+ dataset: IterableDataset,
294
+ seed: int = 42,
295
+ generator: np.random.Generator = None,
296
+ buffer_size: int = 1024,
297
+ ):
298
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
299
+ return IterableDataset(
300
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
301
+ info=dataset._info.copy(),
302
+ split=dataset._split,
303
+ formatting=dataset._formatting,
304
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
305
+ distributed=copy.deepcopy(dataset._distributed),
306
+ token_per_repo_id=dataset._token_per_repo_id,
307
+ )
308
+
309
+
310
+ @dataclass
311
+ class DataCollatorForLanguageModeling:
312
+ """
313
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
314
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
315
+
316
+ Args:
317
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
318
+ The tokenizer used for encoding the data.
319
+ context_len (`int`, optional):
320
+ When `varlen=True`, sequences longer than this length within a document
321
+ (as determined by `cu_seqlens`) will be further chunked.
322
+ varlen (`bool`):
323
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
324
+
325
+ Returns:
326
+ A dictionary with the following keys:
327
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
328
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
329
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
330
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
331
+
332
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
333
+ """
334
+
335
+ tokenizer: PreTrainedTokenizer
336
+ context_len: Optional[int] = None
337
+ varlen: bool = False
338
+
339
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
340
+ if not isinstance(examples[0], Dict):
341
+ examples = [{'input_ids': example} for example in examples]
342
+
343
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
344
+ tensorized = {}
345
+ for key in ['input_ids', 'cu_seqlens']:
346
+ if key not in example:
347
+ continue
348
+ if isinstance(example[key], List):
349
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
350
+ elif isinstance(example[key], np.ndarray):
351
+ tensorized[key] = torch.from_numpy(example[key])
352
+ else:
353
+ tensorized[key] = example[key]
354
+ return tensorized
355
+
356
+ examples = list(map(tensorize, examples))
357
+
358
+ if not self.varlen:
359
+ # --- Handling for varlen=False (Batch Padding) ---
360
+ length_of_first = examples[0]['input_ids'].size(0)
361
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
362
+
363
+ if needs_padding:
364
+ # Check for pad token if padding is actually required
365
+ if self.tokenizer.pad_token_id is None:
366
+ raise ValueError(
367
+ f'You are attempting to pad samples but the tokenizer you are using '
368
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
369
+ )
370
+ # Pad using the tokenizer, ensuring attention_mask is returned
371
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
372
+ else:
373
+ # No padding needed, stack directly and create a full attention mask
374
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
375
+ batch = {
376
+ 'input_ids': input_ids,
377
+ # Create attention mask of all ones
378
+ 'attention_mask': torch.ones_like(input_ids),
379
+ }
380
+
381
+ # Create labels by cloning input_ids
382
+ labels = batch['input_ids'].clone()
383
+ # Mask labels only where attention_mask is 0 (padding positions)
384
+ if 'attention_mask' in batch:
385
+ labels[batch['attention_mask'] == 0] = -100
386
+ batch['labels'] = labels
387
+
388
+ else:
389
+ # --- Handling for varlen=True (Concatenated Sequences) ---
390
+ if len(examples) > 1:
391
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
392
+
393
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
394
+
395
+ # --- cu_seqlens calculation logic remains the same ---
396
+ if 'cu_seqlens' in examples[0]:
397
+ batch['cu_seqlens'] = (
398
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
399
+ ) # Ensure int32
400
+ else:
401
+ # determine boundaries by bos/eos positions
402
+ # Check for bos_token_id first
403
+ if self.tokenizer.bos_token_id is not None:
404
+ cu_seqlens = []
405
+ # Handle case where the sequence doesn't start with BOS
406
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
407
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
408
+ # Find all BOS token positions
409
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
410
+ # Ensure bos_positions is on the correct device if empty
411
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
412
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
413
+ elif bos_positions.numel() > 0:
414
+ cu_seqlens.append(bos_positions)
415
+ # Add the end of the entire batch
416
+ cu_seqlens.append(
417
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
418
+ ) # Match device and use size(1)
419
+ # Filter out empty tensors before cat
420
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
421
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
422
+ batch['cu_seqlens'] = torch.tensor(
423
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
424
+ )
425
+ else:
426
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
427
+
428
+ # Else, check for eos_token_id
429
+ elif self.tokenizer.eos_token_id is not None:
430
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
431
+ # Find positions *after* EOS tokens
432
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
433
+ # Ensure eos_positions is on the correct device if empty
434
+ if eos_positions.numel() > 0:
435
+ cu_seqlens.append(eos_positions)
436
+ # Handle case where the sequence doesn't end with EOS
437
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
438
+ # Only add the final length if the last found EOS wasn't already the end
439
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
440
+ cu_seqlens.append(
441
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
442
+ ) # Match device and use size(1)
443
+ # Filter out empty tensors before cat
444
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
445
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
446
+ batch['cu_seqlens'] = torch.tensor(
447
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
448
+ )
449
+ else:
450
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
451
+ # Else, neither BOS nor EOS is usable
452
+ else:
453
+ raise ValueError(
454
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
455
+ 'or an eos_token_id defined to act as sequence separators.'
456
+ )
457
+
458
+ # --- cu_seqlens validation checks remain the same ---
459
+ if batch['cu_seqlens'].numel() < 2:
460
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
461
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
462
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
463
+ if batch['cu_seqlens'][0] != 0:
464
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
466
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
467
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
468
+ raise ValueError(
469
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
470
+ f'{batch["cu_seqlens"]}'
471
+ )
472
+
473
+ # --- context_len splitting logic remains the same ---
474
+ if self.context_len is not None:
475
+ # This logic splits sequences based on context_len *after* initial boundaries are found
476
+ bos = batch['cu_seqlens'][:-1].tolist()
477
+ eos = batch['cu_seqlens'][1:].tolist()
478
+ # Handle empty sequences between boundaries
479
+ split_boundaries = []
480
+ for i, j in zip(bos, eos):
481
+ if i < j: # Only process non-empty sequences
482
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
483
+ # Add the final end point if it wasn't included by arange
484
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
485
+ # Concatenate all boundaries
486
+ if not split_boundaries: # Handle case of completely empty input
487
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
488
+ else:
489
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
490
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
491
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
492
+
493
+ # Create labels directly from input_ids, NO padding mask needed for varlen
494
+ labels = batch['input_ids'].clone()
495
+ batch['labels'] = labels
496
+
497
+ return batch
498
+
499
+
500
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
501
+ """
502
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ rank: int,
508
+ dataset: IterableDataset,
509
+ batch_size: int,
510
+ collate_fn: Callable,
511
+ num_workers: int = 0,
512
+ pin_memory: bool = False,
513
+ prefetch_factor: int = 2,
514
+ persistent_workers: bool = False,
515
+ snapshot_every_n_steps: Optional[int] = 1,
516
+ ):
517
+ super().__init__(
518
+ dataset=dataset,
519
+ batch_size=batch_size,
520
+ collate_fn=collate_fn,
521
+ num_workers=num_workers,
522
+ pin_memory=pin_memory,
523
+ prefetch_factor=prefetch_factor,
524
+ persistent_workers=persistent_workers,
525
+ snapshot_every_n_steps=snapshot_every_n_steps,
526
+ )
527
+ self.rank = rank
528
+
529
+ def state_dict(self) -> Dict[str, Any]:
530
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
531
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
532
+
533
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
534
+ # State being empty is valid
535
+ if not state_dict:
536
+ return
537
+
538
+ if f'rank_{self.rank}' not in state_dict:
539
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
540
+ return
541
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542
+
543
+
544
+ def build_dataloader(
545
+ dataset: IterableDataset,
546
+ tokenizer: PreTrainedTokenizer,
547
+ rank: int,
548
+ world_size: int,
549
+ batch_size: int,
550
+ seq_len: int,
551
+ context_len: Optional[int] = None,
552
+ varlen: bool = False,
553
+ num_workers: int = 0,
554
+ pin_memory: bool = False,
555
+ persistent_workers: bool = False,
556
+ snapshot_every_n_steps: Optional[int] = 1,
557
+ ):
558
+ dataset = OnlineTokenizedIterableDataset(
559
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
560
+ )
561
+ return ParallelAwareDataLoader(
562
+ rank=rank,
563
+ dataset=dataset,
564
+ batch_size=batch_size,
565
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
566
+ num_workers=num_workers,
567
+ pin_memory=pin_memory,
568
+ persistent_workers=persistent_workers,
569
+ snapshot_every_n_steps=snapshot_every_n_steps,
570
+ )
flame/models/__init__.py ADDED
File without changes
flame/models/activation_offloading.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import contextlib
9
+ from typing import Union
10
+ from warnings import warn
11
+
12
+ import psutil
13
+ import torch
14
+ from torch import nn
15
+ from torch.autograd.graph import saved_tensors_hooks
16
+
17
+ from torchtitan.tools.logging import logger
18
+
19
+ try:
20
+ import torchao
21
+ from torchao.dtypes.nf4tensor import NF4Tensor
22
+ except ImportError:
23
+ torchao = None
24
+ NF4Tensor = None
25
+ logger.warning("torchao not found. ")
26
+
27
+ # from torchtune.modules import TiedLinear
28
+
29
+
30
+ class OffloadActivations(saved_tensors_hooks):
31
+ """Context manager under which activation tensors created in the forward pass will be offloaded.
32
+
33
+ Enable the memory efficiency technique of activation offloading, where activations bigger than
34
+ min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
35
+ This is in contrast to maintaining the activation on GPU VRAM throughout the program.
36
+
37
+ This manager contains the option of using one additional CUDA stream to handle the communication
38
+ between CUDA and CPU, which is intended to overlap with the default computation stream to improve
39
+ runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
40
+ runtime vs memory usage.
41
+
42
+ Args:
43
+ use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
44
+ memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
45
+ but is a limited resource. Default: True.
46
+
47
+ use_streams (bool): Whether or not to use streams for performance optimization where
48
+ the communications get overlapped with the computation. Requires a torch build
49
+ after torch-2.5.0.]. Default: True.
50
+
51
+ max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
52
+ consecutive activations to keep alive during the forward pass. This number must be at
53
+ least 1. Keeping alive more activations will potentially allow more overlap between the
54
+ communication and compute streams at the cost of increasing memory usage. Keeping alive
55
+ fewer activations will conserve memory, but may cause poor overlap between the streams,
56
+ increasing runtime. Default: 5.
57
+
58
+ min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
59
+ for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
60
+ moving it to CPU and back. Default: 1024 bytes.
61
+
62
+ Raises:
63
+ ValueError: if max_fwd_stash_size is not at least 1.
64
+
65
+ Example:
66
+ >>> with OffloadActivations():
67
+ >>> logits = model(inputs)
68
+ >>> loss = ...
69
+ >>> loss.backward()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ use_pin_memory: bool = True,
75
+ use_streams: bool = True,
76
+ max_fwd_stash_size: int = 5,
77
+ min_offload_size: int = 1024,
78
+ ) -> None:
79
+
80
+ self.use_streams: bool = use_streams
81
+
82
+ self.min_tensor_size_bytes = (
83
+ min_offload_size # we don't want to bother with small tensors
84
+ )
85
+ self.tracker = (
86
+ {}
87
+ ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
88
+ self.tensor_id: int = 0
89
+ self.is_first_forward_call = True
90
+ self.is_first_backward_call = True
91
+ self.is_first_forward_pass = True
92
+
93
+ # managing cpu memory
94
+ self.use_pin_memory: bool = use_pin_memory
95
+ self.virtual_memory_safe_pct = (
96
+ 60 # we should not exceed this percentage of memory
97
+ )
98
+
99
+ self.s0 = torch.cuda.default_stream() # comp stream
100
+
101
+ # for streaming
102
+ if self.use_streams:
103
+ self.s1 = torch.cuda.Stream() # comms stream
104
+ self.fwd_stash = {} # tensor_id => (activation, ev1)
105
+ if max_fwd_stash_size < 1:
106
+ raise ValueError(
107
+ f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
108
+ )
109
+ self.max_fwd_stash_size = max_fwd_stash_size
110
+ self.bwd_tensor_stash = {} # tensor_id => activation
111
+ self.bwd_ev_stash = {} # tensor_id => ev0
112
+ self.curr_graph_id = None
113
+ self.curr_autograd_node = None
114
+
115
+ # -------- platform util functions -------- #
116
+ def verify_sufficient_virtual_memory():
117
+ curr_pct = get_cpu_ram_pct()
118
+ if curr_pct > self.virtual_memory_safe_pct:
119
+ warn(
120
+ f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
121
+ )
122
+
123
+ def get_cpu_ram_pct() -> float:
124
+ # get the percentage of memory used by the system
125
+ return psutil.virtual_memory().percent
126
+
127
+ def get_tensor_id() -> int:
128
+ # create a unique id for each tensor we are managing
129
+ self.tensor_id += 1
130
+ return self.tensor_id
131
+
132
+ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
+ # get the number of bytes in a tensor, for memory management purposes
134
+ return (
135
+ x.element_size() * x.nelement()
136
+ ) # x.element_size() * x._base_storage().nbytes()
137
+
138
+ # -------- core pack / unpack work -------- #
139
+ def pack_tensor(activation: torch.Tensor) -> int:
140
+ # activations are passed in during forward pass - from here we take over and return a unique id
141
+ if self.is_first_forward_call:
142
+ assert (
143
+ len(self.tracker) == 0
144
+ ), "backward pass should have cleared tracker of all tensors"
145
+
146
+ # set training phase trackers
147
+ self.is_first_forward_call = False
148
+ self.is_first_backward_call = True
149
+
150
+ # query for basic tensor info
151
+ num_bytes = get_num_bytes_tensor(activation)
152
+ tensor_id = get_tensor_id()
153
+
154
+ # only offload hefty bois if they're activations on CUDA (our heuristic
155
+ # for that is to check if they're not params or buffers)!
156
+ if (
157
+ activation.is_cuda
158
+ and num_bytes >= self.min_tensor_size_bytes
159
+ and (
160
+ not isinstance(activation, torch.nn.Parameter)
161
+ and not isinstance(activation, torch.nn.Buffer)
162
+ )
163
+ ):
164
+ if self.use_streams:
165
+ # First, sync back and dereference previously offloaded tensors
166
+ # as the offloading should be done sufficiently long ago.
167
+ for id in [k for k in self.fwd_stash.keys()]:
168
+ if id <= tensor_id - self.max_fwd_stash_size:
169
+ _, ev = self.fwd_stash[id]
170
+ self.s0.wait_event(ev)
171
+ del self.fwd_stash[id]
172
+ else:
173
+ break
174
+
175
+ # Sync in, offload, and add an event to sync back later
176
+ self.s1.wait_stream(self.s0)
177
+
178
+ stream = self.s1 if self.use_streams else self.s0
179
+ with torch.cuda.stream(stream):
180
+ try:
181
+ cpu_tensor = torch.empty_like(
182
+ activation, pin_memory=self.use_pin_memory, device="cpu"
183
+ )
184
+ except NotImplementedError as e:
185
+ if (
186
+ isinstance(activation, NF4Tensor)
187
+ and torchao.__version__ < "0.6.0.dev20240917"
188
+ ):
189
+ raise RuntimeError(
190
+ "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
191
+ ) from e
192
+ raise e
193
+ cpu_tensor.copy_(activation, non_blocking=True)
194
+ self.tracker[tensor_id] = (
195
+ cpu_tensor,
196
+ True,
197
+ ) # True = (in future) modified
198
+
199
+ if self.use_streams:
200
+ event = self.s1.record_event()
201
+
202
+ # Stash to keep activation alive til s1 is done
203
+ self.fwd_stash[tensor_id] = (activation, event)
204
+ else:
205
+ self.tracker[tensor_id] = (
206
+ activation,
207
+ False,
208
+ ) # False = not modified, tensor is as is
209
+
210
+ return tensor_id
211
+
212
+ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
213
+ # backward pass - we are called with the tensor_id, which
214
+ # we will use to retrieve the saved/offloaded tensor
215
+ if self.is_first_backward_call:
216
+ if self.is_first_forward_pass:
217
+ self.is_first_forward_pass = False
218
+ if self.use_pin_memory:
219
+ verify_sufficient_virtual_memory()
220
+
221
+ self.is_first_backward_call = False
222
+ self.is_first_forward_call = True
223
+
224
+ assert (
225
+ unpack_tensor_id in self.tracker
226
+ ), f"untracked tensor with id {unpack_tensor_id}"
227
+
228
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
229
+ if modified:
230
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
231
+ maybe_gpu_tensor = gpu_tensor
232
+
233
+ # clear tensor from tracking
234
+ del self.tracker[unpack_tensor_id]
235
+ return maybe_gpu_tensor
236
+
237
+ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
238
+ # backward pass - we are called with the tensor_id, which
239
+ # we will use to retrieve the saved/offloaded tensor
240
+ if self.is_first_backward_call:
241
+ self.curr_graph_id = torch._C._current_graph_task_id()
242
+
243
+ def wait_and_del_remaining_references() -> None:
244
+ for id in [k for k in self.bwd_tensor_stash.keys()]:
245
+ event = self.bwd_ev_stash[id]
246
+ self.s1.wait_event(event)
247
+ del self.bwd_tensor_stash[id]
248
+
249
+ # Register a callback to the end of autograd to clean everything up
250
+ torch.autograd.variable.Variable._execution_engine.queue_callback(
251
+ wait_and_del_remaining_references
252
+ )
253
+
254
+ if self.is_first_forward_pass:
255
+ self.is_first_forward_pass = False
256
+ if self.use_pin_memory:
257
+ verify_sufficient_virtual_memory()
258
+
259
+ self.is_first_backward_call = False
260
+ self.is_first_forward_call = True
261
+
262
+ assert (
263
+ unpack_tensor_id in self.tracker
264
+ ), f"untracked tensor with id {unpack_tensor_id}"
265
+
266
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
267
+ if modified:
268
+ # Get data on the current autograd node
269
+ graph_id = torch._C._current_graph_task_id()
270
+ node = torch._C._current_autograd_node()
271
+ prev_node_ids = []
272
+
273
+ # If we're on a new node, mark prev node's tensors to be freed later
274
+ if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
275
+ self.curr_autograd_node = node
276
+ prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
277
+
278
+ brought_back_from_cpu = True
279
+ if unpack_tensor_id in self.fwd_stash:
280
+ maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
281
+ brought_back_from_cpu = False
282
+ else:
283
+ # Kick off the process to bring tensors back
284
+ with torch.cuda.stream(self.s1):
285
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
286
+ maybe_gpu_tensor = gpu_tensor
287
+
288
+ # Tell comp stream to wait for the info to be loaded before executing
289
+ self.s0.wait_stream(self.s1)
290
+
291
+ # Stash the tensor to keep memory alive until compute stream is complete
292
+ self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
293
+
294
+ # Note: [Track views of the unpacked]
295
+ # Why do we get the use count of the unpacked tensor here? We want an
296
+ # initial count to compare to later, during the post-hook of the
297
+ # backward node, when we need to decide whether we're allowed to free
298
+ # the tensor yet. In what obscure cases must we delay freeing the
299
+ # tensor (and thus call record_stream)?
300
+ # 1. Any of the outputs of the backward node is a view of the unpacked
301
+ # tensor.
302
+ # 2. In the case that this unpacked tensor will be used in a
303
+ # checkpointed region, if one of the recomputed saved tensors ends
304
+ # up as a view of the unpacked tensor.
305
+ # 3. The user abuses the system somehow and manually relies on the
306
+ # unpacked tensor to exist after the backward node has executed.
307
+ storage_refcount = torch._C._storage_Use_Count(
308
+ maybe_gpu_tensor.untyped_storage()._cdata
309
+ )
310
+
311
+ def hook(outputs, inputs):
312
+ # create events for the current node inputs/outputs if they were streamed in
313
+ if brought_back_from_cpu:
314
+ # See Note: [Track views of the unpacked]
315
+ # IF any of the outputs is a view of the tensor, OR if a view of
316
+ # the tensor has been saved as a part of checkpoint's recompute
317
+ # process, OR the user has abusedly incurred a reference on the
318
+ # unpacked tensor, THEN the tensor might be used later and we
319
+ # cannot presume to delete it after only the current node is
320
+ # done! So we use our frenemy, record_stream, to ensure the
321
+ # Tensor stays unmessed with until it's done getting used in the
322
+ # compute stream (s0 here). Note that the con here is we introduce
323
+ # non-deterministic (thus higher) memory usage, but this case
324
+ # should not happen often.
325
+ unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
326
+ if (
327
+ torch._C._storage_Use_Count(
328
+ unpacked_tensor.untyped_storage()._cdata
329
+ )
330
+ > storage_refcount
331
+ ):
332
+ unpacked_tensor.record_stream(self.s0)
333
+ del self.bwd_tensor_stash[unpack_tensor_id]
334
+ else:
335
+ event = self.s0.record_event()
336
+ self.bwd_ev_stash[unpack_tensor_id] = event
337
+
338
+ # if there are still things in the fwd_stash, get rid of them as we're in bwd now
339
+ for id in [k for k in self.fwd_stash.keys()]:
340
+ _, ev = self.fwd_stash[id]
341
+ self.s0.wait_event(ev)
342
+ del self.fwd_stash[id]
343
+
344
+ # wait on prev node's events and del those
345
+ for id in prev_node_ids:
346
+ event = self.bwd_ev_stash[id]
347
+ self.s1.wait_event(event)
348
+ del self.bwd_tensor_stash[id]
349
+
350
+ return outputs
351
+
352
+ node.register_hook(hook)
353
+
354
+ # clear tensor from tracking
355
+ del self.tracker[unpack_tensor_id]
356
+ return maybe_gpu_tensor
357
+
358
+ unpack_tensor = (
359
+ unpack_tensor_with_streams
360
+ if self.use_streams
361
+ else unpack_tensor_single_stream
362
+ )
363
+ super().__init__(pack_tensor, unpack_tensor)
364
+
365
+
366
+ class NoOpManager(saved_tensors_hooks):
367
+ """
368
+ A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
369
+ applied before. This relies on the behavior that only the most recently registered
370
+ saved_tensors_hook will run.
371
+
372
+ One example usage is to opt a local region of code out of activations offloading,
373
+ which is usually applied globally to best track state.
374
+ """
375
+
376
+ def __init__(self) -> None:
377
+ def noop(tensor):
378
+ return tensor
379
+
380
+ super().__init__(noop, noop)
381
+
382
+
383
+ def get_act_offloading_ctx_manager(
384
+ model: nn.Module, enable_activation_offloading: bool
385
+ ) -> Union[OffloadActivations, contextlib.nullcontext]:
386
+ """Returns the activation offloading context manager for the model, which will be
387
+ a null context if enable_activation_offloading is False.
388
+
389
+ If activation offloading is enabled, we return the OffloadActivations context manager.
390
+ If activation offloading is disabled, we return a NoOpManager context manager.
391
+
392
+ Args:
393
+ model (nn.Module): the model to wrap with the activation offloading context manager.
394
+ enable_activation_offloading (bool): whether or not to enable activation offloading
395
+ for the model.
396
+
397
+ Returns:
398
+ contextlib.ContextDecorator: the activation offloading context manager for the model.
399
+
400
+ Raises:
401
+ NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
402
+ """
403
+ if enable_activation_offloading:
404
+ activations_handling_ctx = OffloadActivations()
405
+
406
+ # Below is our hack to disable offloading the last output Linear in every
407
+ # step, as the cost for offloading the activation and then soon after bringing
408
+ # it back is expensive. Moreover, due to heuristics in our streaming API,
409
+ # we actually use more memory if we offload it as it interferes with chunkedCE.
410
+ output_head_detected = False
411
+ noop_ctx = NoOpManager()
412
+
413
+ if hasattr(model, "output"):
414
+ if isinstance(model.output, nn.Module):
415
+ model.output.register_forward_pre_hook(
416
+ lambda *args: noop_ctx.__enter__()
417
+ )
418
+ model.output.register_forward_hook(
419
+ lambda *args: noop_ctx.__exit__(), always_call=True
420
+ )
421
+ print("registering hooks for model.output ============ ")
422
+ output_head_detected = True
423
+ # ================================
424
+ # ! TODO[flame] check if we need to detal with TiedLinear
425
+ # The following code appears in `torchtune`
426
+ # elif isinstance(model.output, TiedLinear):
427
+ # model.output.linear.register_forward_pre_hook(
428
+ # lambda *args: noop_ctx.__enter__()
429
+ # )
430
+ # model.output.linear.register_forward_hook(
431
+ # lambda *args: noop_ctx.__exit__(), always_call=True
432
+ # )
433
+ # output_head_detected = True
434
+
435
+ if not output_head_detected:
436
+ logger.warning(
437
+ "During activation offloading, no output head was detected. "
438
+ "If your model has an output head, it will be offloaded. "
439
+ "This usually greatly slows training, given the large vocabulary size. "
440
+ "To change this behavior, set your output head as model.output and make it "
441
+ "an nn.Module."
442
+ )
443
+
444
+ else:
445
+ activations_handling_ctx = contextlib.nullcontext()
446
+
447
+ return activations_handling_ctx
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/train.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+
12
+ import torch
13
+ from datasets import interleave_datasets, load_dataset
14
+ from torch.distributed.elastic.multiprocessing.errors import record
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+
17
+ import fla # noqa
18
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
19
+ from fla.ops.common.utils import prepare_position_ids
20
+ from flame.components.checkpoint import TrainState
21
+ from flame.config_manager import JobConfig
22
+ from flame.data import build_dataloader, shuffle
23
+ from flame.models.parallelize_fla import parallelize_fla
24
+ from flame.models.pipeline_fla import pipeline_fla
25
+ from flame.tools.utils import get_nparams_and_flops
26
+ from flame.utils.checkpoint import cleanup_local_checkpoints
27
+ from flame.utils.convert_dcp_to_hf import save_pretrained
28
+ from flame.utils.hf_utils import upload_checkpoint_to_hf
29
+ from datetime import datetime
30
+ from torchtitan.components.checkpoint import CheckpointManager
31
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
32
+ from torchtitan.components.loss import build_cross_entropy_loss
33
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
34
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
35
+ from torchtitan.components.optimizer import build_optimizers
36
+ from torchtitan.distributed import ParallelDims
37
+ from torchtitan.distributed import utils as dist_utils
38
+ from torchtitan.protocols.model_converter import build_model_converters
39
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
40
+ from torchtitan.tools import utils
41
+ from torchtitan.tools.logging import init_logger, logger
42
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
43
+
44
+
45
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
46
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
47
+
48
+
49
+ register_train_spec(
50
+ TrainSpec(
51
+ name="fla",
52
+ cls=AutoModelForCausalLM,
53
+ config=AutoConfig,
54
+ parallelize_fn=parallelize_fla,
55
+ pipelining_fn=pipeline_fla,
56
+ build_optimizers_fn=build_optimizers,
57
+ build_lr_schedulers_fn=build_lr_schedulers,
58
+ build_dataloader_fn=build_dataloader,
59
+ build_tokenizer_fn=build_tokenizer,
60
+ build_loss_fn=build_cross_entropy_loss,
61
+ )
62
+ )
63
+
64
+
65
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
66
+ @record
67
+ def main(job_config: JobConfig):
68
+ logger.info(f"Starting job: {job_config.job.description}")
69
+
70
+ if job_config.experimental.custom_model_path:
71
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
72
+
73
+ # used for colorful printing
74
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
75
+
76
+ if job_config.job.print_args:
77
+ logger.info(
78
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
79
+ )
80
+
81
+ # take control of garbage collection to avoid stragglers
82
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
83
+
84
+ device_module, device_type = utils.device_module, utils.device_type
85
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
86
+ # Device has to be set before creating TorchFT manager.
87
+ device_module.set_device(device)
88
+ ft_manager = init_ft_manager(job_config)
89
+
90
+ run_specific_repo_id = None
91
+ if getattr(job_config.checkpoint, "hf_upload_enabled", False):
92
+ hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None)
93
+ if hf_repo_base:
94
+ # Generate timestamp (adjust format if desired)
95
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
96
+ run_specific_repo_id = f"{hf_repo_base}-{timestamp}"
97
+ logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}")
98
+ else:
99
+ logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.")
100
+ # Disable upload if base name is missing
101
+ job_config.checkpoint.hf_upload_enabled = False
102
+
103
+ # init distributed
104
+ world_size = int(os.environ["WORLD_SIZE"])
105
+ if not ft_manager.enabled:
106
+ parallel_dims = ParallelDims(
107
+ dp_shard=job_config.training.data_parallel_shard_degree,
108
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
109
+ cp=job_config.experimental.context_parallel_degree,
110
+ tp=job_config.training.tensor_parallel_degree,
111
+ pp=job_config.experimental.pipeline_parallel_degree,
112
+ world_size=world_size,
113
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
114
+ )
115
+ else:
116
+ parallel_dims = FTParallelDims(
117
+ dp_shard=job_config.training.data_parallel_shard_degree,
118
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
119
+ cp=job_config.experimental.context_parallel_degree,
120
+ tp=job_config.training.tensor_parallel_degree,
121
+ pp=job_config.experimental.pipeline_parallel_degree,
122
+ world_size=world_size,
123
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
124
+ ft_manager=ft_manager,
125
+ )
126
+ dist_utils.init_distributed(job_config)
127
+ # initialize device memory monitor and get peak flops for MFU calculation
128
+ device_memory_monitor = build_device_memory_monitor()
129
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
130
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
131
+
132
+ # build meshes
133
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
134
+ if parallel_dims.dp_enabled:
135
+ dp_mesh = world_mesh["dp"]
136
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
137
+ else:
138
+ dp_degree, dp_rank = 1, 0
139
+
140
+ if parallel_dims.pp_enabled:
141
+ raise NotImplementedError(
142
+ "Pipeline parallelism is not supported in this version"
143
+ )
144
+ """
145
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
146
+ [x] Match the key of models' components with the actual naming
147
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
148
+ forces to tie if head is None, we need to handle this case
149
+ [ ]
150
+ """
151
+ pp_mesh = world_mesh["pp"]
152
+
153
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
154
+ dist_utils.set_determinism(
155
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
156
+ )
157
+ train_spec = get_train_spec(job_config.model.name)
158
+
159
+ logger.info("Loading tokenizer...")
160
+ tokenizer = AutoTokenizer.from_pretrained(
161
+ job_config.model.tokenizer_path,
162
+ trust_remote_code=True,
163
+ model_max_length=int(1e10),
164
+ )
165
+ logger.info(f"{tokenizer}")
166
+ logger.info(
167
+ f"Loading dataset {job_config.training.dataset}"
168
+ f":{job_config.training.dataset_name}"
169
+ if job_config.training.dataset_name is not None
170
+ else ""
171
+ )
172
+
173
+ min_num_shards = dp_degree * job_config.training.num_workers
174
+ if len(job_config.training.dataset.split(",")) == 1:
175
+ dataset = load_dataset(
176
+ path=job_config.training.dataset,
177
+ name=getattr(job_config.training, "dataset_name", None),
178
+ data_dir=getattr(job_config.training, "data_dir", None),
179
+ data_files=getattr(job_config.training, "data_files", None),
180
+ split=job_config.training.dataset_split or "train",
181
+ trust_remote_code=True,
182
+ streaming=job_config.training.streaming,
183
+ num_proc=(
184
+ job_config.training.num_workers
185
+ if not job_config.training.streaming
186
+ else None
187
+ ),
188
+ )
189
+ logger.info(f"{dataset}")
190
+
191
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
192
+ if not job_config.training.streaming:
193
+ # the states of map-style dataset is recoverable after shuffling
194
+ dataset = dataset.shuffle(
195
+ seed=job_config.training.seed
196
+ ).to_iterable_dataset(num_shards=min_num_shards)
197
+ else:
198
+ if dataset.num_shards < min_num_shards:
199
+ logger.warning(
200
+ f"{color.red}"
201
+ f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). "
202
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
203
+ f"{job_config.training.num_workers} dataloader workers. "
204
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
205
+ f"{color.reset}"
206
+ )
207
+ dataset = (
208
+ load_dataset(
209
+ path=job_config.training.dataset,
210
+ name=getattr(job_config.training, "dataset_name", None),
211
+ data_dir=getattr(job_config.training, "data_dir", None),
212
+ data_files=getattr(job_config.training, "data_files", None),
213
+ split=job_config.training.dataset_split or "train",
214
+ trust_remote_code=True,
215
+ streaming=False,
216
+ num_proc=job_config.training.num_workers,
217
+ )
218
+ .shuffle(seed=job_config.training.seed)
219
+ .to_iterable_dataset(num_shards=min_num_shards)
220
+ )
221
+ else:
222
+ dataset = shuffle(dataset, seed=job_config.training.seed)
223
+ else:
224
+ datasets = job_config.training.dataset.split(",")
225
+ if job_config.training.dataset_name is not None:
226
+ dataset_names = [
227
+ name or None for name in job_config.training.dataset_name.split(",")
228
+ ]
229
+ assert len(dataset_names) == len(datasets), (
230
+ "The number of dataset names must match the number of datasets"
231
+ )
232
+ else:
233
+ dataset_names = [None] * len(datasets)
234
+ if job_config.training.dataset_split is not None:
235
+ dataset_splits = [
236
+ split or "train"
237
+ for split in job_config.training.dataset_split.split(",")
238
+ ]
239
+ assert len(dataset_splits) == len(datasets), (
240
+ "The number of dataset splits must match the number of datasets"
241
+ )
242
+ else:
243
+ dataset_splits = ["train"] * len(datasets)
244
+ if job_config.training.data_dir is not None:
245
+ data_dirs = [
246
+ data_dir or None for data_dir in job_config.training.data_dir.split(",")
247
+ ]
248
+ assert len(data_dirs) == len(datasets), (
249
+ "The number of data dirs must match the number of datasets"
250
+ )
251
+ else:
252
+ data_dirs = [None] * len(datasets)
253
+ if job_config.training.data_files is not None:
254
+ data_files = job_config.training.data_files.split(",")
255
+ assert len(data_files) == len(datasets), (
256
+ "The number of data files must match the number of datasets"
257
+ )
258
+ else:
259
+ data_files = [None] * len(datasets)
260
+ if job_config.training.data_probs is not None:
261
+ data_probs = [float(p) for p in job_config.training.data_probs.split(",")]
262
+ assert len(data_probs) == len(datasets), (
263
+ "The number of data probabilities must match the number of datasets"
264
+ )
265
+ else:
266
+ raise ValueError(
267
+ "Data sampling probabilities are required if using multiple datasets"
268
+ )
269
+
270
+ subsets = []
271
+ for i, prob in enumerate(data_probs):
272
+ subset = load_dataset(
273
+ path=datasets[i],
274
+ name=dataset_names[i],
275
+ data_dir=data_dirs[i],
276
+ data_files=data_files[i],
277
+ split=dataset_splits[i],
278
+ trust_remote_code=True,
279
+ streaming=job_config.training.streaming,
280
+ num_proc=(
281
+ job_config.training.num_workers
282
+ if not job_config.training.streaming
283
+ else None
284
+ ),
285
+ )
286
+ logger.info(
287
+ f"Subset {color.cyan}{datasets[i]}"
288
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
289
+ + f"(p = {prob:.3f}){color.reset}:\n"
290
+ + f"{subset}"
291
+ )
292
+
293
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
294
+ if not job_config.training.streaming:
295
+ # the states of map-style dataset is recoverable after shuffling
296
+ subset = subset.shuffle(
297
+ seed=job_config.training.seed
298
+ ).to_iterable_dataset(num_shards=min_num_shards)
299
+ else:
300
+ if subset.num_shards < min_num_shards:
301
+ logger.warning(
302
+ f"{color.red}"
303
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
304
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
305
+ f"{job_config.training.num_workers} dataloader workers. "
306
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
307
+ f"{color.reset}"
308
+ )
309
+ # again, it's ok to directly shuffle the map-style dataset
310
+ # we expect an error raised if the map-style dataset still has not enough data shards
311
+ subset = (
312
+ load_dataset(
313
+ path=datasets[i],
314
+ name=dataset_names[i],
315
+ data_dir=data_dirs[i],
316
+ data_files=data_files[i],
317
+ split=dataset_splits[i],
318
+ trust_remote_code=True,
319
+ streaming=False,
320
+ num_proc=job_config.training.num_workers,
321
+ )
322
+ .shuffle(seed=job_config.training.seed)
323
+ .to_iterable_dataset(min_num_shards)
324
+ )
325
+ else:
326
+ # we set relatively small buffer size here as interleaving could provide some randomness
327
+ subset = shuffle(
328
+ subset,
329
+ seed=job_config.training.seed,
330
+ buffer_size=max(128, 1024 // len(datasets)),
331
+ )
332
+
333
+ if "text" in subset.column_names:
334
+ subset = subset.select_columns("text")
335
+ elif "content" in subset.column_names:
336
+ subset = subset.select_columns("content")
337
+ else:
338
+ raise ValueError(
339
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
340
+ )
341
+ subsets.append(subset)
342
+
343
+ logger.info(
344
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
345
+ )
346
+ dataset = interleave_datasets(
347
+ datasets=subsets,
348
+ probabilities=data_probs,
349
+ stopping_strategy="all_exhausted",
350
+ seed=job_config.training.seed,
351
+ )
352
+ logger.info(f"{dataset}")
353
+
354
+ logger.info("Building dataloader...")
355
+ dataloader = build_dataloader(
356
+ dataset=dataset,
357
+ tokenizer=tokenizer,
358
+ rank=dp_rank,
359
+ world_size=dp_degree,
360
+ batch_size=job_config.training.batch_size,
361
+ seq_len=job_config.training.seq_len,
362
+ context_len=job_config.training.context_len,
363
+ varlen=job_config.training.varlen,
364
+ num_workers=job_config.training.num_workers,
365
+ pin_memory=job_config.training.pin_memory,
366
+ persistent_workers=job_config.training.persistent_workers,
367
+ snapshot_every_n_steps=job_config.checkpoint.interval,
368
+ )
369
+
370
+ logger.info(f"Loading model config from {job_config.model.config}")
371
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
372
+ # set the model configs from training inputs:
373
+ # 1. norm type to decide which norm layer to use
374
+ # 2. disable fused norm if TP is enabled
375
+ # 3. vocab size from tokenizer
376
+ # 4. context_len base on inputs
377
+ if parallel_dims.tp_enabled:
378
+ if model_config.fuse_norm:
379
+ logger.warning(
380
+ f"{color.red}"
381
+ f"Fused norm is not compatible with tensor parallelism. "
382
+ f"Disabling it for now."
383
+ f"{color.reset}"
384
+ )
385
+ model_config.fuse_norm = False
386
+ if parallel_dims.loss_parallel_enabled:
387
+ if model_config.fuse_cross_entropy:
388
+ logger.warning(
389
+ f"{color.red}"
390
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
391
+ f"{color.reset}"
392
+ )
393
+ model_config.fuse_cross_entropy = False
394
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
395
+
396
+ logger.info(
397
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
398
+ )
399
+ with torch.device("meta"):
400
+ model = AutoModelForCausalLM.from_config(model_config)
401
+ if (
402
+ getattr(model_config, "fuse_cross_entropy", False)
403
+ and FusedLinearCrossEntropyLoss is not None
404
+ ):
405
+ model.criterion = FusedLinearCrossEntropyLoss(
406
+ num_chunks=8 // parallel_dims.tp
407
+ )
408
+ # defer weight initialization until after parallelisms are applied
409
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
410
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
411
+
412
+ # Build the collection of model converters. No-op if `model.converters` empty
413
+ model_converters = build_model_converters(job_config, parallel_dims)
414
+ model_converters.convert(model)
415
+
416
+ # calculate model size and flops per token
417
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
418
+ model, model_config, job_config.training.context_len
419
+ )
420
+
421
+ # move sharded model to CPU/GPU and initialize weights via DTensor
422
+ if job_config.checkpoint.create_seed_checkpoint:
423
+ init_device = "cpu"
424
+ elif job_config.training.enable_cpu_offload:
425
+ init_device = "cpu"
426
+ else:
427
+ init_device = device_type
428
+
429
+ # apply parallelisms and initialization
430
+ if parallel_dims.pp_enabled:
431
+ # apply PT-D Pipeline Parallel
432
+ (
433
+ pp_schedule,
434
+ model_parts,
435
+ has_first_stage,
436
+ has_last_stage,
437
+ ) = train_spec.pipelining_fn(
438
+ model,
439
+ pp_mesh,
440
+ parallel_dims,
441
+ job_config,
442
+ device,
443
+ model_config,
444
+ train_spec.loss_fn,
445
+ )
446
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
447
+ del model
448
+
449
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
450
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
451
+ # optimizer, and checkpointing
452
+ for m in model_parts:
453
+ # apply SPMD-style PT-D techniques
454
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
455
+ m.to_empty(device=init_device)
456
+ with torch.no_grad():
457
+ m.post_init()
458
+ m.train()
459
+
460
+ # confirm that user will be able to view loss metrics on the console
461
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
462
+ else:
463
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
464
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
465
+ model.to_empty(device=init_device)
466
+ with torch.no_grad():
467
+ model.post_init()
468
+ model.train()
469
+
470
+ model_parts = [model]
471
+
472
+ device_mem_stats = device_memory_monitor.get_peak_stats()
473
+ logger.info(
474
+ f"{device_type.upper()} memory usage for model: "
475
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
476
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
477
+ )
478
+
479
+ # build optimizer after applying parallelisms to the model
480
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
481
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
482
+ # Post optimizer step model converters hook.
483
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
484
+ # where it issues a single all-reduce for all parameters at once for better performance
485
+ optimizers.register_step_post_hook(
486
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
487
+ )
488
+
489
+ train_state = TrainState()
490
+
491
+ # load initial checkpoint
492
+ checkpoint = CheckpointManager(
493
+ dataloader=dataloader,
494
+ model_parts=model_parts,
495
+ optimizers=optimizers,
496
+ lr_schedulers=lr_schedulers,
497
+ states={"train_state": train_state},
498
+ job_config=job_config,
499
+ ft_manager=ft_manager,
500
+ )
501
+
502
+ if job_config.checkpoint.create_seed_checkpoint:
503
+ assert world_size == 1, (
504
+ "Must create seed checkpoint using a single device, to disable sharding"
505
+ )
506
+ assert job_config.checkpoint.enable_checkpoint, (
507
+ "Must enable checkpointing when creating a seed checkpoint"
508
+ )
509
+ checkpoint.save(curr_step=0, force=True)
510
+ logger.info("Created seed checkpoint")
511
+ return
512
+
513
+ checkpoint.load(step=job_config.checkpoint.load_step)
514
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
515
+ # Set dependent attributes for metric_logger
516
+ metric_logger.num_flops_per_token = num_flops_per_token
517
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
518
+ metric_logger.lr_schedulers = (
519
+ lr_schedulers # Pass schedulers if needed by logger logic
520
+ )
521
+
522
+ # plot losses loaded from checkpoint (if any) to TensorBoard
523
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
524
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
525
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
526
+ for idx, step in enumerate(train_state.log_steps):
527
+ metric_logger.log(
528
+ step,
529
+ global_avg_loss=train_state.global_avg_losses[idx],
530
+ global_max_loss=train_state.global_max_losses[idx],
531
+ )
532
+
533
+ data_iterator = iter(dataloader)
534
+
535
+ train_context = dist_utils.get_train_context(
536
+ parallel_dims.loss_parallel_enabled,
537
+ job_config.experimental.enable_compiled_autograd,
538
+ )
539
+
540
+ # variables used to keep info for metrics logging
541
+ device_memory_monitor.reset_peak_stats()
542
+
543
+ global_batch_size = (
544
+ job_config.training.batch_size
545
+ * dp_degree
546
+ * job_config.training.gradient_accumulation_steps
547
+ )
548
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
549
+ # train loop
550
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
551
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
552
+ logger.info(
553
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
554
+ )
555
+ logger.info(
556
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
557
+ )
558
+ logger.info(
559
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
560
+ )
561
+ logger.info(
562
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
563
+ f" ({num_tokens_per_step:,} tokens)"
564
+ )
565
+ logger.info(
566
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
567
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
568
+ )
569
+ logger.info(
570
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
571
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
572
+ )
573
+ logger.info(
574
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
575
+ )
576
+
577
+ with (
578
+ maybe_enable_profiling(
579
+ job_config, global_step=train_state.step
580
+ ) as torch_profiler,
581
+ maybe_enable_memory_snapshot(
582
+ job_config, global_step=train_state.step
583
+ ) as memory_profiler,
584
+ ):
585
+ while train_state.step < job_config.training.steps:
586
+ train_state.step += 1
587
+ gc_handler.run(train_state.step)
588
+
589
+ optimizers.zero_grad()
590
+
591
+ losses = []
592
+ # do gradient accumulation if enabled
593
+ for _ in range(job_config.training.gradient_accumulation_steps):
594
+ # get batch
595
+ data_load_start = time.perf_counter()
596
+ batch = next(data_iterator)
597
+ input_ids, labels = batch["input_ids"], batch["labels"]
598
+
599
+ # Update metrics processor state before forward/backward
600
+ metric_logger.ntokens_since_last_log += labels.numel()
601
+ metric_logger.data_loading_times.append(
602
+ time.perf_counter() - data_load_start
603
+ )
604
+
605
+ input_ids = input_ids.to(device_type)
606
+
607
+ """
608
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
609
+ Depending on the Models'PE, the position_ids might be different.
610
+
611
+ e.g. for TP
612
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
613
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
614
+
615
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
616
+ Each rank has the coresponding chunked position_ids. [FOR All model]
617
+
618
+ """
619
+ labels = labels.to(device_type)
620
+ cu_seqlens = (
621
+ batch["cu_seqlens"].to(device_type)
622
+ if "cu_seqlens" in batch
623
+ else None
624
+ )
625
+ if cu_seqlens is not None:
626
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
627
+ else:
628
+ position_ids = (
629
+ torch.arange(0, input_ids.shape[1], device=device_type)
630
+ .repeat(input_ids.shape[0], 1)
631
+ .to(torch.int32)
632
+ )
633
+ # apply context parallelism if cp is enabled
634
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
635
+ optional_context_parallel_ctx = (
636
+ dist_utils.create_context_parallel_ctx(
637
+ cp_mesh=world_mesh["cp"],
638
+ cp_buffers=[input_ids, labels, position_ids],
639
+ cp_seq_dims=[1, 1, 1],
640
+ cp_no_restore_buffers={input_ids, labels, position_ids},
641
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
642
+ )
643
+ if parallel_dims.cp_enabled
644
+ else None
645
+ )
646
+
647
+ # #! TODO[flame], we should distribute the position_ids as well with CP
648
+ if parallel_dims.pp_enabled:
649
+ raise NotImplementedError(
650
+ "Pipeline parallelism is not supported in this version"
651
+ )
652
+ # Pipeline Parallel forward / backward inside step() call
653
+ with train_context(optional_context_parallel_ctx):
654
+ targets, losses = (
655
+ (labels, []) if has_last_stage else (None, None)
656
+ )
657
+
658
+ if has_first_stage:
659
+ pp_schedule.step(input_ids, target=targets, losses=losses)
660
+ else:
661
+ pp_schedule.step(target=targets, losses=losses)
662
+
663
+ # accumulate losses across pipeline microbatches
664
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
665
+ loss = (
666
+ torch.mean(torch.stack(losses)).to(device)
667
+ if has_last_stage
668
+ else torch.tensor([-1.0], device=device)
669
+ )
670
+ else:
671
+ # Non-PP forward / backward
672
+ with train_context(optional_context_parallel_ctx):
673
+ output = model(
674
+ input_ids=input_ids,
675
+ labels=labels,
676
+ position_ids=position_ids,
677
+ cu_seqlens=cu_seqlens,
678
+ )
679
+ loss = (
680
+ output.loss
681
+ / job_config.training.gradient_accumulation_steps
682
+ )
683
+ loss.backward()
684
+
685
+ losses.append(loss)
686
+ loss = sum(losses)
687
+
688
+ # clip gradients
689
+ grad_norm = dist_utils.clip_grad_norm_(
690
+ [p for m in model_parts for p in m.parameters()],
691
+ job_config.training.max_norm,
692
+ foreach=True,
693
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
694
+ )
695
+
696
+ # optimizer step
697
+ checkpoint.maybe_wait_for_staging()
698
+ if job_config.training.skip_nan_inf and (
699
+ grad_norm.isnan() or grad_norm.isinf()
700
+ ):
701
+ logger.warning(
702
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
703
+ )
704
+ optimizers.zero_grad()
705
+ train_state.skipped_step += 1
706
+ else:
707
+ optimizers.step()
708
+ lr_schedulers.step()
709
+
710
+ # log metrics - Use MetricsProcessor
711
+ if metric_logger.should_log(train_state.step):
712
+ if (
713
+ parallel_dims.dp_replicate_enabled
714
+ or parallel_dims.dp_shard_enabled
715
+ or parallel_dims.cp_enabled
716
+ ):
717
+ loss = loss.detach()
718
+ # Use dist_mean/max on the accumulated loss for the step
719
+ global_avg_loss, global_max_loss = (
720
+ dist_utils.dist_mean(
721
+ loss,
722
+ world_mesh["dp_cp"],
723
+ ),
724
+ dist_utils.dist_max(
725
+ loss,
726
+ world_mesh["dp_cp"],
727
+ ),
728
+ )
729
+ else:
730
+ # Scale back the loss before logging
731
+ global_avg_loss = global_max_loss = loss.item()
732
+
733
+ # Update train state tokens and elapsed time
734
+ time_now = time.perf_counter()
735
+ time_delta = (
736
+ time_now - metric_logger.time_last_log
737
+ ) # Use metric_logger's time
738
+ train_state.token += (
739
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
740
+ * parallel_dims.world_size
741
+ / parallel_dims.non_data_parallel_size
742
+ )
743
+ train_state.elapsed += timedelta(seconds=time_delta)
744
+ train_state.log_steps.append(train_state.step)
745
+ train_state.global_avg_losses.append(global_avg_loss)
746
+ train_state.global_max_losses.append(global_max_loss)
747
+
748
+ # Log using the metric processor
749
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
750
+ eta = (
751
+ train_state.elapsed
752
+ * (job_config.training.steps - train_state.step)
753
+ / train_state.step
754
+ )
755
+ metric_logger.log(
756
+ train_state.step,
757
+ global_avg_loss,
758
+ global_max_loss,
759
+ extra_metrics={
760
+ "optimizer/lr": last_lr,
761
+ "optimizer/grad_norm": grad_norm.item(),
762
+ "optimizer/skipped_step": train_state.skipped_step,
763
+ },
764
+ )
765
+
766
+ logger.info(
767
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
768
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
769
+ )
770
+
771
+ checkpoint.save(
772
+ train_state.step, force=(train_state.step == job_config.training.steps)
773
+ )
774
+
775
+ if torch.distributed.get_rank() == 0:
776
+ if job_config.checkpoint.enable_checkpoint:
777
+ hf_target_path = None
778
+ dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}")
779
+
780
+ # TODO: Haven't tested this one yet
781
+ if getattr(job_config.checkpoint, "convert_to_hf_on_save", False):
782
+ try:
783
+ # Get the path where DCP was just saved
784
+ # Check CheckpointManager API for the best way, assuming get_save_path exists
785
+ hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf
786
+
787
+ logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}")
788
+ save_pretrained( # Call the imported function
789
+ path=hf_target_path, # Pass target HF path as 'path'
790
+ step=train_state.step,
791
+ config=job_config.model.config, # Pass model config path/id
792
+ tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id
793
+ )
794
+ logger.info(f"Successfully converted step {train_state.step} to HF format.")
795
+
796
+ except Exception as e:
797
+ logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True)
798
+
799
+ base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder)
800
+ if getattr(job_config.checkpoint, "hf_upload_enabled", True):
801
+ upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf")
802
+ keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5)
803
+
804
+ local_path_to_upload = None
805
+ if upload_format == "hf":
806
+ if hf_target_path and os.path.isdir(hf_target_path):
807
+ local_path_to_upload = hf_target_path
808
+ elif upload_format == "dcp":
809
+ if dcp_save_path and os.path.isdir(dcp_save_path):
810
+ local_path_to_upload = dcp_save_path
811
+
812
+ if local_path_to_upload:
813
+ try:
814
+ upload_checkpoint_to_hf(
815
+ local_path=local_path_to_upload,
816
+ step=train_state.step,
817
+ hf_repo_id_for_run=run_specific_repo_id,
818
+ upload_format=upload_format,
819
+ hf_keep_latest_k=job_config.checkpoint.keep_latest_k,
820
+ )
821
+ except Exception as e:
822
+ logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True)
823
+
824
+ # signal the profiler that the next profiling step has started
825
+ if torch_profiler:
826
+ torch_profiler.step()
827
+ if memory_profiler:
828
+ memory_profiler.step()
829
+
830
+ # reduce timeout after first train step for faster signal
831
+ # (assuming lazy init and compilation are finished)
832
+ if train_state.step == 1:
833
+ dist_utils.set_pg_timeouts(
834
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
835
+ world_mesh=world_mesh,
836
+ )
837
+
838
+ if torch.distributed.get_rank() == 0:
839
+ logger.info("Sleeping 2 seconds for other ranks to complete")
840
+ time.sleep(2)
841
+
842
+ metric_logger.close()
843
+ logger.info("Training completed")
844
+
845
+
846
+ if __name__ == "__main__":
847
+ init_logger()
848
+ config = JobConfig()
849
+ config.parse_args()
850
+ main(config)
851
+ torch.distributed.destroy_process_group()
flame/utils/__init__.py ADDED
File without changes
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.51.3"
6
+ }
logs/none_nygareex/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_nygareex/attempt_0/0/stdout.log ADDED
File without changes
logs/none_nygareex/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_nygareex/attempt_0/1/stdout.log ADDED
File without changes
logs/none_nygareex/attempt_0/2/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_nygareex/attempt_0/2/stdout.log ADDED
File without changes
logs/none_nygareex/attempt_0/3/stderr.log ADDED
The diff for this file is too large to render. See raw diff