zaydzuhri commited on
Commit
2be9f66
·
verified ·
1 Parent(s): 8fbfec1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  2. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  3. flame/__pycache__/data.cpython-312.pyc +0 -0
  4. flame/components/__init__.py +0 -0
  5. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
  6. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  7. flame/components/checkpoint.py +59 -0
  8. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  9. flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
  10. flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
  11. flame/models/activation_offloading.py +447 -0
  12. flame/models/fla.toml +67 -0
  13. flame/models/parallelize_fla.py +550 -0
  14. flame/models/pipeline_fla.py +162 -0
  15. flame/tools/__init__.py +0 -0
  16. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  17. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  18. flame/utils/__init__.py +0 -0
  19. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  20. flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
  21. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  22. flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
  23. flame/utils/checkpoint.py +50 -0
  24. flame/utils/convert_dcp_to_hf.py +66 -0
  25. flame/utils/convert_hf_to_dcp.py +34 -0
  26. flame/utils/hf_utils.py +77 -0
  27. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log +90 -0
  28. torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
  29. torchtitan/components/__pycache__/ft.cpython-312.pyc +0 -0
  30. torchtitan/components/float8.py +150 -0
  31. torchtitan/components/ft.py +143 -0
  32. torchtitan/experiments/deepseek_v3/model_config.py +204 -0
  33. torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
  34. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  35. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  36. torchtitan/experiments/deepseek_v3/train.py +142 -0
  37. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  38. torchtitan/experiments/flux/flux_argparser.py +42 -0
  39. torchtitan/experiments/flux/model/autoencoder.py +388 -0
  40. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  41. torchtitan/experiments/flux/model/layers.py +286 -0
  42. torchtitan/experiments/flux/model/model.py +177 -0
  43. torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
  44. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  45. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  46. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  47. torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
  48. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
  49. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  50. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
flame/components/__init__.py ADDED
File without changes
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (141 Bytes). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
flame/models/activation_offloading.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import contextlib
9
+ from typing import Union
10
+ from warnings import warn
11
+
12
+ import psutil
13
+ import torch
14
+ from torch import nn
15
+ from torch.autograd.graph import saved_tensors_hooks
16
+
17
+ from torchtitan.tools.logging import logger
18
+
19
+ try:
20
+ import torchao
21
+ from torchao.dtypes.nf4tensor import NF4Tensor
22
+ except ImportError:
23
+ torchao = None
24
+ NF4Tensor = None
25
+ logger.warning("torchao not found. ")
26
+
27
+ # from torchtune.modules import TiedLinear
28
+
29
+
30
+ class OffloadActivations(saved_tensors_hooks):
31
+ """Context manager under which activation tensors created in the forward pass will be offloaded.
32
+
33
+ Enable the memory efficiency technique of activation offloading, where activations bigger than
34
+ min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
35
+ This is in contrast to maintaining the activation on GPU VRAM throughout the program.
36
+
37
+ This manager contains the option of using one additional CUDA stream to handle the communication
38
+ between CUDA and CPU, which is intended to overlap with the default computation stream to improve
39
+ runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
40
+ runtime vs memory usage.
41
+
42
+ Args:
43
+ use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
44
+ memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
45
+ but is a limited resource. Default: True.
46
+
47
+ use_streams (bool): Whether or not to use streams for performance optimization where
48
+ the communications get overlapped with the computation. Requires a torch build
49
+ after torch-2.5.0.]. Default: True.
50
+
51
+ max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
52
+ consecutive activations to keep alive during the forward pass. This number must be at
53
+ least 1. Keeping alive more activations will potentially allow more overlap between the
54
+ communication and compute streams at the cost of increasing memory usage. Keeping alive
55
+ fewer activations will conserve memory, but may cause poor overlap between the streams,
56
+ increasing runtime. Default: 5.
57
+
58
+ min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
59
+ for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
60
+ moving it to CPU and back. Default: 1024 bytes.
61
+
62
+ Raises:
63
+ ValueError: if max_fwd_stash_size is not at least 1.
64
+
65
+ Example:
66
+ >>> with OffloadActivations():
67
+ >>> logits = model(inputs)
68
+ >>> loss = ...
69
+ >>> loss.backward()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ use_pin_memory: bool = True,
75
+ use_streams: bool = True,
76
+ max_fwd_stash_size: int = 5,
77
+ min_offload_size: int = 1024,
78
+ ) -> None:
79
+
80
+ self.use_streams: bool = use_streams
81
+
82
+ self.min_tensor_size_bytes = (
83
+ min_offload_size # we don't want to bother with small tensors
84
+ )
85
+ self.tracker = (
86
+ {}
87
+ ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
88
+ self.tensor_id: int = 0
89
+ self.is_first_forward_call = True
90
+ self.is_first_backward_call = True
91
+ self.is_first_forward_pass = True
92
+
93
+ # managing cpu memory
94
+ self.use_pin_memory: bool = use_pin_memory
95
+ self.virtual_memory_safe_pct = (
96
+ 60 # we should not exceed this percentage of memory
97
+ )
98
+
99
+ self.s0 = torch.cuda.default_stream() # comp stream
100
+
101
+ # for streaming
102
+ if self.use_streams:
103
+ self.s1 = torch.cuda.Stream() # comms stream
104
+ self.fwd_stash = {} # tensor_id => (activation, ev1)
105
+ if max_fwd_stash_size < 1:
106
+ raise ValueError(
107
+ f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
108
+ )
109
+ self.max_fwd_stash_size = max_fwd_stash_size
110
+ self.bwd_tensor_stash = {} # tensor_id => activation
111
+ self.bwd_ev_stash = {} # tensor_id => ev0
112
+ self.curr_graph_id = None
113
+ self.curr_autograd_node = None
114
+
115
+ # -------- platform util functions -------- #
116
+ def verify_sufficient_virtual_memory():
117
+ curr_pct = get_cpu_ram_pct()
118
+ if curr_pct > self.virtual_memory_safe_pct:
119
+ warn(
120
+ f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
121
+ )
122
+
123
+ def get_cpu_ram_pct() -> float:
124
+ # get the percentage of memory used by the system
125
+ return psutil.virtual_memory().percent
126
+
127
+ def get_tensor_id() -> int:
128
+ # create a unique id for each tensor we are managing
129
+ self.tensor_id += 1
130
+ return self.tensor_id
131
+
132
+ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
+ # get the number of bytes in a tensor, for memory management purposes
134
+ return (
135
+ x.element_size() * x.nelement()
136
+ ) # x.element_size() * x._base_storage().nbytes()
137
+
138
+ # -------- core pack / unpack work -------- #
139
+ def pack_tensor(activation: torch.Tensor) -> int:
140
+ # activations are passed in during forward pass - from here we take over and return a unique id
141
+ if self.is_first_forward_call:
142
+ assert (
143
+ len(self.tracker) == 0
144
+ ), "backward pass should have cleared tracker of all tensors"
145
+
146
+ # set training phase trackers
147
+ self.is_first_forward_call = False
148
+ self.is_first_backward_call = True
149
+
150
+ # query for basic tensor info
151
+ num_bytes = get_num_bytes_tensor(activation)
152
+ tensor_id = get_tensor_id()
153
+
154
+ # only offload hefty bois if they're activations on CUDA (our heuristic
155
+ # for that is to check if they're not params or buffers)!
156
+ if (
157
+ activation.is_cuda
158
+ and num_bytes >= self.min_tensor_size_bytes
159
+ and (
160
+ not isinstance(activation, torch.nn.Parameter)
161
+ and not isinstance(activation, torch.nn.Buffer)
162
+ )
163
+ ):
164
+ if self.use_streams:
165
+ # First, sync back and dereference previously offloaded tensors
166
+ # as the offloading should be done sufficiently long ago.
167
+ for id in [k for k in self.fwd_stash.keys()]:
168
+ if id <= tensor_id - self.max_fwd_stash_size:
169
+ _, ev = self.fwd_stash[id]
170
+ self.s0.wait_event(ev)
171
+ del self.fwd_stash[id]
172
+ else:
173
+ break
174
+
175
+ # Sync in, offload, and add an event to sync back later
176
+ self.s1.wait_stream(self.s0)
177
+
178
+ stream = self.s1 if self.use_streams else self.s0
179
+ with torch.cuda.stream(stream):
180
+ try:
181
+ cpu_tensor = torch.empty_like(
182
+ activation, pin_memory=self.use_pin_memory, device="cpu"
183
+ )
184
+ except NotImplementedError as e:
185
+ if (
186
+ isinstance(activation, NF4Tensor)
187
+ and torchao.__version__ < "0.6.0.dev20240917"
188
+ ):
189
+ raise RuntimeError(
190
+ "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
191
+ ) from e
192
+ raise e
193
+ cpu_tensor.copy_(activation, non_blocking=True)
194
+ self.tracker[tensor_id] = (
195
+ cpu_tensor,
196
+ True,
197
+ ) # True = (in future) modified
198
+
199
+ if self.use_streams:
200
+ event = self.s1.record_event()
201
+
202
+ # Stash to keep activation alive til s1 is done
203
+ self.fwd_stash[tensor_id] = (activation, event)
204
+ else:
205
+ self.tracker[tensor_id] = (
206
+ activation,
207
+ False,
208
+ ) # False = not modified, tensor is as is
209
+
210
+ return tensor_id
211
+
212
+ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
213
+ # backward pass - we are called with the tensor_id, which
214
+ # we will use to retrieve the saved/offloaded tensor
215
+ if self.is_first_backward_call:
216
+ if self.is_first_forward_pass:
217
+ self.is_first_forward_pass = False
218
+ if self.use_pin_memory:
219
+ verify_sufficient_virtual_memory()
220
+
221
+ self.is_first_backward_call = False
222
+ self.is_first_forward_call = True
223
+
224
+ assert (
225
+ unpack_tensor_id in self.tracker
226
+ ), f"untracked tensor with id {unpack_tensor_id}"
227
+
228
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
229
+ if modified:
230
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
231
+ maybe_gpu_tensor = gpu_tensor
232
+
233
+ # clear tensor from tracking
234
+ del self.tracker[unpack_tensor_id]
235
+ return maybe_gpu_tensor
236
+
237
+ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
238
+ # backward pass - we are called with the tensor_id, which
239
+ # we will use to retrieve the saved/offloaded tensor
240
+ if self.is_first_backward_call:
241
+ self.curr_graph_id = torch._C._current_graph_task_id()
242
+
243
+ def wait_and_del_remaining_references() -> None:
244
+ for id in [k for k in self.bwd_tensor_stash.keys()]:
245
+ event = self.bwd_ev_stash[id]
246
+ self.s1.wait_event(event)
247
+ del self.bwd_tensor_stash[id]
248
+
249
+ # Register a callback to the end of autograd to clean everything up
250
+ torch.autograd.variable.Variable._execution_engine.queue_callback(
251
+ wait_and_del_remaining_references
252
+ )
253
+
254
+ if self.is_first_forward_pass:
255
+ self.is_first_forward_pass = False
256
+ if self.use_pin_memory:
257
+ verify_sufficient_virtual_memory()
258
+
259
+ self.is_first_backward_call = False
260
+ self.is_first_forward_call = True
261
+
262
+ assert (
263
+ unpack_tensor_id in self.tracker
264
+ ), f"untracked tensor with id {unpack_tensor_id}"
265
+
266
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
267
+ if modified:
268
+ # Get data on the current autograd node
269
+ graph_id = torch._C._current_graph_task_id()
270
+ node = torch._C._current_autograd_node()
271
+ prev_node_ids = []
272
+
273
+ # If we're on a new node, mark prev node's tensors to be freed later
274
+ if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
275
+ self.curr_autograd_node = node
276
+ prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
277
+
278
+ brought_back_from_cpu = True
279
+ if unpack_tensor_id in self.fwd_stash:
280
+ maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
281
+ brought_back_from_cpu = False
282
+ else:
283
+ # Kick off the process to bring tensors back
284
+ with torch.cuda.stream(self.s1):
285
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
286
+ maybe_gpu_tensor = gpu_tensor
287
+
288
+ # Tell comp stream to wait for the info to be loaded before executing
289
+ self.s0.wait_stream(self.s1)
290
+
291
+ # Stash the tensor to keep memory alive until compute stream is complete
292
+ self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
293
+
294
+ # Note: [Track views of the unpacked]
295
+ # Why do we get the use count of the unpacked tensor here? We want an
296
+ # initial count to compare to later, during the post-hook of the
297
+ # backward node, when we need to decide whether we're allowed to free
298
+ # the tensor yet. In what obscure cases must we delay freeing the
299
+ # tensor (and thus call record_stream)?
300
+ # 1. Any of the outputs of the backward node is a view of the unpacked
301
+ # tensor.
302
+ # 2. In the case that this unpacked tensor will be used in a
303
+ # checkpointed region, if one of the recomputed saved tensors ends
304
+ # up as a view of the unpacked tensor.
305
+ # 3. The user abuses the system somehow and manually relies on the
306
+ # unpacked tensor to exist after the backward node has executed.
307
+ storage_refcount = torch._C._storage_Use_Count(
308
+ maybe_gpu_tensor.untyped_storage()._cdata
309
+ )
310
+
311
+ def hook(outputs, inputs):
312
+ # create events for the current node inputs/outputs if they were streamed in
313
+ if brought_back_from_cpu:
314
+ # See Note: [Track views of the unpacked]
315
+ # IF any of the outputs is a view of the tensor, OR if a view of
316
+ # the tensor has been saved as a part of checkpoint's recompute
317
+ # process, OR the user has abusedly incurred a reference on the
318
+ # unpacked tensor, THEN the tensor might be used later and we
319
+ # cannot presume to delete it after only the current node is
320
+ # done! So we use our frenemy, record_stream, to ensure the
321
+ # Tensor stays unmessed with until it's done getting used in the
322
+ # compute stream (s0 here). Note that the con here is we introduce
323
+ # non-deterministic (thus higher) memory usage, but this case
324
+ # should not happen often.
325
+ unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
326
+ if (
327
+ torch._C._storage_Use_Count(
328
+ unpacked_tensor.untyped_storage()._cdata
329
+ )
330
+ > storage_refcount
331
+ ):
332
+ unpacked_tensor.record_stream(self.s0)
333
+ del self.bwd_tensor_stash[unpack_tensor_id]
334
+ else:
335
+ event = self.s0.record_event()
336
+ self.bwd_ev_stash[unpack_tensor_id] = event
337
+
338
+ # if there are still things in the fwd_stash, get rid of them as we're in bwd now
339
+ for id in [k for k in self.fwd_stash.keys()]:
340
+ _, ev = self.fwd_stash[id]
341
+ self.s0.wait_event(ev)
342
+ del self.fwd_stash[id]
343
+
344
+ # wait on prev node's events and del those
345
+ for id in prev_node_ids:
346
+ event = self.bwd_ev_stash[id]
347
+ self.s1.wait_event(event)
348
+ del self.bwd_tensor_stash[id]
349
+
350
+ return outputs
351
+
352
+ node.register_hook(hook)
353
+
354
+ # clear tensor from tracking
355
+ del self.tracker[unpack_tensor_id]
356
+ return maybe_gpu_tensor
357
+
358
+ unpack_tensor = (
359
+ unpack_tensor_with_streams
360
+ if self.use_streams
361
+ else unpack_tensor_single_stream
362
+ )
363
+ super().__init__(pack_tensor, unpack_tensor)
364
+
365
+
366
+ class NoOpManager(saved_tensors_hooks):
367
+ """
368
+ A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
369
+ applied before. This relies on the behavior that only the most recently registered
370
+ saved_tensors_hook will run.
371
+
372
+ One example usage is to opt a local region of code out of activations offloading,
373
+ which is usually applied globally to best track state.
374
+ """
375
+
376
+ def __init__(self) -> None:
377
+ def noop(tensor):
378
+ return tensor
379
+
380
+ super().__init__(noop, noop)
381
+
382
+
383
+ def get_act_offloading_ctx_manager(
384
+ model: nn.Module, enable_activation_offloading: bool
385
+ ) -> Union[OffloadActivations, contextlib.nullcontext]:
386
+ """Returns the activation offloading context manager for the model, which will be
387
+ a null context if enable_activation_offloading is False.
388
+
389
+ If activation offloading is enabled, we return the OffloadActivations context manager.
390
+ If activation offloading is disabled, we return a NoOpManager context manager.
391
+
392
+ Args:
393
+ model (nn.Module): the model to wrap with the activation offloading context manager.
394
+ enable_activation_offloading (bool): whether or not to enable activation offloading
395
+ for the model.
396
+
397
+ Returns:
398
+ contextlib.ContextDecorator: the activation offloading context manager for the model.
399
+
400
+ Raises:
401
+ NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
402
+ """
403
+ if enable_activation_offloading:
404
+ activations_handling_ctx = OffloadActivations()
405
+
406
+ # Below is our hack to disable offloading the last output Linear in every
407
+ # step, as the cost for offloading the activation and then soon after bringing
408
+ # it back is expensive. Moreover, due to heuristics in our streaming API,
409
+ # we actually use more memory if we offload it as it interferes with chunkedCE.
410
+ output_head_detected = False
411
+ noop_ctx = NoOpManager()
412
+
413
+ if hasattr(model, "output"):
414
+ if isinstance(model.output, nn.Module):
415
+ model.output.register_forward_pre_hook(
416
+ lambda *args: noop_ctx.__enter__()
417
+ )
418
+ model.output.register_forward_hook(
419
+ lambda *args: noop_ctx.__exit__(), always_call=True
420
+ )
421
+ print("registering hooks for model.output ============ ")
422
+ output_head_detected = True
423
+ # ================================
424
+ # ! TODO[flame] check if we need to detal with TiedLinear
425
+ # The following code appears in `torchtune`
426
+ # elif isinstance(model.output, TiedLinear):
427
+ # model.output.linear.register_forward_pre_hook(
428
+ # lambda *args: noop_ctx.__enter__()
429
+ # )
430
+ # model.output.linear.register_forward_hook(
431
+ # lambda *args: noop_ctx.__exit__(), always_call=True
432
+ # )
433
+ # output_head_detected = True
434
+
435
+ if not output_head_detected:
436
+ logger.warning(
437
+ "During activation offloading, no output head was detected. "
438
+ "If your model has an output head, it will be offloaded. "
439
+ "This usually greatly slows training, given the large vocabulary size. "
440
+ "To change this behavior, set your output head as model.output and make it "
441
+ "an nn.Module."
442
+ )
443
+
444
+ else:
445
+ activations_handling_ctx = contextlib.nullcontext()
446
+
447
+ return activations_handling_ctx
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
flame/utils/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
flame/utils/__pycache__/hf_utils.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-07-16T22:10:00.785425491Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
2
+ {"time":"2025-07-16T22:10:01.508654924Z","level":"INFO","msg":"stream: created new stream","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
3
+ {"time":"2025-07-16T22:10:01.508690211Z","level":"INFO","msg":"stream: started","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
4
+ {"time":"2025-07-16T22:10:01.508739999Z","level":"INFO","msg":"writer: Do: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
5
+ {"time":"2025-07-16T22:10:01.508759314Z","level":"INFO","msg":"handler: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
6
+ {"time":"2025-07-16T22:10:01.508803829Z","level":"INFO","msg":"sender: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
7
+ {"time":"2025-07-16T23:09:45.740737848Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
8
+ {"time":"2025-07-16T23:18:29.56428269Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
9
+ {"time":"2025-07-16T23:19:01.917480335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
10
+ {"time":"2025-07-16T23:19:36.868918826Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
11
+ {"time":"2025-07-16T23:20:16.297827588Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
12
+ {"time":"2025-07-16T23:20:18.619477493Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:46470->127.0.0.53:53: i/o timeout"}
13
+ {"time":"2025-07-16T23:20:30.740650327Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:47482->127.0.0.53:53: i/o timeout"}
14
+ {"time":"2025-07-16T23:21:04.536690541Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
15
+ {"time":"2025-07-16T23:21:49.291673175Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
16
+ {"time":"2025-07-16T23:22:07.542159208Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
17
+ {"time":"2025-07-16T23:23:23.103733736Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
18
+ {"time":"2025-07-16T23:23:37.543151076Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
19
+ {"time":"2025-07-16T23:25:07.544031298Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
20
+ {"time":"2025-07-16T23:26:37.545971769Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
21
+ {"time":"2025-07-16T23:27:42.194377246Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
22
+ {"time":"2025-07-16T23:27:59.564813743Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000912631,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
23
+ {"time":"2025-07-16T23:28:07.547697617Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
24
+ {"time":"2025-07-16T23:29:37.549836886Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"}
25
+ {"time":"2025-07-16T23:31:01.930916994Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000672411,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
26
+ {"time":"2025-07-16T23:31:02.101966833Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000995925,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
27
+ {"time":"2025-07-16T23:31:07.103368571Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000796336,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
28
+ {"time":"2025-07-16T23:31:07.551682713Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
29
+ {"time":"2025-07-16T23:32:37.553473869Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
30
+ {"time":"2025-07-16T23:33:58.248779065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
31
+ {"time":"2025-07-16T23:34:58.351555112Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1018.787711083,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
32
+ {"time":"2025-07-16T23:34:58.351650283Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.421498346,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
33
+ {"time":"2025-07-16T23:34:58.351778293Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":831.249242004,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
34
+ {"time":"2025-07-16T23:34:58.351785775Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.250829923,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
35
+ {"time":"2025-07-17T01:31:13.353253854Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
36
+ {"time":"2025-07-17T08:06:16.748740406Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
37
+ {"time":"2025-07-17T09:50:19.526737851Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": read tcp 10.0.2.15:54882->35.186.228.49:443: read: connection reset by peer"}
38
+ {"time":"2025-07-17T09:52:30.348552703Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
39
+ {"time":"2025-07-17T09:53:02.422139335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
40
+ {"time":"2025-07-17T09:53:36.600890938Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
41
+ {"time":"2025-07-17T09:54:16.203516351Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
42
+ {"time":"2025-07-17T09:55:05.357439477Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
43
+ {"time":"2025-07-17T09:56:15.05960959Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
44
+ {"time":"2025-07-17T09:57:45.061688428Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
45
+ {"time":"2025-07-17T09:59:15.063226591Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
46
+ {"time":"2025-07-17T10:00:45.065259852Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
47
+ {"time":"2025-07-17T10:01:04.518171545Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
48
+ {"time":"2025-07-17T10:02:00.347889757Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000372919,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
49
+ {"time":"2025-07-17T10:02:15.066174619Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
50
+ {"time":"2025-07-17T10:03:45.067145051Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
51
+ {"time":"2025-07-17T10:05:02.098970791Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000073665,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
52
+ {"time":"2025-07-17T10:05:07.474477054Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000841939,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
53
+ {"time":"2025-07-17T10:05:15.068468165Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
54
+ {"time":"2025-07-17T10:05:16.930808745Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000229861,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
55
+ {"time":"2025-07-17T10:06:07.008582668Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
56
+ {"time":"2025-07-17T10:06:45.070340311Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
57
+ {"time":"2025-07-17T10:07:57.799911415Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
58
+ {"time":"2025-07-17T10:08:57.969386735Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1017.621908973,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
59
+ {"time":"2025-07-17T10:08:57.969579361Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":835.870728331,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
60
+ {"time":"2025-07-17T10:08:57.969680501Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":821.039158554,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
61
+ {"time":"2025-07-17T10:08:57.969682134Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":830.496074059,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
62
+ {"time":"2025-07-17T12:53:12.780364188Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
63
+ {"time":"2025-07-17T16:43:31.998287109Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
64
+ {"time":"2025-07-18T00:01:06.015630566Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
65
+ {"time":"2025-07-18T06:56:24.118529653Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
66
+ {"time":"2025-07-18T14:32:12.830145916Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
67
+ {"time":"2025-07-18T19:51:31.703829065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
68
+ {"time":"2025-07-19T03:35:03.743864446Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
69
+ {"time":"2025-07-19T21:22:32.639517404Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:51870->35.186.228.49:443: read: connection reset by peer"}
70
+ {"time":"2025-07-19T21:31:32.643369264Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:38482->35.186.228.49:443: read: connection reset by peer"}
71
+ {"time":"2025-07-20T00:27:42.221361901Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
72
+ {"time":"2025-07-20T09:40:16.319872482Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
73
+ {"time":"2025-07-20T09:45:18.218885403Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
74
+ {"time":"2025-07-20T19:19:37.674808147Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
75
+ {"time":"2025-07-20T20:26:46.102126738Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
76
+ {"time":"2025-07-20T21:40:42.245223721Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
77
+ {"time":"2025-07-20T21:42:31.526229193Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
78
+ {"time":"2025-07-20T22:42:07.859288654Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
79
+ {"time":"2025-07-21T03:41:28.397742169Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
80
+ {"time":"2025-07-21T04:49:16.742257697Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
81
+ {"time":"2025-07-21T05:48:28.62347913Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
82
+ {"time":"2025-07-21T06:22:31.529351974Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
83
+ {"time":"2025-07-21T14:47:44.545628902Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
84
+ {"time":"2025-07-21T21:19:44.840025606Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
85
+ {"time":"2025-07-21T21:19:44.94975041Z","level":"INFO","msg":"handler: operation stats","stats":{}}
86
+ {"time":"2025-07-21T21:19:44.958211652Z","level":"INFO","msg":"stream: closing","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
87
+ {"time":"2025-07-21T21:19:44.958407771Z","level":"INFO","msg":"writer: Close: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
88
+ {"time":"2025-07-21T21:19:44.958426934Z","level":"INFO","msg":"handler: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
89
+ {"time":"2025-07-21T21:19:44.958428316Z","level":"INFO","msg":"sender: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
90
+ {"time":"2025-07-21T21:19:44.958480192Z","level":"INFO","msg":"stream: closed","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
torchtitan/components/__pycache__/float8.cpython-312.pyc ADDED
Binary file (6.2 kB). View file
 
torchtitan/components/__pycache__/ft.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
torchtitan/components/float8.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # [Note] Getting the 'torchao' package:
8
+ # This script requires the 'torchao' package to function correctly.
9
+ # Please ensure you have this package installed from the appropriate repository.
10
+ # You can obtain it from https://github.com/pytorch/ao by following the
11
+ # installation instructions.
12
+
13
+ # Note: Performance
14
+ # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.distributed import ParallelDims
21
+ from torchtitan.protocols.model_converter import (
22
+ ModelConverter,
23
+ register_model_converter,
24
+ )
25
+ from torchtitan.tools.logging import logger
26
+
27
+
28
+ def _is_sm89_or_later():
29
+ # Float8 is only supported on SM89 or later (H100+ GPUs)
30
+ return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
31
+
32
+
33
+ class Float8Converter(ModelConverter):
34
+ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
35
+ self.enabled = False
36
+
37
+ float8_config = job_config.float8
38
+ if not _is_sm89_or_later():
39
+ logger.warning(
40
+ "Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
41
+ )
42
+ return
43
+ try:
44
+ from torchao.float8 import Float8LinearConfig
45
+ except ImportError as e:
46
+ raise ImportError(
47
+ "torchao is not installed. Please install it to use float8 linear layers."
48
+ ) from e
49
+
50
+ if float8_config.recipe_name is not None and not hasattr(
51
+ Float8LinearConfig, "from_recipe_name"
52
+ ):
53
+ logger.warning(
54
+ "Failed to swap to Float8Linear with recipe lookup because the torchao version "
55
+ "is too old, please install torchao v0.9.0 or later and try again",
56
+ )
57
+ return
58
+
59
+ self.enabled = True
60
+ self.filter_fqns = float8_config.filter_fqns
61
+
62
+ if float8_config.recipe_name is not None:
63
+ assert (
64
+ not float8_config.enable_fsdp_float8_all_gather
65
+ ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
66
+ assert (
67
+ not float8_config.force_recompute_fp8_weight_in_bwd
68
+ ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
69
+ self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
70
+ self.precompute_scale = False
71
+ logger.info(
72
+ f"Float8 training active with recipe {float8_config.recipe_name}"
73
+ )
74
+
75
+ else:
76
+ # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
77
+ enable_fsdp_float8_all_gather = (
78
+ parallel_dims.dp_shard_enabled
79
+ and float8_config.enable_fsdp_float8_all_gather
80
+ )
81
+ self.config = Float8LinearConfig(
82
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
83
+ force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
84
+ )
85
+ # for precompute_float8_dynamic_scale_for_fsdp
86
+ self.precompute_scale = (
87
+ enable_fsdp_float8_all_gather
88
+ and float8_config.precompute_float8_dynamic_scale_for_fsdp
89
+ )
90
+ logger.info("Float8 tensorwise scaled training active")
91
+
92
+ def convert(self, model: nn.Module):
93
+ return self.convert_to_float8_training(model)
94
+
95
+ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
96
+ return self.precompute_float8_dynamic_scale_for_fsdp(model)
97
+
98
+ def convert_to_float8_training(self, model: nn.Module):
99
+ """
100
+ This function converts the linear layers of `model` to `Float8Linear`.
101
+ Note that today, only dynamic tensor scaling (the default) is supported.
102
+ This will mutate the model inplace.
103
+ """
104
+ if not self.enabled:
105
+ return
106
+
107
+ from torchao.float8 import convert_to_float8_training
108
+
109
+ # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
110
+ convert_to_float8_training(
111
+ model,
112
+ config=self.config,
113
+ module_filter_fn=self._module_filter_fn,
114
+ )
115
+ logger.info(
116
+ "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
117
+ f"{self.config.enable_fsdp_float8_all_gather}"
118
+ )
119
+
120
+ def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
121
+ if not isinstance(mod, nn.Linear):
122
+ return False
123
+
124
+ # All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
125
+ dims_multiples_of_16 = (
126
+ mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
127
+ )
128
+
129
+ # If the fqn matches any filtered fqn, then we should not convert this module.
130
+ is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
131
+
132
+ return dims_multiples_of_16 and not is_filtered_fqn
133
+
134
+ def precompute_float8_dynamic_scale_for_fsdp(
135
+ self, model: nn.Module | list[nn.Module]
136
+ ):
137
+ if not self.enabled:
138
+ return
139
+
140
+ if not self.precompute_scale:
141
+ return
142
+
143
+ from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
144
+
145
+ models = [model] if isinstance(model, nn.Module) else model
146
+ for m in models:
147
+ precompute_float8_dynamic_scale_for_fsdp(m)
148
+
149
+
150
+ register_model_converter(Float8Converter, "float8")
torchtitan/components/ft.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import importlib
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.distributed._functional_collectives as funcol
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.tensor import DTensor
16
+ from torchtitan.config_manager import JobConfig
17
+ from torchtitan.distributed import ParallelDims
18
+
19
+ if importlib.util.find_spec("torchft") is not None:
20
+ import torchft as ft
21
+
22
+ has_torchft = True
23
+ else:
24
+ has_torchft = False
25
+
26
+
27
+ class FTManager:
28
+ def __init__(
29
+ self,
30
+ manager: Optional["ft.Manager"],
31
+ group_size: int = 1,
32
+ replica_id: int = 0,
33
+ ) -> None:
34
+ self._manager = manager
35
+ self.group_size = group_size
36
+ self.replica_id = replica_id
37
+
38
+ @property
39
+ def enabled(self) -> bool:
40
+ return self._manager is not None
41
+
42
+ @property
43
+ def manager(self) -> "ft.Manager":
44
+ assert self._manager is not None
45
+ return self._manager
46
+
47
+ def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
48
+ return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
49
+
50
+
51
+ def init_ft_manager(job: JobConfig) -> FTManager:
52
+ """Initialize the FT manager if TorchFT is enabled.
53
+
54
+ Args:
55
+ job (JobConfig): The job configuration.
56
+
57
+ Returns:
58
+ Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
59
+ """
60
+ if not job.fault_tolerance.enable:
61
+ return FTManager(None)
62
+
63
+ if not has_torchft:
64
+ raise ImportError("torchft is not installed. Please install it.")
65
+
66
+ if job.fault_tolerance.min_replica_size < 1:
67
+ raise ValueError("At least one FT replica is required.")
68
+
69
+ pg = ft.ProcessGroupBabyNCCL()
70
+
71
+ return FTManager(
72
+ ft.Manager(
73
+ pg=pg,
74
+ min_replica_size=job.fault_tolerance.min_replica_size,
75
+ load_state_dict=None,
76
+ state_dict=None,
77
+ use_async_quorum=True,
78
+ replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
79
+ ),
80
+ group_size=job.fault_tolerance.group_size,
81
+ replica_id=job.fault_tolerance.replica_id,
82
+ )
83
+
84
+
85
+ @dataclass
86
+ class FTParallelDims(ParallelDims):
87
+ ft_manager: FTManager
88
+
89
+ def build_mesh(self, device_type: str) -> DeviceMesh:
90
+ def func(
91
+ device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
92
+ ) -> DeviceMesh:
93
+ from torchft.process_group import ft_init_device_mesh
94
+
95
+ return ft_init_device_mesh(
96
+ device_type=device_type,
97
+ mesh_shape=mesh_shape,
98
+ mesh_dim_names=mesh_dim_names,
99
+ replicate_dim=mesh_dim_names.index("dp_replicate"),
100
+ manager=self.ft_manager.manager,
101
+ )
102
+
103
+ dims = []
104
+ names = []
105
+ for d, name in zip(
106
+ [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
107
+ ["pp", "dp_replicate", "dp_shard", "cp", "tp"],
108
+ ):
109
+ if d > 1 or name == "dp_replicate":
110
+ dims.append(d)
111
+ names.append(name)
112
+
113
+ return self._build_mesh(device_type, dims, names, func)
114
+
115
+ @property
116
+ def dp_replicate_enabled(self):
117
+ return True
118
+
119
+
120
+ def ft_dist_reduce(
121
+ x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
122
+ ) -> tuple[torch.Tensor, str, DeviceMesh]:
123
+ if has_torchft and isinstance(mesh, ft.process_group._FlattenDeviceMesh):
124
+ x = funcol.all_reduce(
125
+ x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg
126
+ )
127
+ return x, reduceOp, mesh.managed_mesh.mesh
128
+ return x, reduceOp, mesh
129
+
130
+
131
+ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
132
+ if has_torchft:
133
+ mesh = total_norm._spec.mesh
134
+ if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
135
+ # The gradients along the replicated dim has already been reduced.
136
+ # So we don't need another reducution beforing removing the
137
+ # replicate dimension
138
+ local_tensor = total_norm.to_local()
139
+ placements = list(copy.copy(total_norm._spec.placements))
140
+ placements.pop(mesh.replicate_dim)
141
+ return DTensor.from_local(local_tensor, mesh.mesh, placements)
142
+
143
+ return total_norm
torchtitan/experiments/deepseek_v3/model_config.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class ModelArgs:
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+ Args:
19
+ vocab_size (`int`, *optional*, defaults to 129280):
20
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
21
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
22
+ hidden_size (`int`, *optional*, defaults to 4096):
23
+ Dimension of the hidden representations.
24
+ intermediate_size (`int`, *optional*, defaults to 11008):
25
+ Dimension of the MLP representations.
26
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
27
+ Dimension of the MoE representations.
28
+ num_hidden_layers (`int`, *optional*, defaults to 32):
29
+ Number of hidden layers in the Transformer decoder.
30
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
31
+ Number of nextn predict layers in the DeepSeekV3 Model.
32
+ num_attention_heads (`int`, *optional*, defaults to 32):
33
+ Number of attention heads for each attention layer in the Transformer decoder.
34
+ n_shared_experts (`int`, *optional*, defaults to None):
35
+ Number of shared experts, None means dense model.
36
+ n_routed_experts (`int`, *optional*, defaults to None):
37
+ Number of routed experts, None means dense model.
38
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
39
+ Scaling factor or routed experts.
40
+ topk_method (`str`, *optional*, defaults to `gready`):
41
+ Topk method used in routed gate.
42
+ n_group (`int`, *optional*, defaults to None):
43
+ Number of groups for routed experts.
44
+ topk_group (`int`, *optional*, defaults to None):
45
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within
46
+ `topk_group` groups).
47
+ num_experts_per_tok (`int`, *optional*, defaults to None):
48
+ Number of selected experts, None means dense model.
49
+ moe_layer_freq (`int`, *optional*, defaults to 1):
50
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
51
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
52
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
53
+ \--k dense layers--/
54
+ norm_topk_prob (`bool`, *optional*, defaults to False):
55
+ Whether to normalize the weights of the routed experts.
56
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
57
+ Method of computing expert weights.
58
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
59
+ Auxiliary loss weight coefficient.
60
+ seq_aux = (`bool`, *optional*, defaults to True):
61
+ Whether to compute the auxiliary loss for each individual sample.
62
+ num_key_value_heads (`int`, *optional*):
63
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
64
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
65
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
66
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
67
+ by meanpooling all the original heads within that group. For more details checkout [this
68
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
69
+ `num_attention_heads`.
70
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function (function or string) in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ pad_token_id (`int`, *optional*):
82
+ Padding token id.
83
+ bos_token_id (`int`, *optional*, defaults to 1):
84
+ Beginning of stream token id.
85
+ eos_token_id (`int`, *optional*, defaults to 2):
86
+ End of stream token id.
87
+ pretraining_tp (`int`, *optional*, defaults to 1):
88
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
89
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
90
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
91
+ issue](https://github.com/pytorch/pytorch/issues/76232).
92
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to tie weight embeddings
94
+ rope_theta (`float`, *optional*, defaults to 10000.0):
95
+ The base period of the RoPE embeddings.
96
+ rope_scaling (`Dict`, *optional*):
97
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
98
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
99
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
100
+ `max_position_embeddings` to the expected new maximum.
101
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
102
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
103
+ attention_dropout (`float`, *optional*, defaults to 0.0):
104
+ The dropout ratio for the attention probabilities.
105
+ """
106
+
107
+ vocab_size: int = 129280
108
+ hidden_size: int = 7168
109
+ intermediate_size: int = 18432
110
+ moe_intermediate_size: int = 2048
111
+ num_hidden_layers: int = 61
112
+ num_nextn_predict_layers: int = 1
113
+ num_attention_heads: int = 128
114
+ num_key_value_heads: int = 128
115
+ n_shared_experts: int = 1
116
+ n_routed_experts: int = 256
117
+ ep_size: int = 1
118
+ routed_scaling_factor: float = 2.5
119
+ kv_lora_rank: int = 512
120
+ q_lora_rank: int = 1536
121
+ qk_rope_head_dim: int = 64
122
+ v_head_dim: int = 128
123
+ qk_nope_head_dim: int = 128
124
+ topk_method: str = "noaux_tc"
125
+ n_group: int = 8
126
+ topk_group: int = 4
127
+ num_experts_per_tok: int = 8
128
+ moe_layer_freq: int = 1
129
+ first_k_dense_replace: int = 3
130
+ norm_topk_prob: bool = True
131
+ scoring_func: str = "sigmoid"
132
+ aux_loss_alpha: float = 0.001
133
+ seq_aux: bool = True
134
+ hidden_act: str = "silu"
135
+ max_position_embeddings: int = 163840
136
+ initializer_range: float = 0.02
137
+ rms_norm_eps: float = 1e-6
138
+ rope_theta: float = 10000.0
139
+ rope_scaling: dict = field(
140
+ default_factory=lambda: {
141
+ "beta_fast": 32,
142
+ "beta_slow": 1,
143
+ "factor": 40,
144
+ "mscale": 1.0,
145
+ "mscale_all_dim": 1.0,
146
+ "original_max_position_embeddings": 4096,
147
+ "type": "yarn",
148
+ }
149
+ )
150
+ attention_bias: bool = False
151
+ attention_dropout: float = 0.0
152
+ pad_token_id = None
153
+ # Added for symmetric memory
154
+ max_seq_len: int = 4096
155
+ dtype: str = "bfloat16"
156
+ # Added for pipeline parallel
157
+ num_stages: int = 1
158
+ stage_idx: int = 0
159
+
160
+
161
+ # This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
162
+ deepseek_v2_lite_config = ModelArgs(
163
+ vocab_size=102400,
164
+ hidden_size=2048,
165
+ intermediate_size=10944,
166
+ moe_intermediate_size=1408,
167
+ num_hidden_layers=27,
168
+ num_attention_heads=16,
169
+ num_key_value_heads=16,
170
+ n_shared_experts=2,
171
+ n_routed_experts=64,
172
+ routed_scaling_factor=1.0,
173
+ kv_lora_rank=512,
174
+ q_lora_rank=None,
175
+ qk_rope_head_dim=64,
176
+ v_head_dim=128,
177
+ qk_nope_head_dim=128,
178
+ topk_method="greedy",
179
+ n_group=1,
180
+ topk_group=1,
181
+ num_experts_per_tok=6,
182
+ first_k_dense_replace=1,
183
+ norm_topk_prob=False,
184
+ scoring_func="softmax",
185
+ max_position_embeddings=4096,
186
+ rope_scaling={
187
+ "beta_fast": 32,
188
+ "beta_slow": 1,
189
+ "factor": 40,
190
+ "mscale": 0.707,
191
+ "mscale_all_dim": 0.707,
192
+ "original_max_position_embeddings": 4096,
193
+ "type": "yarn",
194
+ },
195
+ )
196
+
197
+
198
+ # Model configuration registry
199
+ # Key is the model distribution ID on HuggingFace Hub
200
+ deepseek_config_registry = {
201
+ "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
202
+ "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
203
+ "deepseek-ai/deepseek-v3": ModelArgs(),
204
+ }
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
8
+
9
+ __all__ = [
10
+ "OnDeviceAllToAllV",
11
+ ]
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/deepseek_v3/train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 8 run.py
8
+ import torch
9
+ import torch.distributed as dist
10
+ from checkpoint import load_weights_from_hf
11
+ from model import DeepseekForCausalLM
12
+ from model_config import deepseek_config_registry
13
+
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.fsdp import fully_shard
16
+ from torch.distributed.pipelining import PipelineStage, Schedule1F1B
17
+
18
+
19
+ # Use DeepSeek-V2-Lite as a proxy
20
+ model_id = "deepseek-ai/DeepSeek-V2-Lite"
21
+
22
+
23
+ # Run full model
24
+ def run_full_model(
25
+ mesh: DeviceMesh,
26
+ ):
27
+ rank = dist.get_rank()
28
+ device_count = torch.cuda.device_count()
29
+ device = torch.device("cuda", rank % device_count)
30
+
31
+ pp_mesh = mesh["pp"]
32
+ ep_mesh = mesh["ep"]
33
+ pp_rank = pp_mesh.get_local_rank()
34
+ ep_rank = ep_mesh.get_local_rank()
35
+ pp_size = pp_mesh.size()
36
+ ep_size = ep_mesh.size()
37
+
38
+ # Get model configs
39
+ model_args = deepseek_config_registry[model_id]
40
+ # [Note]: I am making the model smaller for testing / avoiding OOM. If you
41
+ # have sufficient GPUs for model parallelism, you can remove this line.
42
+ model_args.num_hidden_layers = 16
43
+
44
+ # Apply model parallelism
45
+ model_args.ep_size = ep_size
46
+ model_args.num_stages = pp_size
47
+ model_args.stage_idx = pp_rank
48
+ print(model_args)
49
+
50
+ # Instantiate model
51
+ with device, mesh:
52
+ model = DeepseekForCausalLM(model_args)
53
+
54
+ # Load weights
55
+ load_weights_from_hf(model, model_id, device)
56
+ model.train()
57
+
58
+ # Apply data parallelism
59
+ fsdp_mesh = mesh["fsdp"]
60
+ hsdp_mesh = mesh["ep", "fsdp"]
61
+ # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
62
+ # optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
63
+ # Reason: the MoE is "sparsely activated" compared to the dense model, thus
64
+ # it will be ineconomical re-gather the weights.
65
+ for layer in model.model.layers.values():
66
+ # Apply FSDP to experts
67
+ if hasattr(layer.mlp, "experts"):
68
+ for expert in layer.mlp.experts.values():
69
+ fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
70
+ # Apply HSDP to other parts such as attention, layernorm, because they
71
+ # are doing DDP on EP dimension
72
+ fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
73
+
74
+ # Apply HSDP on root model (lm_head, embeddings, etc)
75
+ fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
76
+
77
+ # Synthetic setting
78
+ microbatches = pp_size * 2
79
+
80
+ # Use Symmetric Memory for MoE token shuffle.
81
+ # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
82
+ # currently supported for forward only. See `generate.py`.
83
+ # model.setup_symm_mem(torch.bfloat16, device)
84
+
85
+ # Example inputs
86
+ torch.manual_seed(ep_rank)
87
+ bs = 4
88
+ seqlen = 128
89
+ x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
90
+ label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
91
+
92
+ # Create loss function
93
+ loss_fn = torch.nn.functional.cross_entropy
94
+
95
+ # Run forward and backward
96
+ steps = 2
97
+ for _ in range(steps):
98
+ if pp_size > 1:
99
+ # Create pipeline stage
100
+ stage = PipelineStage(
101
+ model,
102
+ pp_rank,
103
+ pp_size,
104
+ device,
105
+ group=pp_mesh.get_group(),
106
+ )
107
+
108
+ # Create pipeline schedule
109
+ losses = []
110
+ pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
111
+
112
+ if pp_rank == 0:
113
+ y = pp_schedule.step(x)
114
+ elif pp_rank == pp_size - 1:
115
+ y = pp_schedule.step(target=label, losses=losses)
116
+ loss = torch.mean(torch.stack(losses))
117
+ else:
118
+ pp_schedule.step()
119
+ else:
120
+ y = model(x)
121
+ loss = loss_fn(y, label)
122
+ loss.backward()
123
+
124
+ if pp_rank == pp_size - 1:
125
+ print(f"logits: {y.shape}")
126
+ print(f"{loss=}")
127
+
128
+ if pp_rank == 0:
129
+ param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
130
+ print(f"{torch.linalg.norm(param.grad)=}")
131
+
132
+ model.zero_grad()
133
+
134
+ print("Backward done")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
139
+
140
+ run_full_model(mesh)
141
+
142
+ dist.destroy_process_group()
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/flux_argparser.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+
11
+
12
+ def extend_parser(parser: argparse.ArgumentParser) -> None:
13
+ parser.add_argument(
14
+ "--training.guidance",
15
+ type=float,
16
+ default=3.5,
17
+ help="guidance value used for guidance distillation",
18
+ )
19
+ parser.add_argument(
20
+ "--encoder.t5_encoder",
21
+ type=str,
22
+ default="google/t5-v1_1-small",
23
+ help="T5 encoder to use, HuggingFace model name.",
24
+ )
25
+ parser.add_argument(
26
+ "--encoder.clip_encoder",
27
+ type=str,
28
+ default="openai/clip-vit-large-patch14",
29
+ help="Clip encoder to use, HuggingFace model name.",
30
+ )
31
+ parser.add_argument(
32
+ "--encoder.encoder_dtype",
33
+ type=torch.dtype,
34
+ default=torch.bfloat16,
35
+ help="Which dtype to load for autoencoder. ",
36
+ )
37
+ parser.add_argument(
38
+ "--encoder.max_t5_encoding_len",
39
+ type=int,
40
+ default=512,
41
+ help="Maximum length of the T5 encoding.",
42
+ )
torchtitan/experiments/flux/model/autoencoder.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from safetensors.torch import load_file as load_sft
13
+ from torch import nn, Tensor
14
+
15
+
16
+ @dataclass
17
+ class AutoEncoderParams:
18
+ resolution: int = 256
19
+ in_channels: int = 3
20
+ ch: int = 128
21
+ out_ch: int = 3
22
+ ch_mult: tuple[int] = (1, 2, 4, 4)
23
+ num_res_blocks: int = 2
24
+ z_channels: int = 16
25
+ scale_factor: float = 0.3611
26
+ shift_factor: float = 0.1159
27
+
28
+
29
+ def swish(x: Tensor) -> Tensor:
30
+ return x * torch.sigmoid(x)
31
+
32
+
33
+ class AttnBlock(nn.Module):
34
+ def __init__(self, in_channels: int):
35
+ super().__init__()
36
+ self.in_channels = in_channels
37
+
38
+ self.norm = nn.GroupNorm(
39
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
40
+ )
41
+
42
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
43
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
44
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
45
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
46
+
47
+ def attention(self, h_: Tensor) -> Tensor:
48
+ h_ = self.norm(h_)
49
+ q = self.q(h_)
50
+ k = self.k(h_)
51
+ v = self.v(h_)
52
+
53
+ b, c, h, w = q.shape
54
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
55
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
56
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
57
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
58
+
59
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return x + self.proj_out(self.attention(x))
63
+
64
+
65
+ class ResnetBlock(nn.Module):
66
+ def __init__(self, in_channels: int, out_channels: int):
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+
72
+ self.norm1 = nn.GroupNorm(
73
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
74
+ )
75
+ self.conv1 = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
77
+ )
78
+ self.norm2 = nn.GroupNorm(
79
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
80
+ )
81
+ self.conv2 = nn.Conv2d(
82
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
83
+ )
84
+ if self.in_channels != self.out_channels:
85
+ self.nin_shortcut = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
87
+ )
88
+
89
+ def forward(self, x):
90
+ h = x
91
+ h = self.norm1(h)
92
+ h = swish(h)
93
+ h = self.conv1(h)
94
+
95
+ h = self.norm2(h)
96
+ h = swish(h)
97
+ h = self.conv2(h)
98
+
99
+ if self.in_channels != self.out_channels:
100
+ x = self.nin_shortcut(x)
101
+
102
+ return x + h
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, in_channels: int):
107
+ super().__init__()
108
+ # no asymmetric padding in torch conv, must do it ourselves
109
+ self.conv = nn.Conv2d(
110
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
111
+ )
112
+
113
+ def forward(self, x: Tensor):
114
+ pad = (0, 1, 0, 1)
115
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
116
+ x = self.conv(x)
117
+ return x
118
+
119
+
120
+ class Upsample(nn.Module):
121
+ def __init__(self, in_channels: int):
122
+ super().__init__()
123
+ self.conv = nn.Conv2d(
124
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
125
+ )
126
+
127
+ def forward(self, x: Tensor):
128
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
129
+ x = self.conv(x)
130
+ return x
131
+
132
+
133
+ class Encoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ resolution: int,
137
+ in_channels: int,
138
+ ch: int,
139
+ ch_mult: list[int],
140
+ num_res_blocks: int,
141
+ z_channels: int,
142
+ ):
143
+ super().__init__()
144
+ self.ch = ch
145
+ self.num_resolutions = len(ch_mult)
146
+ self.num_res_blocks = num_res_blocks
147
+ self.resolution = resolution
148
+ self.in_channels = in_channels
149
+ # downsampling
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
152
+ )
153
+
154
+ curr_res = resolution
155
+ in_ch_mult = (1,) + tuple(ch_mult)
156
+ self.in_ch_mult = in_ch_mult
157
+ self.down = nn.ModuleList()
158
+ block_in = self.ch
159
+ for i_level in range(self.num_resolutions):
160
+ block = nn.ModuleList()
161
+ attn = nn.ModuleList()
162
+ block_in = ch * in_ch_mult[i_level]
163
+ block_out = ch * ch_mult[i_level]
164
+ for _ in range(self.num_res_blocks):
165
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
166
+ block_in = block_out
167
+ down = nn.Module()
168
+ down.block = block
169
+ down.attn = attn
170
+ if i_level != self.num_resolutions - 1:
171
+ down.downsample = Downsample(block_in)
172
+ curr_res = curr_res // 2
173
+ self.down.append(down)
174
+
175
+ # middle
176
+ self.mid = nn.Module()
177
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
178
+ self.mid.attn_1 = AttnBlock(block_in)
179
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
180
+
181
+ # end
182
+ self.norm_out = nn.GroupNorm(
183
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
184
+ )
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ # downsampling
191
+ hs = [self.conv_in(x)]
192
+ for i_level in range(self.num_resolutions):
193
+ for i_block in range(self.num_res_blocks):
194
+ h = self.down[i_level].block[i_block](hs[-1])
195
+ if len(self.down[i_level].attn) > 0:
196
+ h = self.down[i_level].attn[i_block](h)
197
+ hs.append(h)
198
+ if i_level != self.num_resolutions - 1:
199
+ hs.append(self.down[i_level].downsample(hs[-1]))
200
+
201
+ # middle
202
+ h = hs[-1]
203
+ h = self.mid.block_1(h)
204
+ h = self.mid.attn_1(h)
205
+ h = self.mid.block_2(h)
206
+ # end
207
+ h = self.norm_out(h)
208
+ h = swish(h)
209
+ h = self.conv_out(h)
210
+ return h
211
+
212
+
213
+ class Decoder(nn.Module):
214
+ def __init__(
215
+ self,
216
+ ch: int,
217
+ out_ch: int,
218
+ ch_mult: list[int],
219
+ num_res_blocks: int,
220
+ in_channels: int,
221
+ resolution: int,
222
+ z_channels: int,
223
+ ):
224
+ super().__init__()
225
+ self.ch = ch
226
+ self.num_resolutions = len(ch_mult)
227
+ self.num_res_blocks = num_res_blocks
228
+ self.resolution = resolution
229
+ self.in_channels = in_channels
230
+ self.ffactor = 2 ** (self.num_resolutions - 1)
231
+
232
+ # compute in_ch_mult, block_in and curr_res at lowest res
233
+ block_in = ch * ch_mult[self.num_resolutions - 1]
234
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
235
+ self.z_shape = (1, z_channels, curr_res, curr_res)
236
+
237
+ # z to block_in
238
+ self.conv_in = nn.Conv2d(
239
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
240
+ )
241
+
242
+ # middle
243
+ self.mid = nn.Module()
244
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
247
+
248
+ # upsampling
249
+ self.up = nn.ModuleList()
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = nn.Module()
258
+ up.block = block
259
+ up.attn = attn
260
+ if i_level != 0:
261
+ up.upsample = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ # get dtype for proper tracing
273
+ upscale_dtype = next(self.up.parameters()).dtype
274
+
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # cast to proper dtype
284
+ h = h.to(upscale_dtype)
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](h)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = swish(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class DiagonalGaussian(nn.Module):
302
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
303
+ super().__init__()
304
+ self.sample = sample
305
+ self.chunk_dim = chunk_dim
306
+
307
+ def forward(self, z: Tensor) -> Tensor:
308
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
309
+ if self.sample:
310
+ std = torch.exp(0.5 * logvar)
311
+ return mean + std * torch.randn_like(mean)
312
+ else:
313
+ return mean
314
+
315
+
316
+ class AutoEncoder(nn.Module):
317
+ def __init__(self, params: AutoEncoderParams):
318
+ super().__init__()
319
+ self.params = params
320
+ self.encoder = Encoder(
321
+ resolution=params.resolution,
322
+ in_channels=params.in_channels,
323
+ ch=params.ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.decoder = Decoder(
329
+ resolution=params.resolution,
330
+ in_channels=params.in_channels,
331
+ ch=params.ch,
332
+ out_ch=params.out_ch,
333
+ ch_mult=params.ch_mult,
334
+ num_res_blocks=params.num_res_blocks,
335
+ z_channels=params.z_channels,
336
+ )
337
+ self.reg = DiagonalGaussian()
338
+
339
+ self.scale_factor = params.scale_factor
340
+ self.shift_factor = params.shift_factor
341
+
342
+ def encode(self, x: Tensor) -> Tensor:
343
+ z = self.reg(self.encoder(x))
344
+ z = self.scale_factor * (z - self.shift_factor)
345
+ return z
346
+
347
+ def decode(self, z: Tensor) -> Tensor:
348
+ z = z / self.scale_factor + self.shift_factor
349
+ return self.decoder(z)
350
+
351
+ def forward(self, x: Tensor) -> Tensor:
352
+ return self.decode(self.encode(x))
353
+
354
+
355
+ def load_ae(
356
+ ckpt_path: str,
357
+ autoencoder_params: AutoEncoderParams,
358
+ device: str | torch.device = "cuda",
359
+ dtype=torch.bfloat16,
360
+ ) -> AutoEncoder:
361
+ """
362
+ Load the autoencoder from the given model name.
363
+ Args:
364
+ name (str): The name of the autoencoder.
365
+ device (str or torch.device): The device to load the autoencoder to.
366
+ Returns:
367
+ AutoEncoder: The loaded autoencoder.
368
+ """
369
+ # Loading the autoencoder
370
+ print("Init AE")
371
+ with torch.device(device):
372
+ ae = AutoEncoder(autoencoder_params)
373
+
374
+ if not os.path.exists(ckpt_path):
375
+ raise ValueError(
376
+ f"Autoencoder path {ckpt_path} does not exist. Please download it first."
377
+ )
378
+
379
+ if ckpt_path is not None:
380
+ sd = load_sft(ckpt_path, device=str(device))
381
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
382
+ if len(missing) > 0:
383
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
384
+ if len(unexpected) > 0:
385
+ print(
386
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
387
+ )
388
+ return ae.to(dtype=dtype)
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/layers.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # imported from black-forest-labs/FLUX
8
+ import math
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from einops import rearrange
13
+ from torch import nn, Tensor
14
+
15
+ from torchtitan.experiments.flux.model.math import attention, rope
16
+
17
+
18
+ class EmbedND(nn.Module):
19
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -math.log(max_period)
48
+ * torch.arange(start=0, end=half, dtype=torch.float32)
49
+ / half
50
+ ).to(t.device)
51
+
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ if torch.is_floating_point(t):
57
+ embedding = embedding.to(t)
58
+ return embedding
59
+
60
+
61
+ class MLPEmbedder(nn.Module):
62
+ def __init__(self, in_dim: int, hidden_dim: int):
63
+ super().__init__()
64
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
65
+ self.silu = nn.SiLU()
66
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ return self.out_layer(self.silu(self.in_layer(x)))
70
+
71
+
72
+ class RMSNorm(torch.nn.Module):
73
+ def __init__(self, dim: int):
74
+ super().__init__()
75
+ self.scale = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x: Tensor):
78
+ x_dtype = x.dtype
79
+ x = x.float()
80
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
81
+ return (x * rrms).to(dtype=x_dtype) * self.scale
82
+
83
+
84
+ class QKNorm(torch.nn.Module):
85
+ def __init__(self, dim: int):
86
+ super().__init__()
87
+ self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
88
+ self.key_norm = RMSNorm(dim)
89
+
90
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
91
+ q = self.query_norm(q)
92
+ k = self.key_norm(k)
93
+ return q.to(v), k.to(v)
94
+
95
+
96
+ class SelfAttention(nn.Module):
97
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.norm = QKNorm(head_dim)
104
+ self.proj = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
107
+ qkv = self.qkv(x)
108
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
109
+ q, k = self.norm(q, k, v)
110
+ x = attention(q, k, v, pe=pe)
111
+ x = self.proj(x)
112
+ return x
113
+
114
+
115
+ @dataclass
116
+ class ModulationOut:
117
+ shift: Tensor
118
+ scale: Tensor
119
+ gate: Tensor
120
+
121
+
122
+ class Modulation(nn.Module):
123
+ def __init__(self, dim: int, double: bool):
124
+ super().__init__()
125
+ self.is_double = double
126
+ self.multiplier = 6 if double else 3
127
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
128
+
129
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
130
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
131
+ self.multiplier, dim=-1
132
+ )
133
+
134
+ return (
135
+ ModulationOut(*out[:3]),
136
+ ModulationOut(*out[3:]) if self.is_double else None,
137
+ )
138
+
139
+
140
+ class DoubleStreamBlock(nn.Module):
141
+ def __init__(
142
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
143
+ ):
144
+ super().__init__()
145
+
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.num_heads = num_heads
148
+ self.hidden_size = hidden_size
149
+ self.img_mod = Modulation(hidden_size, double=True)
150
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
+ self.img_attn = SelfAttention(
152
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
153
+ )
154
+
155
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.img_mlp = nn.Sequential(
157
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
158
+ nn.GELU(approximate="tanh"),
159
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
160
+ )
161
+
162
+ self.txt_mod = Modulation(hidden_size, double=True)
163
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
164
+ self.txt_attn = SelfAttention(
165
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
166
+ )
167
+
168
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ self.txt_mlp = nn.Sequential(
170
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
171
+ nn.GELU(approximate="tanh"),
172
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
173
+ )
174
+
175
+ def forward(
176
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
177
+ ) -> tuple[Tensor, Tensor]:
178
+ img_mod1, img_mod2 = self.img_mod(vec)
179
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
180
+
181
+ # prepare image for attention
182
+ img_modulated = self.img_norm1(img)
183
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
184
+ img_qkv = self.img_attn.qkv(img_modulated)
185
+ img_q, img_k, img_v = rearrange(
186
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
187
+ )
188
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
189
+
190
+ # prepare txt for attention
191
+ txt_modulated = self.txt_norm1(txt)
192
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
193
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
194
+ txt_q, txt_k, txt_v = rearrange(
195
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
196
+ )
197
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
198
+
199
+ # run actual attention
200
+ q = torch.cat((txt_q, img_q), dim=2)
201
+ k = torch.cat((txt_k, img_k), dim=2)
202
+ v = torch.cat((txt_v, img_v), dim=2)
203
+
204
+ attn = attention(q, k, v, pe=pe)
205
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
206
+
207
+ # calculate the img bloks
208
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
209
+ img = img + img_mod2.gate * self.img_mlp(
210
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
211
+ )
212
+
213
+ # calculate the txt bloks
214
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
215
+ txt = txt + txt_mod2.gate * self.txt_mlp(
216
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
217
+ )
218
+ return img, txt
219
+
220
+
221
+ class SingleStreamBlock(nn.Module):
222
+ """
223
+ A DiT block with parallel linear layers as described in
224
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ hidden_size: int,
230
+ num_heads: int,
231
+ mlp_ratio: float = 4.0,
232
+ qk_scale: float | None = None,
233
+ ):
234
+ super().__init__()
235
+ self.hidden_dim = hidden_size
236
+ self.num_heads = num_heads
237
+ head_dim = hidden_size // num_heads
238
+ self.scale = qk_scale or head_dim**-0.5
239
+
240
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
241
+ # qkv and mlp_in
242
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
243
+ # proj and mlp_out
244
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
245
+
246
+ self.norm = QKNorm(head_dim)
247
+
248
+ self.hidden_size = hidden_size
249
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
250
+
251
+ self.mlp_act = nn.GELU(approximate="tanh")
252
+ self.modulation = Modulation(hidden_size, double=False)
253
+
254
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
255
+ mod, _ = self.modulation(vec)
256
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
257
+ qkv, mlp = torch.split(
258
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
259
+ )
260
+
261
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
262
+ q, k = self.norm(q, k, v)
263
+
264
+ # compute attention
265
+ attn = attention(q, k, v, pe=pe)
266
+ # compute activation in mlp stream, cat again and run second linear layer
267
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
268
+ return x + mod.gate * output
269
+
270
+
271
+ class LastLayer(nn.Module):
272
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
273
+ super().__init__()
274
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ self.linear = nn.Linear(
276
+ hidden_size, patch_size * patch_size * out_channels, bias=True
277
+ )
278
+ self.adaLN_modulation = nn.Sequential(
279
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
280
+ )
281
+
282
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
283
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
284
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
285
+ x = self.linear(x)
286
+ return x
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/scripts/download_autoencoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from requests.exceptions import HTTPError
10
+
11
+
12
+ def hf_download(
13
+ repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
14
+ ) -> None:
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ try:
18
+ hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=file_path,
21
+ local_dir=local_dir,
22
+ local_dir_use_symlinks=False,
23
+ token=hf_token,
24
+ )
25
+ except HTTPError as e:
26
+ if e.response.status_code == 401:
27
+ print(
28
+ "You need to pass a valid `--hf_token=...` to download private checkpoints."
29
+ )
30
+ else:
31
+ raise e
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
38
+ parser.add_argument(
39
+ "--repo_id",
40
+ type=str,
41
+ default="black-forest-labs/FLUX.1-dev",
42
+ help="Repository ID to download from. default to Flux-dev model",
43
+ )
44
+ parser.add_argument(
45
+ "--ae_path",
46
+ type=str,
47
+ default="ae.safetensors",
48
+ help="the autoencoder path relative to repo_id",
49
+ )
50
+ parser.add_argument(
51
+ "--hf_token", type=str, default=None, help="HuggingFace API token"
52
+ )
53
+ parser.add_argument(
54
+ "--local_dir",
55
+ type=str,
56
+ default="torchtitan/experiments/flux/assets/autoencoder/",
57
+ help="local directory to save the autoencoder",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import math
10
+ import time
11
+
12
+ from typing import Dict, List, Tuple
13
+
14
+ # import numpy as np
15
+ import torch #
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+
20
+ # from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
+ )
26
+
27
+ # Try to import the optimized MG GEMM implementation
28
+ try:
29
+ from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
30
+ grouped_gemm_forward,
31
+ )
32
+
33
+ has_mg_gemm = True
34
+ except ImportError:
35
+ logging.warning("MG GEMM implementation not found. Will use manual looping only.")
36
+ has_mg_gemm = False
37
+
38
+
39
+ class Router(nn.Module):
40
+ """
41
+ Router module that assigns tokens to experts.
42
+ """
43
+
44
+ def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
45
+ super().__init__()
46
+ self.input_dim = input_dim
47
+ self.num_experts = num_experts
48
+ self.top_k = top_k
49
+
50
+ # Routing layer
51
+ self.router = nn.Linear(input_dim, num_experts)
52
+
53
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
54
+ """
55
+ Route input tokens to experts.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
59
+
60
+ Returns:
61
+ Tuple containing:
62
+ - router_logits: Raw routing probabilities
63
+ - dispatch_tensor: One-hot tensor indicating expert assignment
64
+ - expert_indices: List of indices for each expert's tokens
65
+ """
66
+ batch_size, seq_len, _ = x.shape
67
+
68
+ # Flatten batch and sequence dimensions
69
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
70
+
71
+ # Compute routing probabilities
72
+ router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
73
+
74
+ # Apply softmax to get probabilities
75
+ router_probs = F.softmax(router_logits, dim=-1)
76
+
77
+ # Get top-k experts for each token
78
+ top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
79
+
80
+ # Normalize top-k probabilities
81
+ top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
82
+
83
+ # Create dispatch tensor (one-hot representation of assignments)
84
+ dispatch_tensor = torch.zeros_like(router_probs)
85
+ token_indices = (
86
+ torch.arange(router_probs.size(0), device=router_probs.device)
87
+ .unsqueeze(1)
88
+ .expand(-1, self.top_k)
89
+ )
90
+ dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
91
+
92
+ # For each expert, get the indices of tokens routed to it
93
+ expert_indices = []
94
+ for expert_idx in range(self.num_experts):
95
+ # Get indices of tokens that have non-zero probability for this expert
96
+ indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
97
+ 0
98
+ ]
99
+ expert_indices.append(indices)
100
+
101
+ return router_logits, dispatch_tensor, expert_indices
102
+
103
+
104
+ class Expert(nn.Module):
105
+ """
106
+ Individual expert module.
107
+ """
108
+
109
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
110
+ super().__init__()
111
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
112
+ self.activation = nn.GELU()
113
+ self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ x = self.fc1(x)
117
+ x = self.activation(x)
118
+ x = self.fc2(x)
119
+ return x
120
+
121
+
122
+ class MixtureOfExperts(nn.Module):
123
+ """
124
+ Mixture of Experts layer with support for both manual looping and grouped GEMM.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ input_dim: int,
130
+ hidden_dim: int,
131
+ output_dim: int,
132
+ num_experts: int,
133
+ top_k: int = 2,
134
+ use_mg_gemm: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.input_dim = input_dim
138
+ self.hidden_dim = hidden_dim
139
+ self.output_dim = output_dim
140
+ self.num_experts = num_experts
141
+ self.top_k = top_k
142
+ self.use_mg_gemm = use_mg_gemm and has_mg_gemm
143
+
144
+ # Router
145
+ self.router = Router(input_dim, num_experts, top_k)
146
+
147
+ # Create expert modules
148
+ if self.use_mg_gemm:
149
+ # For MG GEMM, we need a single weight tensor for all experts
150
+ # First layer (input -> hidden)
151
+ self.expert_fc1_weight = nn.Parameter(
152
+ torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
153
+ )
154
+ # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
155
+
156
+ # Second layer (hidden -> output)
157
+ self.expert_fc2_weight = nn.Parameter(
158
+ torch.randn(num_experts * output_dim, hidden_dim)
159
+ / math.sqrt(hidden_dim)
160
+ )
161
+ # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
162
+ else:
163
+ # For manual looping, create separate experts
164
+ self.experts = nn.ModuleList(
165
+ [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
166
+ )
167
+
168
+ def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Forward pass using manual looping over experts.
171
+ """
172
+ batch_size, seq_len, _ = x.shape
173
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
174
+
175
+ # Get routing information
176
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
177
+
178
+ # Initialize output tensor
179
+ final_output = torch.zeros(
180
+ batch_size * seq_len, self.output_dim, device=x.device
181
+ )
182
+
183
+ # Process each expert
184
+ for expert_idx, indices in enumerate(expert_indices):
185
+ if indices.numel() > 0:
186
+ # Get tokens routed to this expert
187
+ expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
188
+
189
+ # Process tokens through expert
190
+ expert_outputs = self.experts[expert_idx](
191
+ expert_inputs
192
+ ) # (num_tokens_for_expert, output_dim)
193
+
194
+ # Scale outputs by router probabilities
195
+ scaled_outputs = expert_outputs * dispatch_tensor[
196
+ indices, expert_idx
197
+ ].unsqueeze(1)
198
+
199
+ # Add to final output
200
+ final_output.index_add_(0, indices, scaled_outputs)
201
+
202
+ # Reshape back to original dimensions
203
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
204
+
205
+ return output, router_logits
206
+
207
+ def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
208
+ batch_size, seq_len, _ = x.shape
209
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
210
+ total_tokens = batch_size * seq_len
211
+
212
+ # Get routing information
213
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
214
+
215
+ # Get token counts for each expert
216
+ token_counts = [indices.numel() for indices in expert_indices]
217
+ m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
218
+
219
+ print(f"Token counts per expert: {token_counts}")
220
+ print(f"m_sizes: {m_sizes}")
221
+
222
+ # Create the combined input tensor
223
+ combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
224
+
225
+ start_idx = 0
226
+ for expert_idx, indices in enumerate(expert_indices):
227
+ if indices.numel() > 0:
228
+ end_idx = start_idx + indices.numel()
229
+ combined_input[start_idx:end_idx] = x_flat[indices]
230
+ start_idx = end_idx
231
+
232
+ print(f"combined_input shape: {combined_input.shape}")
233
+
234
+ # First layer: input -> hidden
235
+ fc1_weight_reshaped = self.expert_fc1_weight.reshape(
236
+ self.num_experts, self.hidden_dim, self.input_dim
237
+ )
238
+ fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
239
+
240
+ print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
241
+
242
+ # Run the grouped GEMM
243
+ hidden_outputs = grouped_gemm_forward(
244
+ combined_input, fc1_weight_combined, m_sizes
245
+ )
246
+
247
+ print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
248
+
249
+ # Apply activation
250
+ hidden_outputs = F.gelu(hidden_outputs)
251
+
252
+ print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
253
+
254
+ # Second layer: hidden -> output
255
+ # Reshape hidden_outputs to match expected dimensions
256
+ reshaped_hidden_outputs = []
257
+ start_idx = 0
258
+
259
+ for expert_idx, count in enumerate(token_counts):
260
+ if count > 0:
261
+ end_idx = start_idx + count
262
+ # Take this expert's outputs and reshape to [count, hidden_dim]
263
+ expert_output = hidden_outputs[
264
+ start_idx:end_idx,
265
+ expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
266
+ ]
267
+ reshaped_hidden_outputs.append(expert_output)
268
+ start_idx = end_idx
269
+
270
+ # Concatenate all reshaped outputs
271
+ hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
272
+
273
+ # Reshape expert weights for second layer
274
+ fc2_weight_reshaped = self.expert_fc2_weight.reshape(
275
+ self.num_experts, self.output_dim, self.hidden_dim
276
+ )
277
+ fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
278
+
279
+ print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
280
+
281
+ # Run the second grouped GEMM
282
+ expert_outputs_combined = grouped_gemm_forward(
283
+ hidden_outputs, fc2_weight_combined, m_sizes
284
+ )
285
+
286
+ # Initialize final output tensor with correct shape
287
+ final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
288
+
289
+ # Distribute the outputs back to the original token positions
290
+ start_idx = 0
291
+ for expert_idx, indices in enumerate(expert_indices):
292
+ if indices.numel() > 0:
293
+ end_idx = start_idx + indices.numel()
294
+ # Get this expert's outputs
295
+ expert_outputs = expert_outputs_combined[start_idx:end_idx]
296
+
297
+ print(
298
+ f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
299
+ )
300
+
301
+ # Scale outputs by router probabilities
302
+ scaled_outputs = expert_outputs * dispatch_tensor[
303
+ indices, expert_idx
304
+ ].unsqueeze(1)
305
+
306
+ # Ensure dimensions match before using index_add_
307
+ if scaled_outputs.shape[1] != final_output.shape[1]:
308
+ # print(
309
+ # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
310
+ # )
311
+ # Reshape if needed - make sure output_dim is correct
312
+ scaled_outputs = scaled_outputs[:, : self.output_dim]
313
+
314
+ # Add to final output
315
+ final_output.index_add_(0, indices, scaled_outputs)
316
+
317
+ start_idx = end_idx
318
+
319
+ # Reshape back to original dimensions
320
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
321
+
322
+ return output, router_logits
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ if self.use_mg_gemm and has_mg_gemm:
326
+ return self.forward_mg_gemm(x)
327
+ else:
328
+ return self.forward_manual_loop(x)
329
+
330
+
331
+ class MoEModel(nn.Module):
332
+ """
333
+ Simple model using MoE layers.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ vocab_size: int,
339
+ embed_dim: int,
340
+ hidden_dim: int,
341
+ num_experts: int,
342
+ top_k: int = 2,
343
+ use_mg_gemm: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
347
+ self.moe_layer = MixtureOfExperts(
348
+ input_dim=embed_dim,
349
+ hidden_dim=hidden_dim,
350
+ output_dim=embed_dim,
351
+ num_experts=num_experts,
352
+ top_k=top_k,
353
+ use_mg_gemm=use_mg_gemm,
354
+ )
355
+ self.output_layer = nn.Linear(embed_dim, vocab_size)
356
+
357
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
358
+ # x shape: (batch_size, seq_len)
359
+ embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
360
+ moe_output, router_logits = self.moe_layer(
361
+ embedded
362
+ ) # (batch_size, seq_len, embed_dim)
363
+ logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
364
+ return logits, router_logits
365
+
366
+
367
+ def compute_load_balancing_loss(
368
+ router_logits: torch.Tensor, num_experts: int
369
+ ) -> torch.Tensor:
370
+ """
371
+ Compute the load balancing loss for MoE training.
372
+
373
+ Args:
374
+ router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
375
+ num_experts (int): Number of experts
376
+
377
+ Returns:
378
+ torch.Tensor: Load balancing loss
379
+ """
380
+ # Get router probabilities
381
+ router_probs = F.softmax(
382
+ router_logits, dim=-1
383
+ ) # (batch_size * seq_len, num_experts)
384
+
385
+ # Compute fraction of tokens routed to each expert
386
+ # Sum across the batch dimension and normalize
387
+ router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
388
+ router_probs_sum = router_probs_sum / router_probs_sum.sum()
389
+
390
+ # Compute the mean probability per expert
391
+ mean_prob = 1.0 / num_experts
392
+
393
+ # Compute the fraction of tokens routed to each expert
394
+ # The goal is to have uniform routing across experts
395
+ load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
396
+
397
+ return load_balancing_loss
398
+
399
+
400
+ def generate_sample_data(
401
+ batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ """
404
+ Generate sample data for training.
405
+
406
+ Args:
407
+ batch_size (int): Batch size
408
+ seq_len (int): Sequence length
409
+ vocab_size (int): Vocabulary size
410
+ device (str): Device to use
411
+
412
+ Returns:
413
+ Tuple of input tokens and target tokens
414
+ """
415
+ # Generate random input tokens
416
+ inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
417
+
418
+ # Generate random target tokens
419
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
420
+
421
+ return inputs, targets
422
+
423
+
424
+ def train_epoch(
425
+ model: nn.Module,
426
+ optimizer: torch.optim.Optimizer,
427
+ batch_size: int,
428
+ seq_len: int,
429
+ vocab_size: int,
430
+ num_batches: int,
431
+ device: str,
432
+ load_balance_coef: float = 0.01,
433
+ ) -> Dict[str, float]:
434
+ """
435
+ Train the model for one epoch.
436
+
437
+ Args:
438
+ model (nn.Module): Model to train
439
+ optimizer (torch.optim.Optimizer): Optimizer
440
+ batch_size (int): Batch size
441
+ seq_len (int): Sequence length
442
+ vocab_size (int): Vocabulary size
443
+ num_batches (int): Number of batches per epoch
444
+ device (str): Device to use
445
+ load_balance_coef (float): Coefficient for load balancing loss
446
+
447
+ Returns:
448
+ Dict containing training metrics
449
+ """
450
+ model.train()
451
+ total_loss = 0.0
452
+ total_acc = 0.0
453
+ start_time = time.time()
454
+
455
+ for i in range(num_batches):
456
+ # Generate sample data
457
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
458
+
459
+ # Forward pass
460
+ optimizer.zero_grad()
461
+ logits, router_logits = model(inputs)
462
+
463
+ # Compute loss
464
+ # Reshape for cross entropy loss
465
+ logits_flat = logits.reshape(-1, vocab_size)
466
+ targets_flat = targets.reshape(-1)
467
+
468
+ # Cross entropy loss
469
+ ce_loss = F.cross_entropy(logits_flat, targets_flat)
470
+
471
+ # Load balancing loss
472
+ lb_loss = compute_load_balancing_loss(
473
+ router_logits, model.moe_layer.num_experts
474
+ )
475
+
476
+ # Combined loss
477
+ loss = ce_loss + load_balance_coef * lb_loss
478
+
479
+ # Backward pass
480
+ loss.backward()
481
+ optimizer.step()
482
+
483
+ # Compute accuracy
484
+ preds = logits_flat.argmax(dim=-1)
485
+ correct = (preds == targets_flat).float().sum()
486
+ acc = correct / (batch_size * seq_len)
487
+
488
+ # Accumulate metrics
489
+ total_loss += loss.item()
490
+ total_acc += acc.item()
491
+
492
+ # Log progress
493
+ if (i + 1) % 10 == 0:
494
+ logging.info(
495
+ f"Batch {i + 1}/{num_batches} | "
496
+ f"Loss: {loss.item():.4f} | "
497
+ f"CE Loss: {ce_loss.item():.4f} | "
498
+ f"LB Loss: {lb_loss.item():.4f} | "
499
+ f"Acc: {acc.item():.4f}"
500
+ )
501
+
502
+ # Compute average metrics
503
+ avg_loss = total_loss / num_batches
504
+ avg_acc = total_acc / num_batches
505
+ epoch_time = time.time() - start_time
506
+
507
+ return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
508
+
509
+
510
+ def evaluate(
511
+ model: nn.Module,
512
+ batch_size: int,
513
+ seq_len: int,
514
+ vocab_size: int,
515
+ num_batches: int,
516
+ device: str,
517
+ ) -> Dict[str, float]:
518
+ """
519
+ Evaluate the model.
520
+
521
+ Args:
522
+ model (nn.Module): Model to evaluate
523
+ batch_size (int): Batch size
524
+ seq_len (int): Sequence length
525
+ vocab_size (int): Vocabulary size
526
+ num_batches (int): Number of batches for evaluation
527
+ device (str): Device to use
528
+
529
+ Returns:
530
+ Dict containing evaluation metrics
531
+ """
532
+ model.eval()
533
+ total_loss = 0.0
534
+ total_acc = 0.0
535
+
536
+ with torch.no_grad():
537
+ for i in range(num_batches):
538
+ # Generate sample data
539
+ inputs, targets = generate_sample_data(
540
+ batch_size, seq_len, vocab_size, device
541
+ )
542
+
543
+ # Forward pass
544
+ logits, router_logits = model(inputs)
545
+
546
+ # Compute loss
547
+ logits_flat = logits.reshape(-1, vocab_size)
548
+ targets_flat = targets.reshape(-1)
549
+
550
+ # Cross entropy loss
551
+ loss = F.cross_entropy(logits_flat, targets_flat)
552
+
553
+ # Compute accuracy
554
+ preds = logits_flat.argmax(dim=-1)
555
+ correct = (preds == targets_flat).float().sum()
556
+ acc = correct / (batch_size * seq_len)
557
+
558
+ # Accumulate metrics
559
+ total_loss += loss.item()
560
+ total_acc += acc.item()
561
+
562
+ # Compute average metrics
563
+ avg_loss = total_loss / num_batches
564
+ avg_acc = total_acc / num_batches
565
+
566
+ return {"loss": avg_loss, "acc": avg_acc}
567
+
568
+
569
+ def measure_performance(
570
+ model: nn.Module,
571
+ batch_size: int,
572
+ seq_len: int,
573
+ vocab_size: int,
574
+ num_batches: int,
575
+ device: str,
576
+ ) -> Dict[str, float]:
577
+ """
578
+ Measure forward and backward pass performance.
579
+
580
+ Args:
581
+ model (nn.Module): Model to evaluate
582
+ batch_size (int): Batch size
583
+ seq_len (int): Sequence length
584
+ vocab_size (int): Vocabulary size
585
+ num_batches (int): Number of batches for measurement
586
+ device (str): Device to use
587
+
588
+ Returns:
589
+ Dict containing performance metrics
590
+ """
591
+ model.train()
592
+
593
+ # Create dummy optimizer
594
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
595
+
596
+ # Warmup
597
+ for _ in range(5):
598
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
599
+ logits, router_logits = model(inputs)
600
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
601
+ loss.backward()
602
+ optimizer.zero_grad()
603
+
604
+ # Measure forward pass time
605
+ torch.cuda.synchronize()
606
+ forward_start = time.time()
607
+
608
+ for _ in range(num_batches):
609
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
610
+ with torch.no_grad():
611
+ logits, router_logits = model(inputs)
612
+
613
+ torch.cuda.synchronize()
614
+ forward_end = time.time()
615
+ forward_time = (forward_end - forward_start) / num_batches
616
+
617
+ # Measure backward pass time
618
+ torch.cuda.synchronize()
619
+ backward_start = time.time()
620
+
621
+ for _ in range(num_batches):
622
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
623
+ logits, router_logits = model(inputs)
624
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
625
+ loss.backward()
626
+ optimizer.zero_grad()
627
+
628
+ torch.cuda.synchronize()
629
+ backward_end = time.time()
630
+ backward_time = (backward_end - backward_start) / num_batches
631
+
632
+ return {
633
+ "forward_time": forward_time * 1000, # Convert to ms
634
+ "backward_time": backward_time * 1000, # Convert to ms
635
+ "total_time": (forward_time + backward_time) * 1000, # Convert to ms
636
+ }
637
+
638
+
639
+ def compare_methods(args):
640
+ """
641
+ Compare manual looping and MG GEMM implementations.
642
+ """
643
+ device = torch.device(args.device)
644
+
645
+ # Create models
646
+ manual_model = MoEModel(
647
+ vocab_size=args.vocab_size,
648
+ embed_dim=args.embed_dim,
649
+ hidden_dim=args.hidden_dim,
650
+ num_experts=args.num_experts,
651
+ top_k=args.top_k,
652
+ use_mg_gemm=False,
653
+ ).to(device)
654
+
655
+ if has_mg_gemm:
656
+ mg_model = MoEModel(
657
+ vocab_size=args.vocab_size,
658
+ embed_dim=args.embed_dim,
659
+ hidden_dim=args.hidden_dim,
660
+ num_experts=args.num_experts,
661
+ top_k=args.top_k,
662
+ use_mg_gemm=True,
663
+ ).to(device)
664
+ else:
665
+ mg_model = None
666
+
667
+ # Measure performance
668
+ logging.info("Measuring performance of manual looping method...")
669
+ manual_perf = measure_performance(
670
+ manual_model,
671
+ args.batch_size,
672
+ args.seq_len,
673
+ args.vocab_size,
674
+ args.perf_batches,
675
+ device,
676
+ )
677
+
678
+ if mg_model is not None:
679
+ logging.info("Measuring performance of MG GEMM method...")
680
+ mg_perf = measure_performance(
681
+ mg_model,
682
+ args.batch_size,
683
+ args.seq_len,
684
+ args.vocab_size,
685
+ args.perf_batches,
686
+ device,
687
+ )
688
+ else:
689
+ mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
690
+
691
+ # Log results
692
+ logging.info("\n===== Performance Comparison =====")
693
+ logging.info("Model Configuration:")
694
+ logging.info(f" - Batch Size: {args.batch_size}")
695
+ logging.info(f" - Sequence Length: {args.seq_len}")
696
+ logging.info(f" - Embed Dimension: {args.embed_dim}")
697
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
698
+ logging.info(f" - Number of Experts: {args.num_experts}")
699
+ logging.info(f" - Top-K: {args.top_k}")
700
+ logging.info("")
701
+
702
+ logging.info("Manual Looping Method:")
703
+ logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
704
+ logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
705
+ logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
706
+ logging.info("")
707
+
708
+ if mg_model is not None:
709
+ logging.info("MG GEMM Method:")
710
+ logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
711
+ logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
712
+ logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
713
+ logging.info("")
714
+
715
+ # Calculate speedup
716
+ forward_speedup = (
717
+ manual_perf["forward_time"] / mg_perf["forward_time"]
718
+ if mg_perf["forward_time"] > 0
719
+ else 0
720
+ )
721
+ backward_speedup = (
722
+ manual_perf["backward_time"] / mg_perf["backward_time"]
723
+ if mg_perf["backward_time"] > 0
724
+ else 0
725
+ )
726
+ total_speedup = (
727
+ manual_perf["total_time"] / mg_perf["total_time"]
728
+ if mg_perf["total_time"] > 0
729
+ else 0
730
+ )
731
+
732
+ logging.info("Speedup (MG GEMM vs Manual):")
733
+ logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
734
+ logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
735
+ logging.info(f" - Total Speedup: {total_speedup:.2f}x")
736
+ else:
737
+ logging.info("MG GEMM method not available.")
738
+
739
+
740
+ def train_model(args):
741
+ """
742
+ Train an MoE model.
743
+ """
744
+ device = torch.device(args.device)
745
+
746
+ # Create model
747
+ model = MoEModel(
748
+ vocab_size=args.vocab_size,
749
+ embed_dim=args.embed_dim,
750
+ hidden_dim=args.hidden_dim,
751
+ num_experts=args.num_experts,
752
+ top_k=args.top_k,
753
+ use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
754
+ ).to(device)
755
+
756
+ # Create optimizer
757
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
758
+
759
+ # Log model information
760
+ logging.info("Model configuration:")
761
+ logging.info(f" - Vocabulary Size: {args.vocab_size}")
762
+ logging.info(f" - Embedding Dimension: {args.embed_dim}")
763
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
764
+ logging.info(f" - Number of Experts: {args.num_experts}")
765
+ logging.info(f" - Top-K: {args.top_k}")
766
+ logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
767
+
768
+ # Training loop
769
+ for epoch in range(args.epochs):
770
+ logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
771
+
772
+ # Train
773
+ train_metrics = train_epoch(
774
+ model=model,
775
+ optimizer=optimizer,
776
+ batch_size=args.batch_size,
777
+ seq_len=args.seq_len,
778
+ vocab_size=args.vocab_size,
779
+ num_batches=args.train_batches,
780
+ device=device,
781
+ load_balance_coef=args.load_balance_coef,
782
+ )
783
+
784
+ # Evaluate
785
+ eval_metrics = evaluate(
786
+ model=model,
787
+ batch_size=args.batch_size,
788
+ seq_len=args.seq_len,
789
+ vocab_size=args.vocab_size,
790
+ num_batches=args.eval_batches,
791
+ device=device,
792
+ )
793
+
794
+ # Log metrics
795
+ logging.info(
796
+ f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
797
+ )
798
+ logging.info(
799
+ f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
800
+ )
801
+ logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
802
+
803
+
804
+ if __name__ == "__main__":
805
+ parser = argparse.ArgumentParser(description="Train MoE model")
806
+
807
+ # Model parameters
808
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
809
+ parser.add_argument(
810
+ "--embed_dim", type=int, default=512, help="Embedding dimension"
811
+ )
812
+ parser.add_argument(
813
+ "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
814
+ )
815
+ parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
816
+ parser.add_argument(
817
+ "--top_k", type=int, default=2, help="Top-k experts to route to"
818
+ )
819
+
820
+ # Training parameters
821
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
822
+ parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
823
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
824
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
825
+ parser.add_argument(
826
+ "--train_batches",
827
+ type=int,
828
+ default=100,
829
+ help="Number of training batches per epoch",
830
+ )
831
+ parser.add_argument(
832
+ "--eval_batches", type=int, default=20, help="Number of evaluation batches"
833
+ )
834
+ parser.add_argument(
835
+ "--perf_batches",
836
+ type=int,
837
+ default=50,
838
+ help="Number of batches for performance testing",
839
+ )
840
+ parser.add_argument(
841
+ "--load_balance_coef",
842
+ type=float,
843
+ default=0.01,
844
+ help="Load balancing loss coefficient",
845
+ )
846
+
847
+ # Runtime parameters
848
+ parser.add_argument(
849
+ "--device",
850
+ type=str,
851
+ default="cuda" if torch.cuda.is_available() else "cpu",
852
+ help="Device to use (cuda or cpu)",
853
+ )
854
+ parser.add_argument(
855
+ "--use_mg_gemm",
856
+ action="store_true",
857
+ help="Use MG GEMM implementation if available",
858
+ )
859
+ parser.add_argument(
860
+ "--compare",
861
+ action="store_true",
862
+ help="Compare manual and MG GEMM implementations",
863
+ )
864
+ parser.add_argument("--train", action="store_true", help="Train the model")
865
+
866
+ args = parser.parse_args()
867
+
868
+ # Check for CUDA
869
+ if args.device == "cuda" and not torch.cuda.is_available():
870
+ logging.warning("CUDA not available, using CPU instead.")
871
+ args.device = "cpu"
872
+
873
+ # Log basic information
874
+ logging.info(f"PyTorch version: {torch.__version__}")
875
+ logging.info(f"Device: {args.device}")
876
+ logging.info(f"MG GEMM available: {has_mg_gemm}")
877
+
878
+ # Run the requested action
879
+ if args.compare:
880
+ compare_methods(args)
881
+ elif args.train:
882
+ train_model(args)
883
+ else:
884
+ # Default to comparison if no action specified
885
+ compare_methods(args)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from reference_utils import (
14
+ analyze_tensor_differences,
15
+ compute_reference_backward,
16
+ compute_reference_forward,
17
+ )
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
+ )
23
+
24
+ # Import grouped GEMM implementations
25
+ try:
26
+ from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
27
+
28
+ except ImportError:
29
+ logging.error(
30
+ "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
31
+ )
32
+ raise
33
+
34
+
35
+ def test_forward_pass():
36
+ """
37
+ A simple test for the M*G grouped GEMM forward pass with detailed error handling.
38
+
39
+ In M*G grouping:
40
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
41
+ - N dimension is the same for all groups
42
+ """
43
+ try:
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Test parameters for DeepSeek-like models
47
+ G = 1 # Number of groups
48
+ M_sizes = [
49
+ 2048,
50
+ ] # 2048, 2048, 2048] # Group sizes (will be adjusted)
51
+ M_total = sum(M_sizes) # Total M dimension
52
+ N = 4096 # Output dimension (same for all groups)
53
+ K = 7168 # Hidden dimension
54
+
55
+ # Create group sizes tensor
56
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
57
+
58
+ # Create input and weight tensors - using float16 for higher precision
59
+ x = torch.randn(M_total, K, dtype=torch.float16, device=device)
60
+ w = torch.randn(N, K, dtype=torch.float16, device=device)
61
+
62
+ # Log the setup
63
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
64
+ logging.info(f"Group sizes: {m_sizes}")
65
+ logging.info(f"Input x shape: {x.shape}")
66
+ logging.info(f"Weight w shape: {w.shape}")
67
+
68
+ # Run forward pass
69
+ logging.info("Running forward pass with grouped GEMM")
70
+ result = grouped_gemm_forward(x, w, m_sizes)
71
+ logging.info(f"Forward result shape: {result.shape}")
72
+
73
+ # Compute reference result
74
+ logging.info("Computing reference result with PyTorch")
75
+ reference_result = compute_reference_forward(x, w, m_sizes)
76
+
77
+ # Compare results
78
+ logging.info("Comparing with PyTorch reference")
79
+ forward_close = analyze_tensor_differences(
80
+ result, reference_result, "Forward output"
81
+ )
82
+
83
+ return forward_close
84
+
85
+ except Exception as e:
86
+ logging.error(f"Test failed with error: {e}")
87
+ import traceback
88
+
89
+ logging.error(traceback.format_exc())
90
+ return False
91
+
92
+
93
+ def test_backward_pass():
94
+ """
95
+ A simple test for the M*G grouped GEMM backward pass with detailed error handling.
96
+
97
+ In M*G grouping:
98
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
99
+ - N dimension is the same for all groups
100
+ """
101
+ try:
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # Test parameters for DeepSeek-like models
105
+ G = 4 # Number of groups
106
+ M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
107
+ M_total = sum(M_sizes) # Total M dimension
108
+ N = 4096 # Output dimension (same for all groups)
109
+ K = 7168 # Hidden dimension
110
+
111
+ # Create group sizes tensor
112
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
113
+
114
+ # Create input and weight tensors - using float16 for higher precision
115
+ x = torch.randn(
116
+ M_total, K, dtype=torch.float16, device=device, requires_grad=True
117
+ )
118
+ w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
119
+
120
+ # Log the setup
121
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
122
+ logging.info(f"Group sizes: {m_sizes}")
123
+ logging.info(f"Input x shape: {x.shape}")
124
+ logging.info(f"Weight w shape: {w.shape}")
125
+
126
+ # Step 1: Run forward pass
127
+ logging.info("Running forward pass")
128
+ result = grouped_gemm_forward(x, w, m_sizes)
129
+ logging.info(f"Forward result shape: {result.shape}")
130
+
131
+ # Create a gradient for backpropagation
132
+ grad_output = torch.randn_like(result)
133
+ logging.info(f"Created gradient with shape: {grad_output.shape}")
134
+
135
+ # Step 2: Run backward pass directly
136
+ logging.info("Running backward pass directly")
137
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
138
+
139
+ # Verify gradient shapes
140
+ logging.info(
141
+ f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
142
+ )
143
+
144
+ # Step 3: Verify gradient computation using PyTorch's autograd
145
+ logging.info("Running PyTorch reference implementation")
146
+
147
+ # Compute reference gradients
148
+ x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
149
+
150
+ # Compare gradients
151
+ logging.info("Comparing gradients with PyTorch reference")
152
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
153
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
154
+
155
+ # Log overall result
156
+ if grad_x_close and grad_w_close:
157
+ logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
158
+ else:
159
+ logging.error("✗ FAILURE: Gradient mismatch detected")
160
+
161
+ return grad_x_close and grad_w_close
162
+
163
+ except Exception as e:
164
+ logging.error(f"Test failed with error: {e}")
165
+ import traceback
166
+
167
+ logging.error(traceback.format_exc())
168
+ return False
169
+
170
+
171
+ def test_multiple_deepseek_configs():
172
+ """
173
+ Test multiple DeepSeek model configurations with both forward and backward pass verification.
174
+ """
175
+ # DeepSeek configurations: (G, M, K, N)
176
+ configs = [
177
+ (4, 8192, 7168, 4096), # Config 1
178
+ (4, 8192, 2048, 7168), # Config 2
179
+ (8, 4096, 7168, 4096), # Config 3
180
+ (8, 4096, 2048, 7168), # Config 4
181
+ ]
182
+
183
+ results = []
184
+
185
+ for config_idx, (G, M, K, N) in enumerate(configs):
186
+ logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
187
+ logging.info(f"G={G}, M={M}, K={K}, N={N}")
188
+
189
+ try:
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ # Create even group sizes
193
+ base_size = M // G
194
+ remainder = M % G
195
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
196
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
197
+
198
+ # Create input and weight tensors using float16 for higher precision
199
+ x = torch.randn(
200
+ M, K, dtype=torch.float16, device=device, requires_grad=True
201
+ )
202
+ w = torch.randn(
203
+ N, K, dtype=torch.float16, device=device, requires_grad=True
204
+ )
205
+
206
+ logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
207
+
208
+ # Run forward pass
209
+ result = grouped_gemm_forward(x, w, m_sizes)
210
+ logging.info(f"Forward result shape: {result.shape}")
211
+
212
+ # ===== FORWARD PASS VERIFICATION =====
213
+ # Compute reference forward result
214
+ reference_result = compute_reference_forward(x, w, m_sizes)
215
+
216
+ # Compare forward results
217
+ forward_close = analyze_tensor_differences(
218
+ result, reference_result, "Forward output"
219
+ )
220
+
221
+ # ===== BACKWARD PASS VERIFICATION =====
222
+ # Create gradient for backpropagation
223
+ grad_output = torch.randn_like(result)
224
+
225
+ # Run backward pass
226
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
227
+
228
+ # Compute reference gradients
229
+ x_ref_grad, w_ref_grad = compute_reference_backward(
230
+ x, w, m_sizes, grad_output
231
+ )
232
+
233
+ # Compare backward results
234
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
235
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
236
+
237
+ # Overall config result
238
+ backward_close = grad_x_close and grad_w_close
239
+ config_success = forward_close and backward_close
240
+ results.append(
241
+ (config_idx + 1, config_success, forward_close, backward_close)
242
+ )
243
+
244
+ # Log overall config result
245
+ if config_success:
246
+ logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
247
+ else:
248
+ logging.error(
249
+ f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
250
+ )
251
+
252
+ except Exception as e:
253
+ logging.error(f"Config {config_idx+1} test failed with error: {e}")
254
+ import traceback
255
+
256
+ logging.error(traceback.format_exc())
257
+ results.append((config_idx + 1, False, False, False))
258
+
259
+ # Summary
260
+ logging.info("\n===== Test Results Summary =====")
261
+ for config_idx, overall_success, forward_success, backward_success in results:
262
+ overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
263
+ forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
264
+ backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
265
+
266
+ logging.info(f"Config {config_idx}: {overall_status}")
267
+ logging.info(f" - Forward pass: {forward_status}")
268
+ logging.info(f" - Backward pass: {backward_status}")
269
+
270
+ return all(overall_success for _, overall_success, _, _ in results)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ logging.info(
275
+ "Running verification for both forward and backward pass of M*G grouped GEMM"
276
+ )
277
+
278
+ # Run basic forward pass test
279
+ logging.info("\n===== Running basic forward pass test =====")
280
+ success_forward = test_forward_pass()
281
+ logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
282
+
283
+ # Run basic backward pass test
284
+ logging.info("\n===== Running basic backward pass test =====")
285
+ success_backward = test_backward_pass()
286
+ logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
287
+
288
+ # Run multiple DeepSeek configs with forward and backward verification
289
+ logging.info("\n===== Running tests for all DeepSeek configs =====")
290
+ success_configs = test_multiple_deepseek_configs()
291
+ logging.info(
292
+ f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
293
+ )
294
+
295
+ # Overall result
296
+ overall_success = success_forward and success_backward and success_configs
297
+ logging.info(
298
+ f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
299
+ )
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
16
+ )
17
+
18
+
19
+ def compute_reference_forward(x, w, m_sizes):
20
+ """
21
+ Compute reference forward pass using PyTorch operations.
22
+
23
+ Args:
24
+ x (torch.Tensor): Input tensor of shape (M, K)
25
+ w (torch.Tensor): Weight tensor of shape (N, K)
26
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
27
+
28
+ Returns:
29
+ torch.Tensor: Reference output tensor of shape (M, N)
30
+ """
31
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
32
+
33
+ m_start = 0
34
+ for g in range(len(m_sizes)):
35
+ m_size = m_sizes[g].item()
36
+ if m_size > 0:
37
+ m_end = m_start + m_size
38
+
39
+ # Extract group input
40
+ x_g = x[m_start:m_end]
41
+
42
+ # Compute group output: y_g = x_g @ w.T
43
+ y_g = torch.matmul(x_g, w.T)
44
+
45
+ # Store result
46
+ result[m_start:m_end] = y_g
47
+
48
+ # Update start index
49
+ m_start = m_end
50
+
51
+ return result
52
+
53
+
54
+ def compute_reference_backward(x, w, m_sizes, grad_output):
55
+ """
56
+ Compute reference backward pass using PyTorch autograd.
57
+
58
+ Args:
59
+ x (torch.Tensor): Input tensor of shape (M, K)
60
+ w (torch.Tensor): Weight tensor of shape (N, K)
61
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
62
+ grad_output (torch.Tensor): Gradient tensor of shape (M, N)
63
+
64
+ Returns:
65
+ tuple: (grad_x, grad_w) gradient tensors
66
+ """
67
+ # Create autograd-enabled copies
68
+ x_autograd = x.detach().clone().requires_grad_(True)
69
+ w_autograd = w.detach().clone().requires_grad_(True)
70
+
71
+ # Compute forward pass
72
+ output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
73
+
74
+ # Backpropagate
75
+ output.backward(grad_output)
76
+
77
+ return x_autograd.grad, w_autograd.grad
78
+
79
+
80
+ def analyze_tensor_differences(actual, expected, name):
81
+ """
82
+ Analyze differences between actual and expected tensors.
83
+
84
+ Args:
85
+ actual (torch.Tensor): Actual tensor
86
+ expected (torch.Tensor): Expected tensor
87
+ name (str): Name of the tensor for logging
88
+
89
+ Returns:
90
+ bool: True if tensors are close enough
91
+ """
92
+ rtol = 0.5 # Relative tolerance for float16
93
+ atol = 0.5 # Absolute tolerance for float16
94
+
95
+ # Analyze differences
96
+ diff = (actual - expected).abs()
97
+ max_idx = diff.argmax().item()
98
+ idx = np.unravel_index(max_idx, actual.shape)
99
+ max_diff = diff.max().item()
100
+
101
+ logging.info(f"Largest {name} difference: {max_diff} at {idx}")
102
+ logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
103
+
104
+ is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
105
+
106
+ if is_close:
107
+ logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
108
+ else:
109
+ logging.error(f"✗ FAILURE: {name} mismatch detected")
110
+
111
+ # Count zeros
112
+ zeros_actual = (actual == 0).sum().item()
113
+ zeros_expected = (expected == 0).sum().item()
114
+ logging.info(
115
+ f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
116
+ )
117
+ logging.info(
118
+ f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
119
+ )
120
+
121
+ # Check for NaNs
122
+ nan_actual = torch.isnan(actual).sum().item()
123
+ if nan_actual > 0:
124
+ logging.error(f"NaN values detected in {name}: {nan_actual}")
125
+
126
+ return is_close