zaydzuhri commited on
Commit
b97064d
·
verified ·
1 Parent(s): 30cc604

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc +0 -0
  2. logs/none_zo1mfnl3/attempt_0/0/stderr.log +0 -0
  3. logs/none_zo1mfnl3/attempt_0/2/stderr.log +0 -0
  4. logs/none_zo1mfnl3/attempt_0/3/stderr.log +0 -0
  5. logs/none_zo1mfnl3/attempt_0/4/stderr.log +0 -0
  6. torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
  7. torchtitan/components/__pycache__/loss.cpython-312.pyc +0 -0
  8. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  9. torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
  10. torchtitan/components/dataloader.py +92 -0
  11. torchtitan/distributed/__pycache__/__init__.cpython-312.pyc +0 -0
  12. torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
  13. torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
  14. torchtitan/experiments/deepseek_v3/generate.py +308 -0
  15. torchtitan/experiments/deepseek_v3/model_config.py +204 -0
  16. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
  17. torchtitan/experiments/flux/README.md +23 -0
  18. torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
  19. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  20. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  21. torchtitan/experiments/flux/model/math.py +38 -0
  22. torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
  23. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  24. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  25. torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
  26. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
  27. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
  28. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  29. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
  30. torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
  31. torchtitan/experiments/llama4/model/args.py +109 -0
  32. torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
  33. torchtitan/experiments/multimodal/tests/__init__.py +5 -0
  34. torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
  35. torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
  36. torchtitan/experiments/multimodal/utils.py +437 -0
  37. torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
  38. torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc +0 -0
  39. torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
  40. torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc +0 -0
  41. torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
  42. torchtitan/experiments/simple_fsdp/tests/test_numerics.py +128 -0
  43. torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
  44. torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
  45. torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc +0 -0
  46. torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
  47. torchtitan/models/llama3/parallelize_llama.py +398 -0
  48. torchtitan/models/llama3/train_configs/llama3_70b.toml +62 -0
  49. torchtitan/protocols/train_spec.py +115 -0
  50. train.sh +121 -0
fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
logs/none_zo1mfnl3/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_zo1mfnl3/attempt_0/2/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_zo1mfnl3/attempt_0/3/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_zo1mfnl3/attempt_0/4/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
torchtitan/components/__pycache__/float8.cpython-312.pyc ADDED
Binary file (6.2 kB). View file
 
torchtitan/components/__pycache__/loss.cpython-312.pyc ADDED
Binary file (1.51 kB). View file
 
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (19.6 kB). View file
 
torchtitan/components/dataloader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ import pickle
10
+ from abc import ABC, abstractmethod
11
+ from collections.abc import Callable
12
+ from typing import Any
13
+
14
+ from torch.distributed.checkpoint.stateful import Stateful
15
+ from torch.utils.data import IterableDataset
16
+ from torchdata.stateful_dataloader import StatefulDataLoader
17
+ from torchtitan.tools.logging import logger
18
+
19
+
20
+ class BaseDataLoader(Stateful, ABC):
21
+ """Base class for all dataloaders.
22
+
23
+ This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
24
+ ``state_dict()`` and ``load_state_dict()``.
25
+ """
26
+
27
+ @abstractmethod
28
+ def __iter__(self):
29
+ ...
30
+
31
+
32
+ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
33
+ """Dataloader that is aware of distributed data parallelism.
34
+
35
+ This dataloader is used to load data in a distributed data parallel fashion. It also
36
+ utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
37
+ methods such as ``__iter__``.
38
+
39
+ Args:
40
+ dataset (IterableDataset): The dataset to iterate over.
41
+ dp_rank: Data parallelism rank for this dataloader.
42
+ dp_world_size: The world size of the data parallelism.
43
+ batch_size: The batch size to use for each iteration.
44
+ collate_fn: Optional function to collate samples in a batch.
45
+ """
46
+
47
+ dp_rank: int
48
+ dp_world_size: int
49
+ batch_size: int
50
+
51
+ def __init__(
52
+ self,
53
+ dataset: IterableDataset,
54
+ dp_rank: int,
55
+ dp_world_size: int,
56
+ batch_size: int,
57
+ collate_fn: Callable | None = None,
58
+ ):
59
+ self.dp_world_size = dp_world_size
60
+ self.dp_rank = dp_rank
61
+ self.batch_size = batch_size
62
+ super().__init__(dataset, batch_size, collate_fn=collate_fn)
63
+ self._rank_id = f"dp_rank_{dp_rank}"
64
+
65
+ def state_dict(self) -> dict[str, Any]:
66
+ # Store state only for dp rank to avoid replicating the same state across other dimensions.
67
+ return {
68
+ # We don't have to use pickle as DCP will serialize the state_dict. However,
69
+ # we have to keep this for backward compatibility.
70
+ self._rank_id: pickle.dumps(super().state_dict()),
71
+ "world_size": self.dp_world_size,
72
+ }
73
+
74
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
75
+ # State being empty is valid.
76
+ if not state_dict:
77
+ return
78
+
79
+ if self._rank_id not in state_dict:
80
+ logger.warning(
81
+ f"DataLoader state is empty for dp rank {self.dp_rank}, "
82
+ "expected key {self._rank_id}"
83
+ )
84
+ return
85
+
86
+ assert self.dp_world_size == state_dict["world_size"], (
87
+ "dp_degree is inconsistent before and after checkpoint, "
88
+ "dataloader resharding is not supported yet."
89
+ )
90
+ # We don't have to use pickle as DCP will serialize the state_dict. However, we have to
91
+ # keep this for backward compatibility.
92
+ super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
torchtitan/distributed/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (255 Bytes). View file
 
torchtitan/experiments/deepseek_v3/LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchtitan/experiments/deepseek_v3/attn_mask_utils.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on src/transformers/modeling_attn_mask_utils.py of
8
+ # huggingface/transformers. It has been modified from its original forms to
9
+ # contain only the necessary utilities.
10
+
11
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ from dataclasses import dataclass
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+
29
+
30
+ @dataclass
31
+ class AttentionMaskConverter:
32
+ """
33
+ A utility attention mask class that allows one to:
34
+ - Create a causal 4d mask
35
+ - Create a causal 4d mask with slided window
36
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
37
+ key_value_length) that can be multiplied with attention scores
38
+
39
+ Examples:
40
+
41
+ ```python
42
+ >>> import torch
43
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
44
+
45
+ >>> converter = AttentionMaskConverter(True)
46
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
47
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
48
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
49
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
50
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
51
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
52
+ ```
53
+
54
+ Parameters:
55
+ is_causal (`bool`):
56
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
57
+
58
+ sliding_window (`int`, *optional*):
59
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
60
+ """
61
+
62
+ is_causal: bool
63
+ sliding_window: int
64
+
65
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
66
+ self.is_causal = is_causal
67
+ self.sliding_window = sliding_window
68
+
69
+ if self.sliding_window is not None and self.sliding_window <= 0:
70
+ raise ValueError(
71
+ "Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
72
+ f"not `{self.sliding_window}`"
73
+ )
74
+
75
+ def to_causal_4d(
76
+ self,
77
+ batch_size: int,
78
+ query_length: int,
79
+ key_value_length: int,
80
+ dtype: torch.dtype,
81
+ device: Union[torch.device, "str"] = "cpu",
82
+ ) -> Optional[torch.Tensor]:
83
+ """
84
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
85
+ bias to upper right hand triangular matrix (causal mask).
86
+ """
87
+ if not self.is_causal:
88
+ raise ValueError(
89
+ f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
90
+ )
91
+
92
+ # If shape is not cached, create a new causal mask and cache it
93
+ input_shape = (batch_size, query_length)
94
+ past_key_values_length = key_value_length - query_length
95
+
96
+ # create causal mask
97
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
98
+ causal_4d_mask = None
99
+ if input_shape[-1] > 1 or self.sliding_window is not None:
100
+ causal_4d_mask = self._make_causal_mask(
101
+ input_shape,
102
+ dtype,
103
+ device=device,
104
+ past_key_values_length=past_key_values_length,
105
+ sliding_window=self.sliding_window,
106
+ )
107
+
108
+ return causal_4d_mask
109
+
110
+ def to_4d(
111
+ self,
112
+ attention_mask_2d: torch.Tensor,
113
+ query_length: int,
114
+ dtype: torch.dtype,
115
+ key_value_length: Optional[int] = None,
116
+ ) -> torch.Tensor:
117
+ """
118
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
119
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
120
+ causal, a causal mask will be added.
121
+ """
122
+ input_shape = (attention_mask_2d.shape[0], query_length)
123
+
124
+ # create causal mask
125
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
126
+ causal_4d_mask = None
127
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
128
+ if key_value_length is None:
129
+ raise ValueError(
130
+ "This attention mask converter is causal. Make sure to pass "
131
+ "`key_value_length` to correctly create a causal mask."
132
+ )
133
+
134
+ past_key_values_length = key_value_length - query_length
135
+ causal_4d_mask = self._make_causal_mask(
136
+ input_shape,
137
+ dtype,
138
+ device=attention_mask_2d.device,
139
+ past_key_values_length=past_key_values_length,
140
+ sliding_window=self.sliding_window,
141
+ )
142
+ elif self.sliding_window is not None:
143
+ raise NotImplementedError(
144
+ "Sliding window is currently only implemented for causal masking"
145
+ )
146
+
147
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
+ expanded_attn_mask = self._expand_mask(
149
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
150
+ ).to(attention_mask_2d.device)
151
+
152
+ if causal_4d_mask is not None:
153
+ expanded_attn_mask = causal_4d_mask.masked_fill(
154
+ expanded_attn_mask.bool(), torch.finfo(dtype).min
155
+ )
156
+
157
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
158
+ expanded_4d_mask = expanded_attn_mask
159
+
160
+ return expanded_4d_mask
161
+
162
+ @staticmethod
163
+ def _make_causal_mask(
164
+ input_ids_shape: torch.Size,
165
+ dtype: torch.dtype,
166
+ device: torch.device,
167
+ past_key_values_length: int = 0,
168
+ sliding_window: Optional[int] = None,
169
+ ):
170
+ """
171
+ Make causal mask used for bi-directional self-attention.
172
+ """
173
+ bsz, tgt_len = input_ids_shape
174
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
175
+ mask_cond = torch.arange(mask.size(-1), device=device)
176
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
177
+
178
+ mask = mask.to(dtype)
179
+
180
+ if past_key_values_length > 0:
181
+ mask = torch.cat(
182
+ [
183
+ torch.zeros(
184
+ tgt_len, past_key_values_length, dtype=dtype, device=device
185
+ ),
186
+ mask,
187
+ ],
188
+ dim=-1,
189
+ )
190
+
191
+ # add lower triangular sliding window mask if necessary
192
+ if sliding_window is not None:
193
+ diagonal = past_key_values_length - sliding_window - 1
194
+
195
+ context_mask = torch.tril(
196
+ torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
197
+ )
198
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
199
+
200
+ return mask[None, None, :, :].expand(
201
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
202
+ )
203
+
204
+ @staticmethod
205
+ def _expand_mask(
206
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
207
+ ):
208
+ """
209
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
210
+ """
211
+ bsz, src_len = mask.size()
212
+ tgt_len = tgt_len if tgt_len is not None else src_len
213
+
214
+ expanded_mask = (
215
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
216
+ )
217
+
218
+ inverted_mask = 1.0 - expanded_mask
219
+
220
+ return inverted_mask.masked_fill(
221
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
222
+ )
223
+
224
+ @staticmethod
225
+ def _unmask_unattended(
226
+ expanded_mask: torch.FloatTensor,
227
+ min_dtype: float,
228
+ ):
229
+ # fmt: off
230
+ """
231
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
232
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
233
+ Details: https://github.com/pytorch/pytorch/issues/110213
234
+
235
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
236
+ `attention_mask` is [bsz, src_seq_len].
237
+
238
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
239
+ of alibi attention bias.
240
+
241
+ For example, if `expanded_mask` is (e.g. here left-padding case)
242
+ ```
243
+ [[[[0, 0, 0],
244
+ [0, 0, 0],
245
+ [0, 0, 1]]],
246
+ [[[1, 0, 0],
247
+ [1, 1, 0],
248
+ [1, 1, 1]]],
249
+ [[[0, 0, 0],
250
+ [0, 1, 0],
251
+ [0, 1, 1]]]]
252
+ ```
253
+ then the modified `expanded_mask` will be
254
+ ```
255
+ [[[[1, 1, 1], <-- modified
256
+ [1, 1, 1], <-- modified
257
+ [0, 0, 1]]],
258
+ [[[1, 0, 0],
259
+ [1, 1, 0],
260
+ [1, 1, 1]]],
261
+ [[[1, 1, 1], <-- modified
262
+ [0, 1, 0],
263
+ [0, 1, 1]]]]
264
+ ```
265
+ """
266
+ # fmt: on
267
+ if expanded_mask.dtype == torch.bool:
268
+ raise ValueError(
269
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
270
+ )
271
+
272
+ return expanded_mask.mul(
273
+ ~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
274
+ )
275
+
276
+ @staticmethod
277
+ def _ignore_causal_mask_sdpa(
278
+ attention_mask: Optional[torch.Tensor],
279
+ inputs_embeds: torch.Tensor,
280
+ past_key_values_length: int,
281
+ sliding_window: Optional[int] = None,
282
+ is_training: bool = False,
283
+ ) -> bool:
284
+ """
285
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
286
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
287
+
288
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
289
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
290
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
291
+ passed).
292
+ """
293
+
294
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
295
+ key_value_length = query_length + past_key_values_length
296
+
297
+ is_tracing = (
298
+ torch.jit.is_tracing()
299
+ or isinstance(inputs_embeds, torch.fx.Proxy)
300
+ or is_torchdynamo_compiling()
301
+ )
302
+
303
+ ignore_causal_mask = False
304
+
305
+ if attention_mask is None:
306
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
307
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
308
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
309
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
310
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
311
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
312
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
313
+ #
314
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
315
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
316
+ if (
317
+ (is_training or not is_tracing)
318
+ and (query_length == 1 or key_value_length == query_length)
319
+ and (sliding_window is None or key_value_length < sliding_window)
320
+ ):
321
+ ignore_causal_mask = True
322
+ elif sliding_window is None or key_value_length < sliding_window:
323
+ if len(attention_mask.shape) == 4:
324
+ return False
325
+ elif not is_tracing and torch.all(attention_mask == 1):
326
+ if query_length == 1 or key_value_length == query_length:
327
+ # For query_length == 1, causal attention and bi-directional attention are the same.
328
+ ignore_causal_mask = True
329
+
330
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
331
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
332
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
333
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
334
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
335
+
336
+ return ignore_causal_mask
337
+
338
+
339
+ def _prepare_4d_causal_attention_mask(
340
+ attention_mask: Optional[torch.Tensor],
341
+ input_shape: Union[torch.Size, Tuple, List],
342
+ inputs_embeds: torch.Tensor,
343
+ past_key_values_length: int,
344
+ sliding_window: Optional[int] = None,
345
+ ):
346
+ """
347
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
348
+ `(batch_size, key_value_length)`
349
+
350
+ Args:
351
+ attention_mask (`torch.Tensor` or `None`):
352
+ A 2D attention mask of shape `(batch_size, key_value_length)`
353
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
354
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
355
+ inputs_embeds (`torch.Tensor`):
356
+ The embedded inputs as a torch Tensor.
357
+ past_key_values_length (`int`):
358
+ The length of the key value cache.
359
+ sliding_window (`int`, *optional*):
360
+ If the model uses windowed attention, a sliding window should be passed.
361
+ """
362
+ attn_mask_converter = AttentionMaskConverter(
363
+ is_causal=True, sliding_window=sliding_window
364
+ )
365
+
366
+ key_value_length = input_shape[-1] + past_key_values_length
367
+
368
+ # 4d mask is passed through the layers
369
+ if attention_mask is not None and len(attention_mask.shape) == 2:
370
+ attention_mask = attn_mask_converter.to_4d(
371
+ attention_mask,
372
+ input_shape[-1],
373
+ key_value_length=key_value_length,
374
+ dtype=inputs_embeds.dtype,
375
+ )
376
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
377
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
378
+ if tuple(attention_mask.shape) != expected_shape:
379
+ raise ValueError(
380
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
381
+ )
382
+ else:
383
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
384
+ inverted_mask = 1.0 - attention_mask
385
+ attention_mask = inverted_mask.masked_fill(
386
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
387
+ )
388
+ else:
389
+ attention_mask = attn_mask_converter.to_causal_4d(
390
+ input_shape[0],
391
+ input_shape[-1],
392
+ key_value_length,
393
+ dtype=inputs_embeds.dtype,
394
+ device=inputs_embeds.device,
395
+ )
396
+
397
+ return attention_mask
torchtitan/experiments/deepseek_v3/generate.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 4 generate.py
8
+
9
+ # use inference.sh "Your Question Here?" to run inference with a single prompt.
10
+
11
+ import sys
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ from checkpoint import load_weights_from_hf
18
+ from model import DeepseekForCausalLM
19
+ from model_config import deepseek_config_registry
20
+ from torch.distributed.device_mesh import DeviceMesh
21
+ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
22
+ from torchtitan.tools.utils import Color
23
+ from transformers import AutoTokenizer
24
+
25
+ # Uncomment the model you want to run.
26
+ model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
27
+ # model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
28
+
29
+
30
+ def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
31
+ """Parse and colorize chat output with optional colors for each role."""
32
+ lines = text.split("\n")
33
+ result = []
34
+
35
+ current_role = None
36
+ current_content = []
37
+
38
+ def _process_current_content():
39
+ if not current_role or not current_content:
40
+ return None
41
+
42
+ content = "\n".join(current_content)
43
+ if current_role == "output":
44
+ return (
45
+ f"Output: {output_color}{content}{color.reset}"
46
+ if output_color
47
+ else f"Output: {content}"
48
+ )
49
+ else:
50
+ try:
51
+ prefix, rest = current_content[0].split(":", 1)
52
+ role_color = user_color if current_role == "user" else assistant_color
53
+ if role_color:
54
+ formatted = f"{prefix}:{role_color}{rest}{color.reset}"
55
+ if len(current_content) > 1:
56
+ formatted += (
57
+ f"{role_color}\n"
58
+ + "\n".join(current_content[1:])
59
+ + f"{color.reset}"
60
+ )
61
+ return formatted
62
+ except ValueError:
63
+ pass
64
+ return content
65
+
66
+ for line in lines:
67
+ if line.startswith("Output:"):
68
+ if processed := _process_current_content():
69
+ result.append(processed)
70
+ current_role = "output"
71
+ content = line[len("Output:") :].strip()
72
+ if output_color:
73
+ content = f"Output: {output_color}{content}{color.reset}"
74
+ else:
75
+ content = f"Output: {content}"
76
+ result.append(content)
77
+ current_content = []
78
+
79
+ elif line.startswith("User:"):
80
+ if processed := _process_current_content():
81
+ result.append(processed)
82
+ current_role = "user"
83
+ current_content = [line]
84
+
85
+ elif line.startswith("Assistant:"):
86
+ if processed := _process_current_content():
87
+ result.append(processed)
88
+ current_role = "assistant"
89
+ current_content = [line]
90
+
91
+ else:
92
+ if current_content:
93
+ current_content.append(line)
94
+ elif line.strip() and current_role is None:
95
+ # Handle system message at the beginning
96
+ current_role = "output"
97
+ if output_color:
98
+ result.append(f"Output: {output_color}{line.strip()}{color.reset}")
99
+ else:
100
+ result.append(f"Output: {line.strip()}")
101
+
102
+ # Process the last segment
103
+ if processed := _process_current_content():
104
+ result.append(processed)
105
+
106
+ return "\n".join(result)
107
+
108
+
109
+ color = Color()
110
+
111
+
112
+ @dataclass
113
+ class DistConfig:
114
+ mesh: DeviceMesh
115
+ pp_mesh: DeviceMesh
116
+ ep_mesh: DeviceMesh
117
+ pp_size: int
118
+ ep_size: int
119
+ ep_rank: int
120
+ pp_rank: int
121
+ device: torch.device
122
+
123
+
124
+ def create_model(dist_config: DistConfig):
125
+ model_args = deepseek_config_registry[model_id]
126
+ model_args.ep_size = dist_config.ep_size
127
+ model_args.num_stages = dist_config.pp_size
128
+ model_args.stage_idx = dist_config.pp_rank
129
+ model_args.max_seq_len = 16384
130
+
131
+ with dist_config.device, dist_config.mesh:
132
+ model = DeepseekForCausalLM(model_args)
133
+ load_weights_from_hf(model, model_id, dist_config.device)
134
+ model.eval()
135
+ model.setup_symm_mem(torch.bfloat16, dist_config.device)
136
+
137
+ stage = PipelineStage(
138
+ model,
139
+ dist_config.pp_rank,
140
+ dist_config.pp_size,
141
+ dist_config.device,
142
+ group=dist_config.pp_mesh.get_group(),
143
+ )
144
+ pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
145
+ return model, pp_schedule
146
+
147
+
148
+ def create_dist_config(mesh: DeviceMesh):
149
+ rank = dist.get_rank()
150
+ device_count = torch.cuda.device_count()
151
+ device = torch.device("cuda", rank % device_count)
152
+
153
+ dist_config = DistConfig(
154
+ mesh=mesh,
155
+ pp_mesh=mesh["pp"],
156
+ ep_mesh=mesh["ep"],
157
+ pp_rank=mesh["pp"].get_local_rank(),
158
+ pp_size=mesh["pp"].size(),
159
+ ep_size=mesh["ep"].size(),
160
+ ep_rank=mesh["ep"].get_local_rank(),
161
+ device=device,
162
+ )
163
+ return dist_config
164
+
165
+
166
+ def decode(tokenizer, x):
167
+ output = tokenizer.decode(x[0])
168
+ # Clean up the output by removing special tokens
169
+ bos = tokenizer.bos_token
170
+ output = output.replace(bos, "")
171
+ # Truncate at end of sentence token
172
+ eos_token = tokenizer.eos_token
173
+ if eos_token and eos_token in output:
174
+ output = output.split(eos_token)[0]
175
+ colored_output = colorize_chat(
176
+ output,
177
+ user_color=color.green,
178
+ assistant_color=color.cyan,
179
+ output_color=color.blue,
180
+ )
181
+ return colored_output
182
+
183
+
184
+ @torch.inference_mode()
185
+ def generate(
186
+ model,
187
+ pp_schedule,
188
+ tokenizer,
189
+ dist_config,
190
+ messages: list[dict],
191
+ n_tokens: int = 50,
192
+ ):
193
+ rank = dist.get_rank()
194
+ device = dist_config.device
195
+ x = tokenizer.apply_chat_template(
196
+ [messages] * dist_config.pp_size,
197
+ add_generation_prompt=True,
198
+ return_tensors="pt",
199
+ )
200
+ next_idx = x.shape[-1]
201
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
202
+ x = x.to(device)
203
+
204
+ for _ in range(n_tokens):
205
+ if dist_config.pp_size > 1:
206
+ if dist_config.pp_rank == 0:
207
+ pp_schedule.step(x)
208
+ torch.distributed.broadcast(
209
+ x,
210
+ group=dist_config.pp_mesh.get_group(),
211
+ group_src=dist_config.pp_size - 1,
212
+ )
213
+ elif dist_config.pp_rank == dist_config.pp_size - 1:
214
+ preds = pp_schedule.step()
215
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
216
+ x[:, next_idx] = next_token
217
+ torch.distributed.broadcast(
218
+ x,
219
+ group=dist_config.pp_mesh.get_group(),
220
+ group_src=dist_config.pp_size - 1,
221
+ )
222
+ else:
223
+ pp_schedule.step()
224
+ torch.distributed.broadcast(
225
+ x,
226
+ group=dist_config.pp_mesh.get_group(),
227
+ group_src=dist_config.pp_size - 1,
228
+ )
229
+
230
+ next_idx += 1
231
+ else:
232
+ preds = model(x)
233
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
234
+ x[:, next_idx] = next_token
235
+ next_idx += 1
236
+
237
+ if rank == 0:
238
+ colored_output = decode(tokenizer, x)
239
+ print(f"Without CUDA Graph:\n{colored_output}")
240
+
241
+
242
+ @torch.inference_mode()
243
+ def generate_with_cuda_graph(
244
+ model,
245
+ tokenizer,
246
+ dist_config,
247
+ messages: list[dict],
248
+ n_tokens: int = 10,
249
+ ):
250
+ rank = dist.get_rank()
251
+ device = dist_config.device
252
+ x = tokenizer.apply_chat_template(
253
+ [messages] * dist_config.pp_size,
254
+ add_generation_prompt=True,
255
+ return_tensors="pt",
256
+ )
257
+ next_idx = x.shape[-1]
258
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
259
+ x = x.to(device)
260
+
261
+ torch.cuda.synchronize()
262
+
263
+ # Create CUDA graph
264
+ g = torch.cuda.CUDAGraph()
265
+ with torch.cuda.graph(g):
266
+ preds = model(x)
267
+
268
+ # Run CUDA graph
269
+ for _ in range(n_tokens):
270
+ g.replay()
271
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
272
+ x[:, next_idx] = next_token
273
+ next_idx += 1
274
+
275
+ if rank == 0:
276
+ colored_output = decode(tokenizer, x)
277
+ print(f"With CUDA Graph:\n{colored_output}")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # Get user prompt from command line arguments
282
+ user_prompt = "What is 2+2?" # Default prompt
283
+ if len(sys.argv) > 1:
284
+ user_prompt = sys.argv[1]
285
+
286
+ mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
287
+ rank = dist.get_rank()
288
+ if rank == 0:
289
+ print(
290
+ f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
291
+ )
292
+
293
+ dist_config = create_dist_config(mesh)
294
+ model, pp_schedule = create_model(dist_config)
295
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
296
+
297
+ messages = [
298
+ {"role": "system", "content": "You are a helpful assistant."},
299
+ {"role": "user", "content": user_prompt},
300
+ ]
301
+
302
+ generate(model, pp_schedule, tokenizer, dist_config, messages)
303
+ generate_with_cuda_graph(model, tokenizer, dist_config, messages)
304
+
305
+ if rank == 0:
306
+ print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
307
+
308
+ dist.destroy_process_group()
torchtitan/experiments/deepseek_v3/model_config.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class ModelArgs:
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+ Args:
19
+ vocab_size (`int`, *optional*, defaults to 129280):
20
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
21
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
22
+ hidden_size (`int`, *optional*, defaults to 4096):
23
+ Dimension of the hidden representations.
24
+ intermediate_size (`int`, *optional*, defaults to 11008):
25
+ Dimension of the MLP representations.
26
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
27
+ Dimension of the MoE representations.
28
+ num_hidden_layers (`int`, *optional*, defaults to 32):
29
+ Number of hidden layers in the Transformer decoder.
30
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
31
+ Number of nextn predict layers in the DeepSeekV3 Model.
32
+ num_attention_heads (`int`, *optional*, defaults to 32):
33
+ Number of attention heads for each attention layer in the Transformer decoder.
34
+ n_shared_experts (`int`, *optional*, defaults to None):
35
+ Number of shared experts, None means dense model.
36
+ n_routed_experts (`int`, *optional*, defaults to None):
37
+ Number of routed experts, None means dense model.
38
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
39
+ Scaling factor or routed experts.
40
+ topk_method (`str`, *optional*, defaults to `gready`):
41
+ Topk method used in routed gate.
42
+ n_group (`int`, *optional*, defaults to None):
43
+ Number of groups for routed experts.
44
+ topk_group (`int`, *optional*, defaults to None):
45
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within
46
+ `topk_group` groups).
47
+ num_experts_per_tok (`int`, *optional*, defaults to None):
48
+ Number of selected experts, None means dense model.
49
+ moe_layer_freq (`int`, *optional*, defaults to 1):
50
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
51
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
52
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
53
+ \--k dense layers--/
54
+ norm_topk_prob (`bool`, *optional*, defaults to False):
55
+ Whether to normalize the weights of the routed experts.
56
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
57
+ Method of computing expert weights.
58
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
59
+ Auxiliary loss weight coefficient.
60
+ seq_aux = (`bool`, *optional*, defaults to True):
61
+ Whether to compute the auxiliary loss for each individual sample.
62
+ num_key_value_heads (`int`, *optional*):
63
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
64
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
65
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
66
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
67
+ by meanpooling all the original heads within that group. For more details checkout [this
68
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
69
+ `num_attention_heads`.
70
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function (function or string) in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ pad_token_id (`int`, *optional*):
82
+ Padding token id.
83
+ bos_token_id (`int`, *optional*, defaults to 1):
84
+ Beginning of stream token id.
85
+ eos_token_id (`int`, *optional*, defaults to 2):
86
+ End of stream token id.
87
+ pretraining_tp (`int`, *optional*, defaults to 1):
88
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
89
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
90
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
91
+ issue](https://github.com/pytorch/pytorch/issues/76232).
92
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to tie weight embeddings
94
+ rope_theta (`float`, *optional*, defaults to 10000.0):
95
+ The base period of the RoPE embeddings.
96
+ rope_scaling (`Dict`, *optional*):
97
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
98
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
99
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
100
+ `max_position_embeddings` to the expected new maximum.
101
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
102
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
103
+ attention_dropout (`float`, *optional*, defaults to 0.0):
104
+ The dropout ratio for the attention probabilities.
105
+ """
106
+
107
+ vocab_size: int = 129280
108
+ hidden_size: int = 7168
109
+ intermediate_size: int = 18432
110
+ moe_intermediate_size: int = 2048
111
+ num_hidden_layers: int = 61
112
+ num_nextn_predict_layers: int = 1
113
+ num_attention_heads: int = 128
114
+ num_key_value_heads: int = 128
115
+ n_shared_experts: int = 1
116
+ n_routed_experts: int = 256
117
+ ep_size: int = 1
118
+ routed_scaling_factor: float = 2.5
119
+ kv_lora_rank: int = 512
120
+ q_lora_rank: int = 1536
121
+ qk_rope_head_dim: int = 64
122
+ v_head_dim: int = 128
123
+ qk_nope_head_dim: int = 128
124
+ topk_method: str = "noaux_tc"
125
+ n_group: int = 8
126
+ topk_group: int = 4
127
+ num_experts_per_tok: int = 8
128
+ moe_layer_freq: int = 1
129
+ first_k_dense_replace: int = 3
130
+ norm_topk_prob: bool = True
131
+ scoring_func: str = "sigmoid"
132
+ aux_loss_alpha: float = 0.001
133
+ seq_aux: bool = True
134
+ hidden_act: str = "silu"
135
+ max_position_embeddings: int = 163840
136
+ initializer_range: float = 0.02
137
+ rms_norm_eps: float = 1e-6
138
+ rope_theta: float = 10000.0
139
+ rope_scaling: dict = field(
140
+ default_factory=lambda: {
141
+ "beta_fast": 32,
142
+ "beta_slow": 1,
143
+ "factor": 40,
144
+ "mscale": 1.0,
145
+ "mscale_all_dim": 1.0,
146
+ "original_max_position_embeddings": 4096,
147
+ "type": "yarn",
148
+ }
149
+ )
150
+ attention_bias: bool = False
151
+ attention_dropout: float = 0.0
152
+ pad_token_id = None
153
+ # Added for symmetric memory
154
+ max_seq_len: int = 4096
155
+ dtype: str = "bfloat16"
156
+ # Added for pipeline parallel
157
+ num_stages: int = 1
158
+ stage_idx: int = 0
159
+
160
+
161
+ # This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
162
+ deepseek_v2_lite_config = ModelArgs(
163
+ vocab_size=102400,
164
+ hidden_size=2048,
165
+ intermediate_size=10944,
166
+ moe_intermediate_size=1408,
167
+ num_hidden_layers=27,
168
+ num_attention_heads=16,
169
+ num_key_value_heads=16,
170
+ n_shared_experts=2,
171
+ n_routed_experts=64,
172
+ routed_scaling_factor=1.0,
173
+ kv_lora_rank=512,
174
+ q_lora_rank=None,
175
+ qk_rope_head_dim=64,
176
+ v_head_dim=128,
177
+ qk_nope_head_dim=128,
178
+ topk_method="greedy",
179
+ n_group=1,
180
+ topk_group=1,
181
+ num_experts_per_tok=6,
182
+ first_k_dense_replace=1,
183
+ norm_topk_prob=False,
184
+ scoring_func="softmax",
185
+ max_position_embeddings=4096,
186
+ rope_scaling={
187
+ "beta_fast": 32,
188
+ "beta_slow": 1,
189
+ "factor": 40,
190
+ "mscale": 0.707,
191
+ "mscale_all_dim": 0.707,
192
+ "original_max_position_embeddings": 4096,
193
+ "type": "yarn",
194
+ },
195
+ )
196
+
197
+
198
+ # Model configuration registry
199
+ # Key is the model distribution ID on HuggingFace Hub
200
+ deepseek_config_registry = {
201
+ "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
202
+ "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
203
+ "deepseek-ai/deepseek-v3": ModelArgs(),
204
+ }
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def get_tid():
13
+ return tl.inline_asm_elementwise(
14
+ """
15
+ mov.u32 $0, %tid.x;
16
+ mov.u32 $1, %tid.y;
17
+ mov.u32 $2, %tid.z;
18
+ """,
19
+ "=r,=r,=r",
20
+ [],
21
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
22
+ is_pure=True,
23
+ pack=1,
24
+ )
25
+
26
+
27
+ @triton.jit
28
+ def get_ntid():
29
+ return tl.inline_asm_elementwise(
30
+ """
31
+ mov.u32 $0, %ntid.x;
32
+ mov.u32 $1, %ntid.y;
33
+ mov.u32 $2, %ntid.z;
34
+ """,
35
+ "=r,=r,=r",
36
+ [],
37
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
38
+ is_pure=True,
39
+ pack=1,
40
+ )
41
+
42
+
43
+ @triton.jit
44
+ def get_flat_tid():
45
+ tid_x, tid_y, tid_z = get_tid()
46
+ ntid_x, ntid_y, _ = get_ntid()
47
+ return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
48
+
49
+
50
+ @triton.jit
51
+ def get_flat_bid():
52
+ return (
53
+ tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
54
+ + tl.program_id(1) * tl.num_programs(0)
55
+ + tl.program_id(0)
56
+ )
57
+
58
+
59
+ @triton.jit
60
+ def sync_threads():
61
+ tl.inline_asm_elementwise(
62
+ "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
63
+ )
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/dataset/flux_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from datasets import Dataset, load_dataset
17
+ from datasets.distributed import split_dataset_by_node
18
+ from PIL import Image
19
+
20
+ from torch.distributed.checkpoint.stateful import Stateful
21
+
22
+ from torch.utils.data import IterableDataset
23
+ from torchtitan.components.dataloader import ParallelAwareDataloader
24
+
25
+ from torchtitan.config_manager import JobConfig
26
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
27
+ from torchtitan.tools.logging import logger
28
+
29
+
30
+ def _process_cc12m_image(
31
+ img: Image.Image,
32
+ output_size: int = 256,
33
+ ) -> Optional[torch.Tensor]:
34
+ """Process CC12M image to the desired size."""
35
+
36
+ width, height = img.size
37
+ # Skip low resolution images
38
+ if width < output_size or height < output_size:
39
+ return None
40
+
41
+ if width >= height:
42
+ # resize height to be equal to output_size, then crop
43
+ new_width, new_height = math.ceil(output_size / height * width), output_size
44
+ img = img.resize((new_width, new_height))
45
+ left = random.randint(0, new_width - output_size)
46
+ resized_img = img.crop((left, 0, left + output_size, output_size))
47
+ else:
48
+ # resize width to be equal to output_size, the crop
49
+ new_width, new_height = (
50
+ output_size,
51
+ math.ceil(output_size / width * height),
52
+ )
53
+ img = img.resize((new_width, new_height))
54
+ lower = random.randint(0, new_width - output_size)
55
+ resized_img = img.crop((0, lower, output_size, lower + output_size))
56
+
57
+ assert resized_img.size[0] == resized_img.size[1] == output_size
58
+
59
+ # Skip grayscale images
60
+ if resized_img.mode == "L":
61
+ return None
62
+
63
+ np_img = np.array(resized_img).transpose((2, 0, 1))
64
+ tensor_img = torch.tensor(np_img).float() / 255.0
65
+
66
+ # NOTE: The following commented code is an alternative way
67
+ # img_transform = transforms.Compose(
68
+ # [
69
+ # transforms.Resize(max(output_size, output_size)),
70
+ # transforms.CenterCrop((output_size, output_size)),
71
+ # transforms.ToTensor(),
72
+ # ]
73
+ # )
74
+ # tensor_img = img_transform(img)
75
+
76
+ return tensor_img
77
+
78
+
79
+ def _flux_data_processor(
80
+ sample: dict[str, Any],
81
+ t5_tokenizer: FluxTokenizer,
82
+ clip_tokenizer: FluxTokenizer,
83
+ output_size: int = 256,
84
+ ) -> dict[str, Any]:
85
+ """
86
+ Preprocess CC12M dataset sample image and text for Flux model.
87
+
88
+ Args:
89
+ sample: A sample from dataset
90
+ t5_encoder: T5 encoder
91
+ clip_encoder: CLIP encoder
92
+ output_size: The output image size
93
+
94
+ """
95
+ img = _process_cc12m_image(sample["jpg"], output_size=output_size)
96
+ t5_tokens = t5_tokenizer.encode(sample["txt"])
97
+ clip_tokens = clip_tokenizer.encode(sample["txt"])
98
+
99
+ return {
100
+ "image": img,
101
+ "clip_tokens": clip_tokens, # type: List[int]
102
+ "t5_tokens": t5_tokens, # type: List[int]
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class TextToImageDatasetConfig:
108
+ path: str
109
+ loader: Callable
110
+ data_processor: Callable
111
+
112
+
113
+ DATASETS = {
114
+ "cc12m": TextToImageDatasetConfig(
115
+ path="pixparse/cc12m-wds",
116
+ loader=lambda path: load_dataset(path, split="train", streaming=True),
117
+ data_processor=_flux_data_processor,
118
+ ),
119
+ }
120
+
121
+
122
+ def _validate_dataset(
123
+ dataset_name: str, dataset_path: Optional[str] = None
124
+ ) -> tuple[str, Callable, Callable]:
125
+ """Validate dataset name and path."""
126
+ if dataset_name not in DATASETS:
127
+ raise ValueError(
128
+ f"Dataset {dataset_name} is not supported. "
129
+ f"Supported datasets are: {list(DATASETS.keys())}"
130
+ )
131
+
132
+ config = DATASETS[dataset_name]
133
+ path = dataset_path or config.path
134
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
135
+ return path, config.loader, config.data_processor
136
+
137
+
138
+ class FluxDataset(IterableDataset, Stateful):
139
+ """Dataset for FLUX text-to-image model.
140
+
141
+ Args:
142
+ dataset_name (str): Name of the dataset.
143
+ dataset_path (str): Path to the dataset.
144
+ model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
145
+ dp_rank (int): Data parallel rank.
146
+ dp_world_size (int): Data parallel world size.
147
+ infinite (bool): Whether to loop over the dataset infinitely.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ dataset_name: str,
153
+ dataset_path: Optional[str],
154
+ t5_tokenizer: FluxTokenizer,
155
+ clip_tokenizer: FluxTokenizer,
156
+ job_config: Optional[JobConfig] = None,
157
+ dp_rank: int = 0,
158
+ dp_world_size: int = 1,
159
+ infinite: bool = False,
160
+ ) -> None:
161
+
162
+ # Force lowercase for consistent comparison
163
+ dataset_name = dataset_name.lower()
164
+
165
+ path, dataset_loader, data_processor = _validate_dataset(
166
+ dataset_name, dataset_path
167
+ )
168
+ ds = dataset_loader(path)
169
+
170
+ self.dataset_name = dataset_name
171
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172
+
173
+ self._t5_tokenizer = t5_tokenizer
174
+ self._clip_tokenizer = clip_tokenizer
175
+ self._data_processor = data_processor
176
+ self.job_config = job_config
177
+
178
+ self.infinite = infinite
179
+
180
+ # Variables for checkpointing
181
+ self._sample_idx = 0
182
+ self._all_samples: list[dict[str, Any]] = []
183
+
184
+ def _get_data_iter(self):
185
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
186
+ return iter([])
187
+
188
+ it = iter(self._data)
189
+ for _ in range(self._sample_idx):
190
+ next(it)
191
+ return it
192
+
193
+ def __iter__(self):
194
+ while True:
195
+ for sample in self._get_data_iter():
196
+ # Use the dataset-specific preprocessor
197
+ sample_dict = self._data_processor(
198
+ sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199
+ )
200
+
201
+ # skip low quality image or image with color channel = 1
202
+ if sample_dict["image"] is None:
203
+ logger.warning(
204
+ f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
205
+ )
206
+ continue
207
+
208
+ self._all_samples.extend(sample_dict)
209
+ self._sample_idx += 1
210
+
211
+ labels = sample_dict.pop("image")
212
+ yield sample_dict, labels
213
+
214
+ if not self.infinite:
215
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
216
+ break
217
+ else:
218
+ # Reset offset for the next iteration
219
+ self._sample_idx = 0
220
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
221
+
222
+ def load_state_dict(self, state_dict):
223
+ self._sample_idx = state_dict["sample_idx"]
224
+ self._all_samples = state_dict["all_samples"]
225
+
226
+ def state_dict(self):
227
+ return {
228
+ "all_samples": self._all_samples,
229
+ "sample_idx": self._sample_idx,
230
+ }
231
+
232
+
233
+ def build_flux_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ job_config: JobConfig,
237
+ # This parameter is not used, keep it for compatibility
238
+ tokenizer: FluxTokenizer | None,
239
+ infinite: bool = True,
240
+ ) -> ParallelAwareDataloader:
241
+ """Build a data loader for HuggingFace datasets."""
242
+ dataset_name = job_config.training.dataset
243
+ dataset_path = job_config.training.dataset_path
244
+ batch_size = job_config.training.batch_size
245
+
246
+ t5_encoder_name = job_config.encoder.t5_encoder
247
+ clip_encoder_name = job_config.encoder.clip_encoder
248
+ max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
249
+
250
+ ds = FluxDataset(
251
+ dataset_name=dataset_name,
252
+ dataset_path=dataset_path,
253
+ t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
254
+ clip_tokenizer=FluxTokenizer(
255
+ clip_encoder_name, max_length=77
256
+ ), # fix max_length for CLIP
257
+ dp_rank=dp_rank,
258
+ dp_world_size=dp_world_size,
259
+ infinite=infinite,
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ )
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/math.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+
12
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
13
+ q, k = apply_rope(q, k, pe)
14
+
15
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
16
+ x = rearrange(x, "B H L D -> B L (H D)")
17
+
18
+ return x
19
+
20
+
21
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
22
+ assert dim % 2 == 0
23
+ scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
24
+ omega = 1.0 / (theta**scale)
25
+ out = torch.einsum("...n,d->...nd", pos, omega)
26
+ out = torch.stack(
27
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
28
+ )
29
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
30
+ return out.float()
31
+
32
+
33
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
34
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
35
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
36
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
37
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
38
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
torchtitan/experiments/flux/scripts/download_autoencoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from requests.exceptions import HTTPError
10
+
11
+
12
+ def hf_download(
13
+ repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
14
+ ) -> None:
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ try:
18
+ hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=file_path,
21
+ local_dir=local_dir,
22
+ local_dir_use_symlinks=False,
23
+ token=hf_token,
24
+ )
25
+ except HTTPError as e:
26
+ if e.response.status_code == 401:
27
+ print(
28
+ "You need to pass a valid `--hf_token=...` to download private checkpoints."
29
+ )
30
+ else:
31
+ raise e
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
38
+ parser.add_argument(
39
+ "--repo_id",
40
+ type=str,
41
+ default="black-forest-labs/FLUX.1-dev",
42
+ help="Repository ID to download from. default to Flux-dev model",
43
+ )
44
+ parser.add_argument(
45
+ "--ae_path",
46
+ type=str,
47
+ default="ae.safetensors",
48
+ help="the autoencoder path relative to repo_id",
49
+ )
50
+ parser.add_argument(
51
+ "--hf_token", type=str, default=None, help="HuggingFace API token"
52
+ )
53
+ parser.add_argument(
54
+ "--local_dir",
55
+ type=str,
56
+ default="torchtitan/experiments/flux/assets/autoencoder/",
57
+ help="local directory to save the autoencoder",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
11
+
12
+ import argparse
13
+ import logging
14
+ import time
15
+
16
+ # from typing import Dict, List, Optional, Tuple
17
+
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import torch
21
+ import triton
22
+
23
+ # import triton.language as tl
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+ # Try to import the optimized implementations
31
+ try:
32
+ from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
33
+
34
+ except ImportError:
35
+ logging.error(
36
+ "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
37
+ )
38
+ raise
39
+
40
+
41
+ def compute_reference_forward(x, w, m_sizes):
42
+ """
43
+ Reference PyTorch implementation of M*G grouped GEMM forward pass.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor of shape (M, K)
47
+ w (torch.Tensor): Weight tensor of shape (N, K)
48
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
49
+
50
+ Returns:
51
+ torch.Tensor: Output tensor of shape (M, N)
52
+ """
53
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
54
+
55
+ m_start = 0
56
+ for g in range(len(m_sizes)):
57
+ m_size = m_sizes[g].item()
58
+ if m_size > 0:
59
+ m_end = m_start + m_size
60
+
61
+ # Extract group input
62
+ x_g = x[m_start:m_end]
63
+
64
+ # Compute group output
65
+ y_g = torch.matmul(x_g, w.T)
66
+
67
+ # Store result
68
+ result[m_start:m_end] = y_g
69
+
70
+ # Update start index
71
+ m_start = m_end
72
+
73
+ return result
74
+
75
+
76
+ @triton.testing.perf_report(
77
+ triton.testing.Benchmark(
78
+ x_names=["N"], # We'll vary the output dimension
79
+ x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
80
+ # x_vals=[8192, 16384],
81
+ line_arg="provider", # We'll compare different providers
82
+ line_vals=["pytorch_reference", "M*G grouped GEMM"],
83
+ line_names=["PyTorch Reference", "M*G grouped Kernel"],
84
+ styles=[("blue", "-"), ("red", "-")],
85
+ ylabel="TFLOPS", # We'll measure TFLOPS
86
+ plot_name="mg_grouped_gemm_comparison",
87
+ args={
88
+ "M": 8192, # Batch dimension, fixed for all tests
89
+ "K": 7168, # Hidden dimension, fixed for all tests
90
+ "G": 8, # Number of groups
91
+ "dtype": torch.float16,
92
+ "device": "cuda",
93
+ },
94
+ )
95
+ )
96
+ def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
97
+ """
98
+ Benchmark the forward pass of the grouped GEMM implementation.
99
+
100
+ Args:
101
+ M (int): Total batch size dimension
102
+ K (int): Hidden dimension
103
+ N (int): Output dimension
104
+ G (int): Number of groups
105
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
106
+ dtype (torch.dtype): Data type to use
107
+ device (str): Device to use
108
+
109
+ Returns:
110
+ float: Performance in TFLOPS
111
+ """
112
+ # Create group sizes for M dimension (balanced across groups)
113
+ base_size = M // G
114
+ remainder = M % G
115
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
116
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
117
+
118
+ print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
119
+
120
+ # Create input and weight tensors
121
+ x = torch.randn(M, K, dtype=dtype, device=device)
122
+ w = torch.randn(N, K, dtype=dtype, device=device)
123
+
124
+ # Pre-compute for PyTorch reference to ensure fair comparison
125
+ if provider == "pytorch_reference":
126
+ # Warmup
127
+ torch.cuda.synchronize()
128
+ compute_reference_forward(x, w, m_sizes)
129
+ torch.cuda.synchronize()
130
+
131
+ # Benchmark
132
+ start_time = time.time()
133
+ for _ in range(10): # Average over 10 runs
134
+ compute_reference_forward(x, w, m_sizes)
135
+ torch.cuda.synchronize()
136
+ end_time = time.time()
137
+ else: # Optimized kernel
138
+ # Warmup
139
+ torch.cuda.synchronize()
140
+ grouped_gemm_forward(x, w, m_sizes)
141
+ torch.cuda.synchronize()
142
+
143
+ # Benchmark
144
+ start_time = time.time()
145
+ for _ in range(10): # Average over 10 runs
146
+ grouped_gemm_forward(x, w, m_sizes)
147
+ torch.cuda.synchronize()
148
+ end_time = time.time()
149
+
150
+ # Calculate FLOPs
151
+ # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
152
+ flops = 2 * M * N * K
153
+
154
+ # Convert to TFLOPS (tera-FLOPS)
155
+ avg_time = (end_time - start_time) / 10 # Average time per run
156
+ tflops = flops / avg_time / 1e12
157
+
158
+ return tflops
159
+
160
+
161
+ @triton.testing.perf_report(
162
+ triton.testing.Benchmark(
163
+ x_names=["G"], # We'll vary the number of groups
164
+ x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
165
+ line_arg="provider", # We'll compare different providers
166
+ line_vals=["pytorch_reference", "optimized_kernel"],
167
+ line_names=["PyTorch Reference", "Optimized Kernel"],
168
+ styles=[("blue", "-"), ("red", "-")],
169
+ ylabel="TFLOPS", # We'll measure TFLOPS
170
+ plot_name="mg_grouped_gemm_group_scaling",
171
+ args={
172
+ "M": 8192, # Batch dimension, fixed for all tests
173
+ "K": 4096, # Hidden dimension, fixed for all tests
174
+ "N": 8192, # Output dimension, fixed for all tests
175
+ "dtype": torch.float16,
176
+ "device": "cuda",
177
+ },
178
+ )
179
+ )
180
+ def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
181
+ """
182
+ Benchmark how performance scales with number of groups.
183
+
184
+ Args:
185
+ M (int): Total batch size dimension
186
+ K (int): Hidden dimension
187
+ N (int): Output dimension
188
+ G (int): Number of groups
189
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
190
+ dtype (torch.dtype): Data type to use
191
+ device (str): Device to use
192
+
193
+ Returns:
194
+ float: Performance in TFLOPS
195
+ """
196
+ # Create group sizes for M dimension (balanced across groups)
197
+ base_size = M // G
198
+ remainder = M % G
199
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
200
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
201
+
202
+ # Create input and weight tensors
203
+ x = torch.randn(M, K, dtype=dtype, device=device)
204
+ w = torch.randn(N, K, dtype=dtype, device=device)
205
+
206
+ # Benchmark logic - same as previous function
207
+ if provider == "pytorch_reference":
208
+ torch.cuda.synchronize()
209
+ compute_reference_forward(x, w, m_sizes)
210
+ torch.cuda.synchronize()
211
+
212
+ start_time = time.time()
213
+ for _ in range(10):
214
+ compute_reference_forward(x, w, m_sizes)
215
+ torch.cuda.synchronize()
216
+ end_time = time.time()
217
+ else:
218
+ torch.cuda.synchronize()
219
+ grouped_gemm_forward(x, w, m_sizes)
220
+ torch.cuda.synchronize()
221
+
222
+ start_time = time.time()
223
+ for _ in range(10):
224
+ grouped_gemm_forward(x, w, m_sizes)
225
+ torch.cuda.synchronize()
226
+ end_time = time.time()
227
+
228
+ # Calculate FLOPs and TFLOPS
229
+ flops = 2 * M * N * K
230
+ avg_time = (end_time - start_time) / 10
231
+ tflops = flops / avg_time / 1e12
232
+
233
+ return tflops
234
+
235
+
236
+ @triton.testing.perf_report(
237
+ triton.testing.Benchmark(
238
+ x_names=["group_balance"], # We'll vary the group balance factor
239
+ x_vals=[
240
+ 0.0,
241
+ 0.25,
242
+ 0.5,
243
+ 0.75,
244
+ 0.9,
245
+ ], # Different imbalance factors (0 = balanced, 1 = max imbalance)
246
+ line_arg="provider", # We'll compare different providers
247
+ line_vals=["pytorch_reference", "optimized_kernel"],
248
+ line_names=["PyTorch Reference", "Optimized Kernel"],
249
+ styles=[("blue", "-"), ("red", "-")],
250
+ ylabel="TFLOPS", # We'll measure TFLOPS
251
+ plot_name="mg_grouped_gemm_imbalance",
252
+ args={
253
+ "M": 8192, # Batch dimension, fixed for all tests
254
+ "K": 4096, # Hidden dimension, fixed for all tests
255
+ "N": 8192, # Output dimension, fixed for all tests
256
+ "G": 4, # Number of groups
257
+ "dtype": torch.float16,
258
+ "device": "cuda",
259
+ },
260
+ )
261
+ )
262
+ def benchmark_imbalance(
263
+ M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
264
+ ):
265
+ """
266
+ Benchmark how performance is affected by imbalanced group sizes.
267
+
268
+ Args:
269
+ M (int): Total batch size dimension
270
+ K (int): Hidden dimension
271
+ N (int): Output dimension
272
+ G (int): Number of groups
273
+ group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
274
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
275
+ dtype (torch.dtype): Data type to use
276
+ device (str): Device to use
277
+
278
+ Returns:
279
+ float: Performance in TFLOPS
280
+ """
281
+ # Create imbalanced group sizes for M dimension
282
+ if group_balance == 0:
283
+ # Balanced case
284
+ base_size = M // G
285
+ remainder = M % G
286
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
287
+ else:
288
+ # Imbalanced case
289
+ # First group gets more elements, last group gets fewer
290
+ # The imbalance is controlled by the group_balance factor
291
+ remaining = M
292
+ M_sizes = []
293
+ for g in range(G):
294
+ # Interpolate from balanced to imbalanced based on group_balance
295
+ # For balanced (group_balance=0), each group gets M/G
296
+ # For imbalanced (group_balance=1), first group gets much more than last group
297
+ balanced_size = remaining // (G - g)
298
+
299
+ # Adjusting size based on position and imbalance factor
300
+ # First groups get more, last groups get less
301
+ if g < G // 2:
302
+ # First half of groups get more
303
+ adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
304
+ size = balanced_size + adjustment
305
+ else:
306
+ # Second half of groups get less
307
+ adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
308
+ size = balanced_size - adjustment
309
+
310
+ # Ensure we don't go below 1 or take more than remaining
311
+ size = max(1, min(size, remaining))
312
+ M_sizes.append(size)
313
+ remaining -= size
314
+
315
+ # Handle any remaining elements
316
+ if remaining > 0:
317
+ M_sizes[-1] += remaining
318
+
319
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
320
+
321
+ # Create input and weight tensors
322
+ x = torch.randn(M, K, dtype=dtype, device=device)
323
+ w = torch.randn(N, K, dtype=dtype, device=device)
324
+
325
+ # Benchmark logic
326
+ if provider == "pytorch_reference":
327
+ torch.cuda.synchronize()
328
+ compute_reference_forward(x, w, m_sizes)
329
+ torch.cuda.synchronize()
330
+
331
+ start_time = time.time()
332
+ for _ in range(10):
333
+ compute_reference_forward(x, w, m_sizes)
334
+ torch.cuda.synchronize()
335
+ end_time = time.time()
336
+ else:
337
+ torch.cuda.synchronize()
338
+ grouped_gemm_forward(x, w, m_sizes)
339
+ torch.cuda.synchronize()
340
+
341
+ start_time = time.time()
342
+ for _ in range(10):
343
+ grouped_gemm_forward(x, w, m_sizes)
344
+ torch.cuda.synchronize()
345
+ end_time = time.time()
346
+
347
+ # Calculate FLOPs and TFLOPS
348
+ flops = 2 * M * N * K
349
+ avg_time = (end_time - start_time) / 10
350
+ tflops = flops / avg_time / 1e12
351
+
352
+ return tflops
353
+
354
+
355
+ def benchmark_model_configs():
356
+ """
357
+ Benchmark common model configurations used in DeepSeek-like models.
358
+ """
359
+ # Model configurations: (M, K, N, G)
360
+ configs = [
361
+ (8192, 7168, 4096, 4), # Config 1
362
+ (8192, 2048, 7168, 4), # Config 2
363
+ (4096, 7168, 4096, 8), # Config 3
364
+ (4096, 2048, 7168, 8), # Config 4
365
+ ]
366
+
367
+ results = []
368
+
369
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
370
+ dtype = torch.float16
371
+
372
+ for config_idx, (M, K, N, G) in enumerate(configs):
373
+ logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
374
+ logging.info(f"M={M}, K={K}, N={N}, G={G}")
375
+
376
+ # Create group sizes for M dimension
377
+ base_size = M // G
378
+ remainder = M % G
379
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
380
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
381
+
382
+ # Create tensors
383
+ x = torch.randn(M, K, dtype=dtype, device=device)
384
+ w = torch.randn(N, K, dtype=dtype, device=device)
385
+
386
+ # Benchmark PyTorch reference
387
+ torch.cuda.synchronize()
388
+ compute_reference_forward(x, w, m_sizes) # Warmup
389
+ torch.cuda.synchronize()
390
+
391
+ logging.info("Benchmarking PyTorch reference...")
392
+ torch.cuda.reset_peak_memory_stats()
393
+ start_time = time.time()
394
+ for _ in range(10):
395
+ compute_reference_forward(x, w, m_sizes)
396
+ torch.cuda.synchronize()
397
+ end_time = time.time()
398
+ pt_time = (end_time - start_time) / 10
399
+ pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
400
+
401
+ # Benchmark optimized kernel
402
+ torch.cuda.synchronize()
403
+ grouped_gemm_forward(x, w, m_sizes) # Warmup
404
+ torch.cuda.synchronize()
405
+
406
+ logging.info("Benchmarking optimized kernel...")
407
+ torch.cuda.reset_peak_memory_stats()
408
+ start_time = time.time()
409
+ for _ in range(10):
410
+ grouped_gemm_forward(x, w, m_sizes)
411
+ torch.cuda.synchronize()
412
+ end_time = time.time()
413
+ opt_time = (end_time - start_time) / 10
414
+ opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
415
+
416
+ # Calculate FLOPs and speedup
417
+ flops = 2 * M * N * K
418
+ pt_tflops = flops / pt_time / 1e12
419
+ opt_tflops = flops / opt_time / 1e12
420
+ speedup = pt_time / opt_time
421
+
422
+ # Store results
423
+ results.append(
424
+ {
425
+ "config": f"Config {config_idx + 1}",
426
+ "dimensions": f"M={M}, K={K}, N={N}, G={G}",
427
+ "pt_time_ms": pt_time * 1000,
428
+ "opt_time_ms": opt_time * 1000,
429
+ "pt_tflops": pt_tflops,
430
+ "opt_tflops": opt_tflops,
431
+ "speedup": speedup,
432
+ "pt_memory_mb": pt_memory,
433
+ "opt_memory_mb": opt_memory,
434
+ "memory_savings": (
435
+ (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
436
+ ),
437
+ }
438
+ )
439
+
440
+ logging.info(
441
+ f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
442
+ )
443
+ logging.info(
444
+ f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
445
+ )
446
+ logging.info(
447
+ f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
448
+ )
449
+
450
+ # Print summary table
451
+ logging.info("\n===== Benchmark Results Summary =====")
452
+ logging.info(
453
+ f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
454
+ )
455
+ logging.info(
456
+ f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
457
+ f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
458
+ )
459
+ logging.info("-" * 100)
460
+
461
+ for result in results:
462
+ logging.info(
463
+ f"{result['config']:<10} | "
464
+ f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
465
+ f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
466
+ f"{result['speedup']:<10.2f} | "
467
+ f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
468
+ f"{result['memory_savings']:<12.2f}%"
469
+ )
470
+
471
+ return results
472
+
473
+
474
+ def plot_benchmark_results(results):
475
+ """
476
+ Plot benchmark results as bar charts.
477
+ """
478
+ # Extract data
479
+ configs = [r["config"] for r in results]
480
+ pt_tflops = [r["pt_tflops"] for r in results]
481
+ opt_tflops = [r["opt_tflops"] for r in results]
482
+ speedups = [r["speedup"] for r in results]
483
+
484
+ # Create figure with subplots
485
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
486
+
487
+ # Plot TFLOPS comparison
488
+ x = np.arange(len(configs))
489
+ width = 0.35
490
+ ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
491
+ ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
492
+ ax1.set_xlabel("Model Configuration")
493
+ ax1.set_ylabel("TFLOPS")
494
+ ax1.set_title("Performance Comparison (Higher is Better)")
495
+ ax1.set_xticks(x)
496
+ ax1.set_xticklabels(configs)
497
+ ax1.legend()
498
+ ax1.grid(axis="y", linestyle="--", alpha=0.7)
499
+
500
+ # Plot speedup
501
+ ax2.bar(x, speedups, width=0.6, color="green")
502
+ ax2.set_xlabel("Model Configuration")
503
+ ax2.set_ylabel("Speedup (x)")
504
+ ax2.set_title("Speedup Factor (Higher is Better)")
505
+ ax2.set_xticks(x)
506
+ ax2.set_xticklabels(configs)
507
+ ax2.grid(axis="y", linestyle="--", alpha=0.7)
508
+
509
+ # Add speedup values on top of bars
510
+ for i, v in enumerate(speedups):
511
+ ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
512
+
513
+ plt.tight_layout()
514
+ plt.savefig("mg_grouped_gemm_benchmark_results.png")
515
+ logging.info(
516
+ "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
517
+ )
518
+
519
+
520
+ def compare_mg_implementations():
521
+ """
522
+ Combine the M*G and N*G benchmark results for comparison.
523
+ """
524
+ # Only run this if both NG and MG benchmarks have been run
525
+ try:
526
+ import pandas as pd
527
+
528
+ # Try to load previous benchmark results
529
+ mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
530
+ ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
531
+
532
+ # Create comparison plot
533
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
534
+
535
+ # Plot speedup comparison
536
+ configs = mg_results["config"].unique()
537
+ mg_speedups = mg_results.groupby("config")["speedup"].mean()
538
+ ng_speedups = ng_results.groupby("config")["speedup"].mean()
539
+
540
+ x = np.arange(len(configs))
541
+ width = 0.35
542
+
543
+ axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
544
+ axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
545
+ axes[0].set_xlabel("Model Configuration")
546
+ axes[0].set_ylabel("Speedup (x)")
547
+ axes[0].set_title("Speedup Comparison: M*G vs N*G")
548
+ axes[0].set_xticks(x)
549
+ axes[0].set_xticklabels(configs)
550
+ axes[0].legend()
551
+ axes[0].grid(axis="y", linestyle="--", alpha=0.7)
552
+
553
+ # Plot TFLOPS comparison for optimized kernels
554
+ mg_tflops = (
555
+ mg_results[mg_results["implementation"] == "optimized"]
556
+ .groupby("config")["tflops"]
557
+ .mean()
558
+ )
559
+ ng_tflops = (
560
+ ng_results[ng_results["implementation"] == "optimized"]
561
+ .groupby("config")["tflops"]
562
+ .mean()
563
+ )
564
+
565
+ axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
566
+ axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
567
+ axes[1].set_xlabel("Model Configuration")
568
+ axes[1].set_ylabel("TFLOPS")
569
+ axes[1].set_title("Performance Comparison: M*G vs N*G")
570
+ axes[1].set_xticks(x)
571
+ axes[1].set_xticklabels(configs)
572
+ axes[1].legend()
573
+ axes[1].grid(axis="y", linestyle="--", alpha=0.7)
574
+
575
+ plt.tight_layout()
576
+ plt.savefig("mg_vs_ng_comparison.png")
577
+ logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
578
+
579
+ except Exception as e:
580
+ logging.error(f"Could not create comparison plot: {e}")
581
+ logging.info(
582
+ "Run both M*G and N*G benchmarks first to generate comparison plots"
583
+ )
584
+
585
+
586
+ if __name__ == "__main__":
587
+ parser = argparse.ArgumentParser(
588
+ description="Benchmark M*G Grouped GEMM implementations"
589
+ )
590
+ parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
591
+ parser.add_argument(
592
+ "--triton-bench", action="store_true", help="Run Triton performance reports"
593
+ )
594
+ parser.add_argument(
595
+ "--model-configs", action="store_true", help="Benchmark model configurations"
596
+ )
597
+ parser.add_argument(
598
+ "--compare-mg-ng",
599
+ action="store_true",
600
+ help="Compare M*G and N*G implementations",
601
+ )
602
+ args = parser.parse_args()
603
+
604
+ # Check if CUDA is available
605
+ if not torch.cuda.is_available():
606
+ logging.error(
607
+ "CUDA is not available. This benchmark requires a CUDA-capable GPU."
608
+ )
609
+ exit(1)
610
+
611
+ if args.run_all or args.model_configs:
612
+ # Benchmark model configurations
613
+ logging.info("Running benchmark for model configurations...")
614
+ results = benchmark_model_configs()
615
+ plot_benchmark_results(results)
616
+
617
+ if args.run_all or args.triton_bench:
618
+ # Run Triton performance reports
619
+ logging.info("Running Triton performance reports...")
620
+ benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
621
+ benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
622
+ benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
623
+ logging.info(
624
+ "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
625
+ )
626
+
627
+ if args.run_all or args.compare_mg_ng:
628
+ # Compare M*G and N*G implementations
629
+ logging.info("Comparing M*G and N*G implementations...")
630
+ compare_mg_implementations()
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mg_grouped_gemm import grouped_gemm_forward
8
+ from .tma_autotuning import ALIGN_SIZE_M
9
+
10
+ __all__ = [
11
+ "grouped_gemm_forward",
12
+ "ALIGN_SIZE_M",
13
+ ]
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from reference_utils import (
14
+ analyze_tensor_differences,
15
+ compute_reference_backward,
16
+ compute_reference_forward,
17
+ )
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
+ )
23
+
24
+ # Import grouped GEMM implementations
25
+ try:
26
+ from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
27
+
28
+ except ImportError:
29
+ logging.error(
30
+ "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
31
+ )
32
+ raise
33
+
34
+
35
+ def test_forward_pass():
36
+ """
37
+ A simple test for the M*G grouped GEMM forward pass with detailed error handling.
38
+
39
+ In M*G grouping:
40
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
41
+ - N dimension is the same for all groups
42
+ """
43
+ try:
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Test parameters for DeepSeek-like models
47
+ G = 1 # Number of groups
48
+ M_sizes = [
49
+ 2048,
50
+ ] # 2048, 2048, 2048] # Group sizes (will be adjusted)
51
+ M_total = sum(M_sizes) # Total M dimension
52
+ N = 4096 # Output dimension (same for all groups)
53
+ K = 7168 # Hidden dimension
54
+
55
+ # Create group sizes tensor
56
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
57
+
58
+ # Create input and weight tensors - using float16 for higher precision
59
+ x = torch.randn(M_total, K, dtype=torch.float16, device=device)
60
+ w = torch.randn(N, K, dtype=torch.float16, device=device)
61
+
62
+ # Log the setup
63
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
64
+ logging.info(f"Group sizes: {m_sizes}")
65
+ logging.info(f"Input x shape: {x.shape}")
66
+ logging.info(f"Weight w shape: {w.shape}")
67
+
68
+ # Run forward pass
69
+ logging.info("Running forward pass with grouped GEMM")
70
+ result = grouped_gemm_forward(x, w, m_sizes)
71
+ logging.info(f"Forward result shape: {result.shape}")
72
+
73
+ # Compute reference result
74
+ logging.info("Computing reference result with PyTorch")
75
+ reference_result = compute_reference_forward(x, w, m_sizes)
76
+
77
+ # Compare results
78
+ logging.info("Comparing with PyTorch reference")
79
+ forward_close = analyze_tensor_differences(
80
+ result, reference_result, "Forward output"
81
+ )
82
+
83
+ return forward_close
84
+
85
+ except Exception as e:
86
+ logging.error(f"Test failed with error: {e}")
87
+ import traceback
88
+
89
+ logging.error(traceback.format_exc())
90
+ return False
91
+
92
+
93
+ def test_backward_pass():
94
+ """
95
+ A simple test for the M*G grouped GEMM backward pass with detailed error handling.
96
+
97
+ In M*G grouping:
98
+ - M dimension is partitioned into G groups (M_total = sum(M_sizes))
99
+ - N dimension is the same for all groups
100
+ """
101
+ try:
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # Test parameters for DeepSeek-like models
105
+ G = 4 # Number of groups
106
+ M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
107
+ M_total = sum(M_sizes) # Total M dimension
108
+ N = 4096 # Output dimension (same for all groups)
109
+ K = 7168 # Hidden dimension
110
+
111
+ # Create group sizes tensor
112
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
113
+
114
+ # Create input and weight tensors - using float16 for higher precision
115
+ x = torch.randn(
116
+ M_total, K, dtype=torch.float16, device=device, requires_grad=True
117
+ )
118
+ w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
119
+
120
+ # Log the setup
121
+ logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
122
+ logging.info(f"Group sizes: {m_sizes}")
123
+ logging.info(f"Input x shape: {x.shape}")
124
+ logging.info(f"Weight w shape: {w.shape}")
125
+
126
+ # Step 1: Run forward pass
127
+ logging.info("Running forward pass")
128
+ result = grouped_gemm_forward(x, w, m_sizes)
129
+ logging.info(f"Forward result shape: {result.shape}")
130
+
131
+ # Create a gradient for backpropagation
132
+ grad_output = torch.randn_like(result)
133
+ logging.info(f"Created gradient with shape: {grad_output.shape}")
134
+
135
+ # Step 2: Run backward pass directly
136
+ logging.info("Running backward pass directly")
137
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
138
+
139
+ # Verify gradient shapes
140
+ logging.info(
141
+ f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
142
+ )
143
+
144
+ # Step 3: Verify gradient computation using PyTorch's autograd
145
+ logging.info("Running PyTorch reference implementation")
146
+
147
+ # Compute reference gradients
148
+ x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
149
+
150
+ # Compare gradients
151
+ logging.info("Comparing gradients with PyTorch reference")
152
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
153
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
154
+
155
+ # Log overall result
156
+ if grad_x_close and grad_w_close:
157
+ logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
158
+ else:
159
+ logging.error("✗ FAILURE: Gradient mismatch detected")
160
+
161
+ return grad_x_close and grad_w_close
162
+
163
+ except Exception as e:
164
+ logging.error(f"Test failed with error: {e}")
165
+ import traceback
166
+
167
+ logging.error(traceback.format_exc())
168
+ return False
169
+
170
+
171
+ def test_multiple_deepseek_configs():
172
+ """
173
+ Test multiple DeepSeek model configurations with both forward and backward pass verification.
174
+ """
175
+ # DeepSeek configurations: (G, M, K, N)
176
+ configs = [
177
+ (4, 8192, 7168, 4096), # Config 1
178
+ (4, 8192, 2048, 7168), # Config 2
179
+ (8, 4096, 7168, 4096), # Config 3
180
+ (8, 4096, 2048, 7168), # Config 4
181
+ ]
182
+
183
+ results = []
184
+
185
+ for config_idx, (G, M, K, N) in enumerate(configs):
186
+ logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
187
+ logging.info(f"G={G}, M={M}, K={K}, N={N}")
188
+
189
+ try:
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ # Create even group sizes
193
+ base_size = M // G
194
+ remainder = M % G
195
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
196
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
197
+
198
+ # Create input and weight tensors using float16 for higher precision
199
+ x = torch.randn(
200
+ M, K, dtype=torch.float16, device=device, requires_grad=True
201
+ )
202
+ w = torch.randn(
203
+ N, K, dtype=torch.float16, device=device, requires_grad=True
204
+ )
205
+
206
+ logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
207
+
208
+ # Run forward pass
209
+ result = grouped_gemm_forward(x, w, m_sizes)
210
+ logging.info(f"Forward result shape: {result.shape}")
211
+
212
+ # ===== FORWARD PASS VERIFICATION =====
213
+ # Compute reference forward result
214
+ reference_result = compute_reference_forward(x, w, m_sizes)
215
+
216
+ # Compare forward results
217
+ forward_close = analyze_tensor_differences(
218
+ result, reference_result, "Forward output"
219
+ )
220
+
221
+ # ===== BACKWARD PASS VERIFICATION =====
222
+ # Create gradient for backpropagation
223
+ grad_output = torch.randn_like(result)
224
+
225
+ # Run backward pass
226
+ grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
227
+
228
+ # Compute reference gradients
229
+ x_ref_grad, w_ref_grad = compute_reference_backward(
230
+ x, w, m_sizes, grad_output
231
+ )
232
+
233
+ # Compare backward results
234
+ grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
235
+ grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
236
+
237
+ # Overall config result
238
+ backward_close = grad_x_close and grad_w_close
239
+ config_success = forward_close and backward_close
240
+ results.append(
241
+ (config_idx + 1, config_success, forward_close, backward_close)
242
+ )
243
+
244
+ # Log overall config result
245
+ if config_success:
246
+ logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
247
+ else:
248
+ logging.error(
249
+ f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
250
+ )
251
+
252
+ except Exception as e:
253
+ logging.error(f"Config {config_idx+1} test failed with error: {e}")
254
+ import traceback
255
+
256
+ logging.error(traceback.format_exc())
257
+ results.append((config_idx + 1, False, False, False))
258
+
259
+ # Summary
260
+ logging.info("\n===== Test Results Summary =====")
261
+ for config_idx, overall_success, forward_success, backward_success in results:
262
+ overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
263
+ forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
264
+ backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
265
+
266
+ logging.info(f"Config {config_idx}: {overall_status}")
267
+ logging.info(f" - Forward pass: {forward_status}")
268
+ logging.info(f" - Backward pass: {backward_status}")
269
+
270
+ return all(overall_success for _, overall_success, _, _ in results)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ logging.info(
275
+ "Running verification for both forward and backward pass of M*G grouped GEMM"
276
+ )
277
+
278
+ # Run basic forward pass test
279
+ logging.info("\n===== Running basic forward pass test =====")
280
+ success_forward = test_forward_pass()
281
+ logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
282
+
283
+ # Run basic backward pass test
284
+ logging.info("\n===== Running basic backward pass test =====")
285
+ success_backward = test_backward_pass()
286
+ logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
287
+
288
+ # Run multiple DeepSeek configs with forward and backward verification
289
+ logging.info("\n===== Running tests for all DeepSeek configs =====")
290
+ success_configs = test_multiple_deepseek_configs()
291
+ logging.info(
292
+ f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
293
+ )
294
+
295
+ # Overall result
296
+ overall_success = success_forward and success_backward and success_configs
297
+ logging.info(
298
+ f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
299
+ )
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - TMAHelper class, AutoTuning are derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+
13
+ import os
14
+ import sys
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ from triton import Config as TConfig
22
+
23
+ from triton.runtime import driver # @manual
24
+
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+
28
+ # ===== Supporting utils, CUDA and TMA =====
29
+
30
+
31
+ class CudaUtils:
32
+ @staticmethod
33
+ def is_cuda() -> bool:
34
+ """Check if Triton is running on CUDA backend."""
35
+ return driver.active.get_current_target().backend == "cuda"
36
+
37
+ @staticmethod
38
+ def verify_tma() -> bool:
39
+ """Check if TMA is supported on the current device."""
40
+ return (
41
+ CudaUtils.is_cuda()
42
+ and torch.cuda.is_available()
43
+ and torch.cuda.get_device_capability()[0] >= 9
44
+ )
45
+
46
+ @staticmethod
47
+ def get_num_sms() -> int:
48
+ """Get the number of streaming multiprocessors on the current device."""
49
+ if not CudaUtils.is_cuda():
50
+ raise RuntimeError("Triton is not running on CUDA backend")
51
+ if not torch.cuda.is_available():
52
+ raise RuntimeError("CUDA is not available")
53
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
54
+
55
+
56
+ class TmaDescriptorHelper:
57
+ """Helper class for managing TMA descriptors in Triton kernels."""
58
+
59
+ class KernelParamWrapper:
60
+ """Wrapper to implement the TmaDescKernelParam interface."""
61
+
62
+ def __init__(self, desc: torch.Tensor):
63
+ self.desc = desc
64
+
65
+ def tma_desc_cpu_ptr(self) -> int:
66
+ """Return the CPU pointer to the TMA descriptor."""
67
+ return self.desc.data_ptr()
68
+
69
+ def __init__(self, tma_size: int = 128):
70
+ """Initialize the TMA descriptor helper.
71
+
72
+ Args:
73
+ tma_size: Size of the TMA descriptor in bytes
74
+ """
75
+ if not CudaUtils.verify_tma():
76
+ raise RuntimeError(
77
+ "TMA not supported on this device (requires Hopper or newer)"
78
+ )
79
+ if "nv_tma_desc_type" not in dir(tl):
80
+ raise RuntimeError(
81
+ "TMA grid constant descriptors not supported in your Triton version"
82
+ )
83
+
84
+ self.tma_size = tma_size
85
+ self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
86
+ self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
87
+ self.descriptors: Dict[str, torch.Tensor] = {}
88
+
89
+ def init_tma_descriptor(self, name: str) -> None:
90
+ """Initialize a TMA descriptor with the given name.
91
+
92
+ Call this method outside of the lambda function for grid size.
93
+ """
94
+ self.descriptors[name] = torch.empty(
95
+ self.tma_size, device="cpu", dtype=torch.int8
96
+ )
97
+
98
+ def fill_1d_tma_descriptor(
99
+ self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
100
+ ) -> None:
101
+ """Fill a 1D TMA descriptor.
102
+
103
+ Call this method inside the lambda function for grid size.
104
+ """
105
+ if name not in self.descriptors:
106
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
107
+
108
+ desc_x = self.descriptors[name]
109
+ if desc_x.data_ptr() % 64 != 0:
110
+ raise ValueError("TMA descriptor must be 64-byte aligned")
111
+ self.fill_1d_tma_descriptor_inner(
112
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
113
+ )
114
+
115
+ def fill_2d_tma_descriptor(
116
+ self,
117
+ name: str,
118
+ ptr: int,
119
+ dim1: int,
120
+ dim0: int,
121
+ block_dim1: int,
122
+ block_dim0: int,
123
+ element_size: int,
124
+ ) -> None:
125
+ """Fill a 2D TMA descriptor.
126
+
127
+ Call this method inside the lambda function for grid size.
128
+ """
129
+ if name not in self.descriptors:
130
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
131
+
132
+ desc_x = self.descriptors[name]
133
+ if desc_x.data_ptr() % 64 != 0:
134
+ raise ValueError("TMA descriptor must be 64-byte aligned")
135
+ self.fill_2d_tma_descriptor_inner(
136
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
137
+ )
138
+
139
+ def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
140
+ """Get the TMA descriptor kernel parameter for the given name."""
141
+ if name not in self.descriptors or self.descriptors[name] is None:
142
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
143
+ return self.KernelParamWrapper(self.descriptors[name])
144
+
145
+
146
+ # ====== Autotuning utilities ======
147
+ ALIGN_SIZE_M = 128
148
+
149
+ _NV_CONFIGS = [
150
+ triton.Config(
151
+ {
152
+ "BLOCK_SIZE_M": block_size_m,
153
+ "BLOCK_SIZE_N": block_size_n,
154
+ "BLOCK_SIZE_K": block_size_k,
155
+ },
156
+ num_stages=num_stages,
157
+ num_warps=num_warps,
158
+ num_ctas=num_ctas,
159
+ )
160
+ for block_size_m in [ALIGN_SIZE_M, ]
161
+ for block_size_n in [64, 128, 256]
162
+ for block_size_k in [64, 128, 256]
163
+ for num_stages in [3, 4]
164
+ for num_warps in [4, 8]
165
+ for num_ctas in [1]
166
+ ]
167
+
168
+
169
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
170
+ device = torch.cuda.current_device()
171
+ # Check for all possible pointer parameter names
172
+ if "grad_input_ptr" in named_args:
173
+ ptr_name = "grad_input_ptr"
174
+ elif "c_ptr" in named_args:
175
+ ptr_name = "c_ptr"
176
+ elif "grad_weight_ptr" in named_args:
177
+ ptr_name = "grad_weight_ptr"
178
+ else:
179
+ raise KeyError("No recognized pointer parameter found in kernel arguments")
180
+
181
+ if dtsize is None:
182
+ dtsize = named_args[ptr_name].element_size()
183
+ if dtype is None:
184
+ dtype = named_args[ptr_name].dtype
185
+
186
+ pruned_configs = []
187
+ for config in configs:
188
+ kw = config.kwargs
189
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
190
+ kw["BLOCK_SIZE_M"],
191
+ kw["BLOCK_SIZE_N"],
192
+ kw["BLOCK_SIZE_K"],
193
+ config.num_stages,
194
+ )
195
+ G, M, N, K = (
196
+ named_args["G"],
197
+ named_args["M_BUCKET"],
198
+ named_args["N"],
199
+ named_args["K"],
200
+ )
201
+
202
+ # 1. make sure we have enough smem
203
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
204
+ "max_shared_mem"
205
+ ]
206
+
207
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
208
+ if required_shared_memory > max_shared_memory:
209
+ continue
210
+
211
+ M_PER_GROUP = M // G
212
+ MIN_M_TILES = 64
213
+ # 2. make sure we don't load M tiles that are too big
214
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
215
+ continue
216
+ # 3. make sure we don't load N tiles that are too small
217
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
218
+ continue
219
+
220
+ num_sm = driver.active.utils.get_device_properties(device)[
221
+ "multiprocessor_count"
222
+ ]
223
+ N_TILES = N // BLOCK_N
224
+ MIN_N_TILES = 64
225
+ # 4. make sure we don't load N tiles that are too big
226
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
227
+ continue
228
+ # 5. make sure we don't load N tiles that are too small
229
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
230
+ continue
231
+ # 6. make sure K can be evenly divided
232
+ if K % BLOCK_K != 0:
233
+ continue
234
+
235
+ pruned_configs.append(config)
236
+
237
+ return pruned_configs
238
+
239
+
240
+ # ======== End Autotuning utilities ========
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
torchtitan/experiments/llama4/model/args.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ from torch import nn
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.protocols.train_spec import BaseModelArgs
16
+ from torchtitan.tools.logging import logger
17
+
18
+
19
+ @dataclass
20
+ class TransformerModelArgs(BaseModelArgs):
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000
30
+
31
+ max_seq_len: int = 2048
32
+ # If `True`, then each transformer block init uses its layer ID, and if
33
+ # `False`, each uses the total number of transformer blocks
34
+ depth_init: bool = True
35
+ norm_type: str = "rmsnorm"
36
+
37
+ use_flex_attn: bool = False
38
+ attn_mask_type: str = "causal"
39
+ eos_id: int = 0
40
+
41
+ # MoE args
42
+ moe_enabled: bool = True
43
+ num_experts: int = 8
44
+ use_shared_expert: bool = True
45
+ auto_scale_hidden_dim: bool = True
46
+ # frequency of using MoE layer instead of feedforward layer in a transformer block
47
+ interleave_moe_layer_step: int = 2
48
+ # token-choice
49
+ top_k: int = 1
50
+
51
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
52
+ self.norm_type = job_config.model.norm_type
53
+ self.vocab_size = tokenizer.n_words
54
+ self.max_seq_len = job_config.training.seq_len
55
+ self.use_flex_attn = job_config.model.use_flex_attn
56
+
57
+ def get_nparams_and_flops(
58
+ self, model: nn.Module, seq_len: int
59
+ ) -> tuple[int, float]:
60
+ nparams_embedding = 0
61
+ nparams_moe_router = 0
62
+ nparams_shared_expert = 0
63
+ nparams_experts = 0
64
+ nparams_dense = 0
65
+
66
+ for name, p in model.named_parameters():
67
+ if "embedding" in name:
68
+ nparams_embedding += p.numel()
69
+ nparams_dense += p.numel()
70
+ elif "moe.shared_expert" in name:
71
+ nparams_shared_expert += p.numel()
72
+ elif "moe.router" in name:
73
+ nparams_moe_router += p.numel()
74
+ elif "moe.experts" in name:
75
+ nparams_experts += p.numel()
76
+ else:
77
+ nparams_dense += p.numel()
78
+
79
+ nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
80
+ nparams = nparams_dense + nparams_sparse
81
+ nparams_sparse_active = (
82
+ nparams_moe_router
83
+ + nparams_shared_expert
84
+ + nparams_experts * self.top_k // self.num_experts
85
+ )
86
+
87
+ logger.info(
88
+ f"Total parameter count: dense {nparams_dense:,}, "
89
+ f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
90
+ )
91
+
92
+ l, h, q, t = (
93
+ self.n_layers,
94
+ self.n_heads,
95
+ self.dim // self.n_heads,
96
+ seq_len,
97
+ )
98
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
99
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
100
+ # 2. the flash attention does 1 more matmul recomputation in the backward
101
+ # but recomputation should not be counted in calculating MFU (+0)
102
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
103
+ # 4. we follow the convention and do not account for sparsity in causal attention
104
+ num_flops_per_token = (
105
+ 6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
106
+ + 12 * l * h * q * t
107
+ )
108
+
109
+ return nparams, num_flops_per_token
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ set -ex
9
+
10
+ # use envs as local overrides for convenience
11
+ # e.g.
12
+ # LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
13
+ NGPU=${NGPU:-"8"}
14
+ LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
15
+ CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
16
+
17
+ overrides=""
18
+ if [ $# -ne 0 ]; then
19
+ overrides="$*"
20
+ fi
21
+
22
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
23
+ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
24
+ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
25
+ convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
torchtitan/experiments/multimodal/tests/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
torchtitan/experiments/multimodal/tests/test_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ def fixed_init_tensor(
16
+ shape: torch.Size,
17
+ min_val: Union[float, int] = 0.0,
18
+ max_val: Union[float, int] = 1.0,
19
+ nonlinear: bool = False,
20
+ dtype: torch.dtype = torch.float,
21
+ ):
22
+ """
23
+ Utility for generating deterministic tensors of a given shape. In general stuff
24
+ like torch.ones, torch.eye, etc can result in trivial outputs. This utility
25
+ generates a range tensor [min_val, max_val) of a specified dtype, applies
26
+ a sine function if nonlinear=True, then reshapes to the appropriate shape.
27
+ """
28
+ n_elements = math.prod(shape)
29
+ step_size = (max_val - min_val) / n_elements
30
+ x = torch.arange(min_val, max_val, step_size, dtype=dtype)
31
+ x = x.reshape(shape)
32
+ if nonlinear:
33
+ return torch.sin(x)
34
+ return x
35
+
36
+
37
+ @torch.no_grad
38
+ def fixed_init_model(
39
+ model: nn.Module,
40
+ min_val: Union[float, int] = 0.0,
41
+ max_val: Union[float, int] = 1.0,
42
+ nonlinear: bool = False,
43
+ dtype: Optional[torch.dtype] = None,
44
+ ):
45
+ """
46
+ This utility initializes all parameters of a model deterministically using the
47
+ function fixed_init_tensor above. See that docstring for details of each parameter.
48
+ """
49
+ for _, param in model.named_parameters():
50
+ param.copy_(
51
+ fixed_init_tensor(
52
+ param.shape,
53
+ min_val=min_val,
54
+ max_val=max_val,
55
+ nonlinear=nonlinear,
56
+ dtype=param.dtype if dtype is None else dtype,
57
+ )
58
+ )
torchtitan/experiments/multimodal/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from pathlib import Path
12
+ from typing import (
13
+ AbstractSet,
14
+ Any,
15
+ cast,
16
+ Collection,
17
+ Dict,
18
+ Iterator,
19
+ List,
20
+ Literal,
21
+ Mapping,
22
+ Optional,
23
+ Sequence,
24
+ Union,
25
+ )
26
+
27
+ import tiktoken
28
+ import torch
29
+ from tiktoken.load import load_tiktoken_bpe
30
+
31
+ from torchtitan.components.tokenizer import Tokenizer
32
+ from torchtitan.config_manager import JobConfig
33
+ from torchtitan.tools.logging import logger
34
+
35
+ IMAGE_TOKEN_ID = 128256
36
+ IGNORE_INDEX = -100
37
+
38
+
39
+ class TikTokenizer(Tokenizer):
40
+ """
41
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
42
+
43
+ Args:
44
+ model_path (str): The path to the Tiktoken model file.
45
+ """
46
+
47
+ special_tokens: Dict[str, int]
48
+
49
+ num_reserved_special_tokens = 256
50
+
51
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
52
+
53
+ def __init__(self, model_path: str):
54
+ super().__init__(model_path)
55
+ assert os.path.isfile(model_path), model_path
56
+
57
+ mergeable_ranks = load_tiktoken_bpe(model_path)
58
+ num_base_tokens = len(mergeable_ranks)
59
+ special_tokens = [
60
+ "<|begin_of_text|>",
61
+ "<|end_of_text|>",
62
+ "<|reserved_special_token_0|>",
63
+ "<|reserved_special_token_1|>",
64
+ "<|reserved_special_token_2|>",
65
+ "<|reserved_special_token_3|>",
66
+ "<|start_header_id|>",
67
+ "<|end_header_id|>",
68
+ "<|reserved_special_token_4|>",
69
+ "<|eot_id|>", # end of turn
70
+ ] + [
71
+ f"<|reserved_special_token_{i}|>"
72
+ for i in range(5, self.num_reserved_special_tokens - 5)
73
+ ]
74
+ self.special_tokens = {
75
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
76
+ }
77
+ self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
78
+ self.model = tiktoken.Encoding(
79
+ name=Path(model_path).name,
80
+ pat_str=self.pat_str,
81
+ mergeable_ranks=mergeable_ranks,
82
+ special_tokens=self.special_tokens,
83
+ )
84
+
85
+ self._n_words: int = self.model.n_vocab
86
+ # BOS / EOS token IDs
87
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
88
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
89
+ self.pad_id: int = -1
90
+ self.image_id = IMAGE_TOKEN_ID
91
+ self.stop_tokens = {
92
+ self.special_tokens["<|end_of_text|>"],
93
+ self.special_tokens["<|eot_id|>"],
94
+ }
95
+ logger.info(
96
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
97
+ )
98
+
99
+ def encode(
100
+ self,
101
+ s: str,
102
+ *,
103
+ bos: bool,
104
+ eos: bool,
105
+ allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
106
+ disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
107
+ ) -> List[int]:
108
+ """
109
+ Encodes a string into a list of token IDs.
110
+
111
+ Args:
112
+ s (str): The input string to be encoded.
113
+ bos (bool): Whether to prepend the beginning-of-sequence token.
114
+ eos (bool): Whether to append the end-of-sequence token.
115
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
116
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
117
+
118
+ Returns:
119
+ list[int]: A list of token IDs.
120
+
121
+ By default, setting disallowed_special=() encodes a string by ignoring
122
+ special tokens. Specifically:
123
+ - Setting `disallowed_special` to () will cause all text corresponding
124
+ to special tokens to be encoded as natural text (insteading of raising
125
+ an error).
126
+ - Setting `allowed_special` to "all" will treat all text corresponding
127
+ to special tokens to be encoded as special tokens.
128
+ """
129
+ assert type(s) is str
130
+ allowed_special = allowed_special or set()
131
+ disallowed_special = disallowed_special or ()
132
+
133
+ # The tiktoken tokenizer can handle <=400k chars without
134
+ # pyo3_runtime.PanicException.
135
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
136
+
137
+ # https://github.com/openai/tiktoken/issues/195
138
+ # Here we iterate over subsequences and split if we exceed the limit
139
+ # of max consecutive non-whitespace or whitespace characters.
140
+ MAX_NO_WHITESPACES_CHARS = 25_000
141
+
142
+ substrs = (
143
+ substr
144
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
145
+ for substr in self._split_whitespaces_or_nonwhitespaces(
146
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
147
+ )
148
+ )
149
+ t: List[int] = []
150
+ for substr in substrs:
151
+ t.extend(
152
+ self.model.encode(
153
+ substr,
154
+ allowed_special=allowed_special,
155
+ disallowed_special=disallowed_special,
156
+ )
157
+ )
158
+ if bos:
159
+ t.insert(0, self.bos_id)
160
+ if eos:
161
+ t.append(self.eos_id)
162
+ return t
163
+
164
+ def decode(self, t: Sequence[int]) -> str:
165
+ """
166
+ Decodes a list of token IDs into a string.
167
+
168
+ Args:
169
+ t (List[int]): The list of token IDs to be decoded.
170
+
171
+ Returns:
172
+ str: The decoded string.
173
+ """
174
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
175
+ return self.model.decode(cast(List[int], t))
176
+
177
+ @staticmethod
178
+ def _split_whitespaces_or_nonwhitespaces(
179
+ s: str, max_consecutive_slice_len: int
180
+ ) -> Iterator[str]:
181
+ """
182
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
183
+ consecutive whitespaces or consecutive non-whitespaces.
184
+ """
185
+ current_slice_len = 0
186
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
187
+ slice_start = 0
188
+
189
+ for i in range(len(s)):
190
+ is_now_space = s[i].isspace()
191
+
192
+ if current_slice_is_space ^ is_now_space:
193
+ current_slice_len = 1
194
+ current_slice_is_space = is_now_space
195
+ else:
196
+ current_slice_len += 1
197
+ if current_slice_len > max_consecutive_slice_len:
198
+ yield s[slice_start:i]
199
+ slice_start = i
200
+ current_slice_len = 1
201
+ yield s[slice_start:]
202
+
203
+ def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
204
+ """
205
+ Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
206
+ """
207
+ # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
208
+ # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
209
+ # & everything else expects `tokens`
210
+ text = sample["text"]
211
+ tokens = self.encode(
212
+ text, bos=True, eos=True, allowed_special=set(["<|image|>"])
213
+ )
214
+ input_ids = torch.LongTensor(tokens[:-1])
215
+ labels = torch.LongTensor(tokens[1:])
216
+ labels = torch.where(
217
+ torch.isin(
218
+ labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
219
+ ),
220
+ IGNORE_INDEX,
221
+ labels,
222
+ )
223
+
224
+ assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
225
+
226
+ sample.update({"tokens": input_ids, "labels": labels})
227
+
228
+ return sample
229
+
230
+
231
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
232
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/experiments/multimodal/utils.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from collections import defaultdict
10
+
11
+ from pathlib import Path
12
+ from typing import List, Optional, Set, Tuple, Union
13
+ from urllib import request
14
+
15
+ import torch
16
+ import torchvision
17
+ from torchvision.transforms.v2 import functional as F
18
+
19
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
20
+ def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
21
+ """
22
+ Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
23
+
24
+ Args:
25
+ image (torch.Tensor): Input image to crop into tiles.
26
+ tile_size (int): Size of each tile.
27
+
28
+ Returns:
29
+ torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
30
+
31
+ Examples:
32
+ >>> image = torch.rand(3, 200, 300)
33
+ >>> tiles = tile_crop(image, tile_size=50)
34
+ >>> tiles.shape # 4x6 = 24 tiles
35
+ torch.Size([24, 3, 50, 50])
36
+
37
+ >>> image = torch.rand(3, 400, 600)
38
+ >>> tiles = tile_crop(image, tile_size=200)
39
+ >>> tiles.shape # 2x3 = 6 tiles
40
+ torch.Size([6, 3, 200, 200])
41
+ """
42
+
43
+ channel_size, height, width = image.shape
44
+
45
+ # assert sizes are divisible
46
+ assert (
47
+ height % tile_size == 0 and width % tile_size == 0
48
+ ), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
49
+
50
+ # Reshape to split height and width into tile_size blocks
51
+ tiles_height = height // tile_size
52
+ tiles_width = width // tile_size
53
+
54
+ reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
55
+
56
+ # Transpose to bring tiles together
57
+ # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
58
+ transposed = reshaped.permute(1, 3, 0, 2, 4)
59
+
60
+ # Flatten the tiles
61
+ tiles = transposed.contiguous().view(
62
+ tiles_height * tiles_width, channel_size, tile_size, tile_size
63
+ )
64
+
65
+ return tiles
66
+
67
+
68
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
69
+ def resize_with_pad(
70
+ image: torch.Tensor,
71
+ target_size: Tuple[int, int],
72
+ resample: torchvision.transforms.InterpolationMode,
73
+ max_size: Optional[int] = None,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Resizes and pads an image to target_size without causing distortion.
77
+ The user can set max_size to limit upscaling when target_size exceeds image_size.
78
+
79
+ Args:
80
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
81
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
82
+ resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
83
+ Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
84
+ InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
85
+ max_size (Optional[int]): The maximum size to upscale the image to.
86
+ If None, will upscale up to target_size.
87
+
88
+ Returns:
89
+ torch.Tensor: The resized and padded image tensor in the format [..., H, W].
90
+
91
+ Examples:
92
+
93
+ Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
94
+ and then padded from (448, 1194) to (448, 1344).
95
+
96
+ >>> max_size = None
97
+ >>> image = torch.rand([3, 300, 800])
98
+ >>> target_size = (448, 1344)
99
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
100
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
101
+
102
+ Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
103
+
104
+ >>> max_size = 600
105
+ >>> image = torch.rand([3, 300, 800])
106
+ >>> target_size = (448, 1344)
107
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
108
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
109
+
110
+ Example 3: The image will be downscaled from (500, 1000) to (224, 448),
111
+ and padded from (224, 448) to (448, 448).
112
+
113
+ >>> max_size = 600
114
+ >>> image = torch.rand([3, 500, 1000])
115
+ >>> target_size = (448, 488)
116
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
117
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
118
+
119
+ """
120
+
121
+ image_height, image_width = image.shape[-2:]
122
+ image_size = (image_height, image_width)
123
+
124
+ # If target_size requires upscaling, we might want to limit the upscaling to max_size
125
+ if max_size is not None:
126
+ new_target_height = min(max(image_height, max_size), target_size[0])
127
+ new_target_width = min(max(image_width, max_size), target_size[1])
128
+ target_size_resize = (new_target_height, new_target_width)
129
+ else:
130
+ target_size_resize = target_size
131
+
132
+ # resize to target_size while preserving aspect ratio
133
+ new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
134
+ image_size=image_size,
135
+ target_size=target_size_resize,
136
+ )
137
+
138
+ image = F.resize(
139
+ inpt=image,
140
+ size=list(new_size_preserving_aspect_ratio),
141
+ interpolation=resample,
142
+ antialias=True,
143
+ )
144
+
145
+ image = _pad_image_top_left(image=image, target_size=target_size)
146
+
147
+ return image
148
+
149
+
150
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
151
+ def _pad_image_top_left(
152
+ image: torch.Tensor,
153
+ target_size: Tuple[int, int],
154
+ ) -> torch.Tensor:
155
+ """
156
+ Places the image at the top left of the canvas and pads with 0 the right and bottom
157
+ to fit to the target resolution. If target_size < image_size, it will crop the image.
158
+
159
+ Args:
160
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
161
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
162
+
163
+ Returns:
164
+ torch.Tensor: The padded image tensor in the format [..., H, W].
165
+ """
166
+
167
+ image_size = image.shape[-2:]
168
+
169
+ height, width = image_size
170
+ target_height, target_width = target_size
171
+
172
+ pad_x = target_width - width
173
+ pad_y = target_height - height
174
+
175
+ padding = [0, 0, pad_x, pad_y]
176
+ return F.pad(inpt=image, padding=padding)
177
+
178
+
179
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
180
+ def _get_max_res_without_distortion(
181
+ image_size: Tuple[int, int],
182
+ target_size: Tuple[int, int],
183
+ ) -> Tuple[int, int]:
184
+ """
185
+ Determines the maximum resolution to which an image can be resized to without distorting its
186
+ aspect ratio, based on the target resolution.
187
+
188
+ For example, if image_size = (200,400) and target_size = (600,800),
189
+ scale_h = 600/200 = 3
190
+ scale_w = 800/400 = 2
191
+ So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
192
+
193
+ Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
194
+
195
+ Args:
196
+ image_size (Tuple[int, int]): The original resolution of the image.
197
+ target_size (Tuple[int, int]): The desired resolution to fit the image into.
198
+ Returns:
199
+ Tuple[int, int]: The optimal dimensions to which the image should be resized.
200
+ Examples:
201
+ >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
202
+ (133, 200)
203
+ >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
204
+ (450, 337)
205
+ """
206
+
207
+ original_height, original_width = image_size
208
+ target_height, target_width = target_size
209
+
210
+ scale_w = target_width / original_width
211
+ scale_h = target_height / original_height
212
+
213
+ if scale_w < scale_h:
214
+ new_width = target_width
215
+ new_height = min(math.floor(original_height * scale_w), target_height)
216
+ else:
217
+ new_height = target_height
218
+ new_width = min(math.floor(original_width * scale_h), target_width)
219
+
220
+ return new_height, new_width
221
+
222
+
223
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
224
+ def _get_factors(n: int) -> Set[int]:
225
+ """
226
+ Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
227
+
228
+ Args:
229
+ n (int): The number to find factors for.
230
+
231
+ Returns:
232
+ set: A set containing all factors of the number.
233
+
234
+ Examples:
235
+ >>> _get_factors(n=12)
236
+ {1, 2, 3, 4, 6, 12}
237
+ """
238
+ factors_set = set()
239
+
240
+ for i in range(1, int(n**0.5) + 1):
241
+ if n % i == 0:
242
+ factors_set.add(i)
243
+ factors_set.add(n // i)
244
+ return factors_set
245
+
246
+
247
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
248
+ def get_canvas_best_fit(
249
+ image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
250
+ ) -> Tuple[int, int]:
251
+ """
252
+ Determines the best canvas possible from a list of possible resolutions to
253
+ resize an image to, without distortion.
254
+
255
+ For each possible resolution, calculates the scaling factors for
256
+ width and height, and selects the smallest one, which is the limiting side.
257
+ E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
258
+ then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
259
+
260
+ If there are multiple canvases that satisfy the conditions,
261
+ we pick the one with the lowest area to minimize padding.
262
+
263
+ Args:
264
+ image (torch.Tensor): The image we want to fit into a canvas.
265
+ possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
266
+ row represents a possible canvas.
267
+ resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
268
+ If False, pick the canvas that minimizes downscaling, including no downscaling at all.
269
+
270
+ Returns:
271
+ Tuple[int, int]: The best resolution to fit the image into.
272
+
273
+ Examples:
274
+ >>> image = torch.rand(3, 200, 300)
275
+ >>> possible_resolutions = torch.tensor([
276
+ ... [224, 672],
277
+ ... [672, 224],
278
+ ... [224, 448],
279
+ ... [448, 224],
280
+ ... [224, 224]
281
+ ... ])
282
+ >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
283
+ (224, 448)
284
+
285
+ In the example above, we calculate the scaling factors for each possible resolution
286
+
287
+ >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
288
+ >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
289
+ >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
290
+
291
+ Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
292
+
293
+ >>> upscaling_options = torch.tensor([1.1200, 1.1200])
294
+ >>> selected_scale = torch.tensor(1.1200)
295
+
296
+ There are two possible options, so we pick the one with the smallest area
297
+
298
+ >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
299
+ >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
300
+ """
301
+
302
+ original_height, original_width = image.shape[-2:]
303
+
304
+ # possible resolutions heights/widths
305
+ target_heights, target_widths = (
306
+ possible_resolutions[:, 0],
307
+ possible_resolutions[:, 1],
308
+ )
309
+
310
+ # scaling factors to resize the image without distortion
311
+ scale_w = target_widths / original_width
312
+ scale_h = target_heights / original_height
313
+
314
+ # get limiting side scaling -> no distortion
315
+ scales = torch.where(scale_w > scale_h, scale_h, scale_w)
316
+
317
+ # filter only scales that allow upscaling
318
+ upscaling_options = scales[scales >= 1]
319
+ if len(upscaling_options) > 0:
320
+ if resize_to_max_canvas:
321
+ selected_scale = torch.max(upscaling_options)
322
+ else:
323
+ selected_scale = torch.min(upscaling_options)
324
+ else:
325
+ # no upscaling possible,
326
+ # get the minimum downscaling (max scale for scales<1)
327
+ downscaling_options = scales[scales < 1]
328
+ selected_scale = torch.max(downscaling_options)
329
+
330
+ # get all resolutions that support this scaling factor,
331
+ # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
332
+ chosen_canvas = possible_resolutions[scales == selected_scale]
333
+
334
+ # if there are multiple resolutions,
335
+ # get the one with minimum area to reduce padding
336
+ if len(chosen_canvas) > 1:
337
+ areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
338
+ optimal_idx = torch.argmin(areas)
339
+ optimal_canvas = chosen_canvas[optimal_idx]
340
+ else:
341
+ optimal_canvas = chosen_canvas[0]
342
+
343
+ return tuple(optimal_canvas.tolist())
344
+
345
+
346
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
347
+ def find_supported_resolutions(
348
+ max_num_tiles: int, tile_size: int
349
+ ) -> List[Tuple[int, int]]:
350
+ """
351
+ Computes all combinations of resolutions, multiple of tile_size,
352
+ that contain up to max_num_tiles. Useful for when dividing an image into tiles.
353
+
354
+ For example, if we want at most 2 tiles per image, then we can support the
355
+ following resolutions: (1x1, 1x2, 2x1) * tile_size
356
+
357
+ Args:
358
+ max_num_tiles (int): Maximum number of tiles.
359
+ tile_size (int): Size of the side of the tile.
360
+
361
+ Returns:
362
+ List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
363
+
364
+ Examples:
365
+
366
+ >>> max_num_tiles = 4
367
+ >>> tile_size = 224
368
+ >>> find_supported_resolutions(max_num_tiles, tile_size)
369
+ [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
370
+ """
371
+
372
+ # create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
373
+ # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
374
+ asp_dict = defaultdict(list)
375
+ for _tile_size in range(max_num_tiles, 0, -1):
376
+ factors = sorted(_get_factors(_tile_size))
377
+ asp_ratios = [(factor, _tile_size // factor) for factor in factors]
378
+ for height, width in asp_ratios:
379
+ ratio_float = height / width
380
+ asp_dict[ratio_float].append((height, width))
381
+
382
+ # get the resolutions multiplied by the tile_size
383
+ possible_resolutions = []
384
+ for ar, resolution in asp_dict.items():
385
+ for height, width in resolution:
386
+ possible_resolutions.append((height * tile_size, width * tile_size))
387
+
388
+ return possible_resolutions
389
+
390
+
391
+ # NOTE Copied from torchtune.data._utils.py
392
+ def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
393
+ """
394
+ Convenience method to load an image in torch.Tensor format from a local file path or remote source.
395
+
396
+ Args:
397
+ image_loc (Union[Path, str]): Local file path or remote source pointing to the image
398
+ which will be loaded in PIL format.
399
+
400
+ Note:
401
+ If loading an image from a remote source, the function expects the URL provided in ``image_loc``
402
+ to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
403
+
404
+ Raises:
405
+ ValueError: If the image cannot be loaded from remote source, **or**
406
+ if the image cannot be opened as a :class:`~torch.Tensor`.
407
+
408
+ Examples:
409
+ >>> # Load from remote source
410
+ >>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
411
+
412
+ >>> # Load from local file path
413
+ >>> image = load_image(Path("/home/user/bird.jpg"))
414
+
415
+ Returns:
416
+ torch.Tensor: The loaded image.
417
+ """
418
+
419
+ # If pointing to remote source, try to load to local
420
+ if isinstance(image_loc, str) and image_loc.startswith("http"):
421
+ try:
422
+ image_loc = request.urlopen(image_loc).read()
423
+ image = torchvision.io.decode_image(
424
+ torch.frombuffer(image_loc, dtype=torch.uint8),
425
+ mode="RGB",
426
+ )
427
+ except Exception as e:
428
+ raise ValueError("Failed to load remote image as torch.Tensor") from e
429
+
430
+ # Open the local image as a Tensor image
431
+ else:
432
+ try:
433
+ image = torchvision.io.decode_image(image_loc, mode="RGB")
434
+ except Exception as e:
435
+ raise ValueError("Failed to load local image as torch.Tensor") from e
436
+
437
+ return image
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc ADDED
Binary file (1.14 kB). View file
 
torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc ADDED
Binary file (2.61 kB). View file
 
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc ADDED
Binary file (6.83 kB). View file
 
torchtitan/experiments/simple_fsdp/tests/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
torchtitan/experiments/simple_fsdp/tests/test_numerics.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import copy
7
+
8
+ import torch
9
+ from torch.distributed._composable.fsdp import fully_shard
10
+
11
+ from torch.testing._internal.common_fsdp import FSDPTest
12
+
13
+ from torchtitan.components.loss import cross_entropy_loss
14
+ from torchtitan.distributed import ParallelDims
15
+ from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
16
+
17
+
18
+ class TestSimpleFSDP(FSDPTest):
19
+ def init_test(self):
20
+ self.optimizer = torch.optim.Adam
21
+ self.loss_fn = cross_entropy_loss
22
+ data_parallel_shard_degree = -1
23
+ if self.mode == "replicate":
24
+ self.dp_mesh_dim_names = ("dp_replicate",)
25
+ data_parallel_replicate_degree = self.world_size
26
+ elif self.mode == "fully_shard":
27
+ self.dp_mesh_dim_names = ("dp_shard_cp",)
28
+ data_parallel_replicate_degree = 1
29
+ elif self.mode == "hybrid_shard":
30
+ self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
31
+ data_parallel_replicate_degree = self.world_size // 2
32
+ else:
33
+ raise ValueError(f"Unsupported mode {mode}")
34
+
35
+ self.parallel_dims = ParallelDims(
36
+ dp_shard=data_parallel_shard_degree,
37
+ dp_replicate=data_parallel_replicate_degree,
38
+ cp=1,
39
+ tp=1,
40
+ pp=1,
41
+ world_size=self.world_size,
42
+ enable_loss_parallel=True,
43
+ )
44
+ self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda")
45
+
46
+ def get_input(self):
47
+ inputs = torch.randn(8, 8).cuda()
48
+ labels = torch.randn(8, 8).cuda()
49
+ model = torch.nn.Linear(8, 8)
50
+ return model, inputs, labels
51
+
52
+ def run_fsdp2(self, model, inputs, labels, epoch=20):
53
+ fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)])
54
+ optim = self.optimizer(model.parameters(), lr=1e-4)
55
+ losses = []
56
+ for _ in range(epoch):
57
+ optim.zero_grad()
58
+ out = model(inputs)
59
+ loss = self.loss_fn(out, labels)
60
+ loss.backward()
61
+ optim.step()
62
+ losses.append(loss)
63
+ return losses
64
+
65
+ def run_simple_fsdp(self, model, inputs, labels, epoch=20):
66
+ model = data_parallel(
67
+ model,
68
+ device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
69
+ mode=self.mode,
70
+ )
71
+ optim = self.optimizer(model.parameters(), lr=1e-4)
72
+ losses = []
73
+ for _ in range(epoch):
74
+ optim.zero_grad()
75
+ out = model(inputs)
76
+ loss = self.loss_fn(out, labels)
77
+ loss.backward()
78
+ optim.step()
79
+ losses.append(loss)
80
+ return losses
81
+
82
+ def test_replicate_convergence(self):
83
+ # unit test for replicate mode
84
+ self.mode = "replicate"
85
+ self.init_test()
86
+ model, inputs, labels = self.get_input()
87
+
88
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
89
+ simple_fsdp_replicate_losses = self.run_simple_fsdp(
90
+ copy.deepcopy(model), inputs, labels
91
+ )
92
+
93
+ for fsdp2_loss, simple_fsdp_replicate_loss in zip(
94
+ fsdp2_losses, simple_fsdp_replicate_losses
95
+ ):
96
+ assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss)
97
+
98
+ def test_fullyshard_convergence(self):
99
+ # unit test for fully_shard mode
100
+ self.mode = "fully_shard"
101
+ self.init_test()
102
+ model, inputs, labels = self.get_input()
103
+
104
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
105
+ simple_fsdp_fullyshard_losses = self.run_simple_fsdp(
106
+ copy.deepcopy(model), inputs, labels
107
+ )
108
+
109
+ for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
110
+ fsdp2_losses, simple_fsdp_fullyshard_losses
111
+ ):
112
+ assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss)
113
+
114
+ def test_hybridshard_convergence(self):
115
+ # unit test for hybrid_shard mode
116
+ self.mode = "hybrid_shard"
117
+ self.init_test()
118
+ model, inputs, labels = self.get_input()
119
+
120
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
121
+ simple_fsdp_hybridshard_losses = self.run_simple_fsdp(
122
+ copy.deepcopy(model), inputs, labels
123
+ )
124
+
125
+ for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
126
+ fsdp2_losses, simple_fsdp_hybridshard_losses
127
+ ):
128
+ assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss)
torchtitan/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (195 Bytes). View file
 
torchtitan/models/__pycache__/norms.cpython-312.pyc ADDED
Binary file (1.39 kB). View file
 
torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.57 kB). View file
 
torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
torchtitan/models/llama3/parallelize_llama.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed._composable.replicate import replicate
15
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16
+ checkpoint_wrapper as ptd_checkpoint_wrapper,
17
+ )
18
+
19
+ from torch.distributed.device_mesh import DeviceMesh
20
+ from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
21
+ from torch.distributed.tensor import Replicate, Shard
22
+ from torch.distributed.tensor.parallel import (
23
+ ColwiseParallel,
24
+ parallelize_module,
25
+ PrepareModuleInput,
26
+ RowwiseParallel,
27
+ SequenceParallel,
28
+ )
29
+
30
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
31
+ from torchtitan.distributed import ParallelDims
32
+ from torchtitan.tools.logging import logger
33
+
34
+
35
+ def parallelize_llama(
36
+ model: nn.Module,
37
+ world_mesh: DeviceMesh,
38
+ parallel_dims: ParallelDims,
39
+ job_config: JobConfig,
40
+ ):
41
+ """
42
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
43
+ parallelism to the model.
44
+
45
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
46
+ the model must fit on GPU or CPU memory.
47
+ """
48
+
49
+ if parallel_dims.tp_enabled:
50
+ if (
51
+ job_config.parallelism.enable_async_tensor_parallel
52
+ and not job_config.training.compile
53
+ ):
54
+ raise RuntimeError("Async TP requires --training.compile")
55
+
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ float8_is_rowwise = job_config.float8.recipe_name in (
58
+ "rowwise",
59
+ "rowwise_with_gw_hp",
60
+ )
61
+
62
+ # For now, float8 all-gather with TP is only supported for tensorwise
63
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
64
+ # all-gather happens in high precision.
65
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
66
+
67
+ apply_tp(
68
+ model,
69
+ world_mesh["tp"],
70
+ loss_parallel=parallel_dims.loss_parallel_enabled,
71
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
72
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
73
+ )
74
+
75
+ if job_config.model.use_flex_attn:
76
+ if job_config.activation_checkpoint.mode == "selective":
77
+ raise ValueError(
78
+ "FlexAttention is not compatible with selective AC yet. "
79
+ "See https://github.com/pytorch/pytorch/issues/147879"
80
+ )
81
+
82
+ if parallel_dims.cp_enabled:
83
+ raise ValueError(
84
+ "FlexAttention is not compatible with CP yet. "
85
+ "We are still working on this."
86
+ )
87
+
88
+ if job_config.activation_checkpoint.mode != "none":
89
+ apply_ac(model, job_config.activation_checkpoint)
90
+
91
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
92
+ if job_config.training.compile:
93
+ apply_compile(model)
94
+
95
+ if (
96
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
97
+ ): # apply FSDP or HSDP, potentially with Context Parallel
98
+ if parallel_dims.dp_replicate_enabled:
99
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
100
+ else:
101
+ dp_mesh_dim_names = ("dp_shard_cp",)
102
+
103
+ apply_fsdp(
104
+ model,
105
+ world_mesh[tuple(dp_mesh_dim_names)],
106
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
107
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
108
+ pp_enabled=parallel_dims.pp_enabled,
109
+ cpu_offload=job_config.training.enable_cpu_offload,
110
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
111
+ )
112
+
113
+ if parallel_dims.dp_replicate_enabled:
114
+ logger.info("Applied HSDP to the model")
115
+ else:
116
+ logger.info("Applied FSDP to the model")
117
+
118
+ if parallel_dims.cp_enabled:
119
+ logger.info("Applied Context Parallel to the model")
120
+
121
+ if job_config.training.enable_cpu_offload:
122
+ logger.info("Applied CPU Offloading to the model")
123
+ elif parallel_dims.dp_replicate_enabled:
124
+ if world_mesh.ndim > 1:
125
+ raise RuntimeError("DDP has not supported > 1D parallelism")
126
+ apply_ddp(
127
+ model,
128
+ world_mesh,
129
+ enable_compile=job_config.training.compile,
130
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
131
+ )
132
+
133
+ return model
134
+
135
+
136
+ def apply_tp(
137
+ model: nn.Module,
138
+ tp_mesh: DeviceMesh,
139
+ loss_parallel: bool,
140
+ enable_float8_tensorwise_tp: bool,
141
+ enable_async_tp: bool,
142
+ ):
143
+ """Apply tensor parallelism."""
144
+ # 1. Parallelize the embedding and shard its outputs (which are the first
145
+ # transformer block's inputs)
146
+ # 2. Parallelize the root norm layer over the sequence dim
147
+ # 3. Parallelize the final linear output layer
148
+ parallelize_module(
149
+ model,
150
+ tp_mesh,
151
+ {
152
+ "tok_embeddings": RowwiseParallel(
153
+ input_layouts=Replicate(),
154
+ output_layouts=Shard(1),
155
+ ),
156
+ "norm": SequenceParallel(),
157
+ "output": ColwiseParallel(
158
+ input_layouts=Shard(1),
159
+ output_layouts=Shard(-1) if loss_parallel else Replicate(),
160
+ use_local_output=not loss_parallel,
161
+ ),
162
+ },
163
+ )
164
+
165
+ # Parallel styles used for transformer block linear weights and their
166
+ # inputs may be different for float8 linears with tensorwise scaling.
167
+ if enable_float8_tensorwise_tp:
168
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
169
+ from torchao.float8.float8_tensor_parallel import (
170
+ Float8ColwiseParallel,
171
+ Float8RowwiseParallel,
172
+ PrepareFloat8ModuleInput,
173
+ )
174
+
175
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
176
+ Float8RowwiseParallel,
177
+ Float8ColwiseParallel,
178
+ PrepareFloat8ModuleInput,
179
+ )
180
+ else:
181
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
182
+ RowwiseParallel,
183
+ ColwiseParallel,
184
+ PrepareModuleInput,
185
+ )
186
+
187
+ # Apply tensor + sequence parallelism to every transformer block
188
+ # NOTE: At the cost of model code change, we can accelerate Sequence Parallel
189
+ # by folding (and unfolding) the batch dimension and the sequence dimension.
190
+ # Examples can be found at https://github.com/pytorch/torchtitan/pull/437
191
+ for layer_id, transformer_block in model.layers.items():
192
+ layer_plan = {
193
+ "attention_norm": SequenceParallel(),
194
+ "attention": prepare_module_input(
195
+ input_layouts=(Shard(1), None),
196
+ desired_input_layouts=(Replicate(), None),
197
+ ),
198
+ "attention.wq": colwise_parallel(),
199
+ "attention.wk": colwise_parallel(),
200
+ "attention.wv": colwise_parallel(),
201
+ "attention.wo": rowwise_parallel(output_layouts=Shard(1)),
202
+ "ffn_norm": SequenceParallel(),
203
+ "feed_forward": prepare_module_input(
204
+ input_layouts=(Shard(1),),
205
+ desired_input_layouts=(Replicate(),),
206
+ ),
207
+ "feed_forward.w1": colwise_parallel(),
208
+ "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
209
+ "feed_forward.w3": colwise_parallel(),
210
+ }
211
+
212
+ parallelize_module(
213
+ module=transformer_block,
214
+ device_mesh=tp_mesh,
215
+ parallelize_plan=layer_plan,
216
+ )
217
+
218
+ if enable_async_tp:
219
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
220
+
221
+ torch._inductor.config._micro_pipeline_tp = True
222
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
223
+
224
+ logger.info(
225
+ f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
226
+ "Tensor Parallelism to the model"
227
+ )
228
+
229
+
230
+ # for selective op activation checkpointing
231
+ _save_list = {
232
+ torch.ops.aten.mm.default,
233
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
234
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
235
+ # for low precision training, it's useful to always save
236
+ # the result of max, since the absolute maximum is
237
+ # used to compute the scaling factor for quantization.
238
+ torch.ops.aten.max.default,
239
+ }
240
+
241
+
242
+ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
243
+ valid_ac_modes = ("full", "selective")
244
+ if ac_config.mode not in valid_ac_modes:
245
+ raise ValueError(
246
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
247
+ )
248
+
249
+ if ac_config.mode == "full":
250
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
251
+
252
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
253
+ use_op_sac = ac_config.selective_ac_option == "op"
254
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
255
+ if not use_op_sac and not use_layer_sac:
256
+ raise ValueError(
257
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
258
+ f"Valid options: 'op' or a positive int representing layer frequency"
259
+ )
260
+ if use_op_sac:
261
+ from torch.utils.checkpoint import (
262
+ CheckpointPolicy,
263
+ create_selective_checkpoint_contexts,
264
+ )
265
+
266
+ def _get_custom_policy(meta):
267
+ def _custom_policy(ctx, func, *args, **kwargs):
268
+ mode = "recompute" if ctx.is_recompute else "forward"
269
+ mm_count_key = f"{mode}_mm_count"
270
+ if func == torch.ops.aten.mm.default:
271
+ meta[mm_count_key] += 1
272
+ # Saves output of all compute ops, except every second mm
273
+ to_save = func in _save_list and not (
274
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
275
+ )
276
+ return (
277
+ CheckpointPolicy.MUST_SAVE
278
+ if to_save
279
+ else CheckpointPolicy.PREFER_RECOMPUTE
280
+ )
281
+
282
+ return _custom_policy
283
+
284
+ def selective_checkpointing_context_fn():
285
+ meta = defaultdict(int)
286
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
287
+
288
+ return ptd_checkpoint_wrapper(
289
+ module,
290
+ context_fn=selective_checkpointing_context_fn,
291
+ preserve_rng_state=False,
292
+ )
293
+ elif use_layer_sac:
294
+ # Checkpoint every `ac_freq` of the modules passed to this function
295
+ ac_freq = int(ac_config.selective_ac_option)
296
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
297
+ ptd_checkpoint_wrapper._count += 1
298
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
299
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
300
+ else:
301
+ return module
302
+
303
+
304
+ def apply_ac(model: nn.Module, ac_config):
305
+ """Apply activation checkpointing to the model."""
306
+ for layer_id, transformer_block in model.layers.named_children():
307
+ transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
308
+ model.layers.register_module(layer_id, transformer_block)
309
+
310
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
311
+
312
+
313
+ def apply_compile(model: nn.Module):
314
+ """
315
+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
316
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
317
+ """
318
+ for layer_id, transformer_block in model.layers.named_children():
319
+ transformer_block = torch.compile(transformer_block, fullgraph=True)
320
+ model.layers.register_module(layer_id, transformer_block)
321
+
322
+ logger.info("Compiling each TransformerBlock with torch.compile")
323
+
324
+
325
+ def apply_fsdp(
326
+ model: nn.Module,
327
+ dp_mesh: DeviceMesh,
328
+ param_dtype: torch.dtype,
329
+ reduce_dtype: torch.dtype,
330
+ pp_enabled: bool,
331
+ cpu_offload: bool = False,
332
+ reshard_after_forward_policy: str = "default",
333
+ ):
334
+ """
335
+ Apply data parallelism (via FSDP2) to the model.
336
+
337
+ Args:
338
+ model (nn.Module): The model to apply data parallelism to.
339
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
340
+ param_dtype (torch.dtype): The data type to use for model parameters.
341
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
342
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
343
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
344
+ reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
345
+ Other options: "never", "always".
346
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
347
+ - "always" will enable `reshard_after_forward` for all forward passes.
348
+ - "never" will disable `reshard_after_forward` for all forward passes.
349
+
350
+ """
351
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
352
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
353
+ if cpu_offload:
354
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
355
+
356
+ for layer_id, transformer_block in model.layers.items():
357
+ if reshard_after_forward_policy == "always":
358
+ reshard_after_forward = True
359
+ elif reshard_after_forward_policy == "never":
360
+ reshard_after_forward = False
361
+ elif reshard_after_forward_policy == "default":
362
+ if pp_enabled:
363
+ # For PP, do not reshard after forward to avoid per-microbatch
364
+ # all-gathers, which can be expensive and non-overlapped
365
+ reshard_after_forward = False
366
+ else:
367
+ # As an optimization, do not reshard after forward for the last
368
+ # transformer block since FSDP would prefetch it immediately
369
+ reshard_after_forward = int(layer_id) < len(model.layers) - 1
370
+ else:
371
+ raise ValueError(
372
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
373
+ )
374
+ fully_shard(
375
+ transformer_block,
376
+ **fsdp_config,
377
+ reshard_after_forward=reshard_after_forward,
378
+ )
379
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
380
+
381
+
382
+ def apply_ddp(
383
+ model: nn.Module,
384
+ dp_mesh: DeviceMesh,
385
+ enable_compile: bool,
386
+ enable_compiled_autograd: bool,
387
+ ):
388
+ if enable_compile:
389
+ if enable_compiled_autograd:
390
+ torch._dynamo.config.optimize_ddp = (
391
+ "python_reducer_without_compiled_forward"
392
+ )
393
+ else:
394
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
395
+
396
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
397
+
398
+ logger.info("Applied DDP to the model")
torchtitan/models/llama3/train_configs/llama3_70b.toml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+ # NOTE: this toml config is a preset for 64 A100 GPUs.
3
+
4
+ [job]
5
+ dump_folder = "./outputs"
6
+ description = "Llama 3 70B training"
7
+
8
+ [profiling]
9
+ enable_profiling = true
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 100
12
+
13
+ [metrics]
14
+ log_freq = 10
15
+ enable_tensorboard = true
16
+ save_tb_folder = "tb"
17
+
18
+ [model]
19
+ name = "llama3"
20
+ flavor = "70B"
21
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
22
+ tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
23
+ # converters = "float8"
24
+
25
+ [optimizer]
26
+ name = "AdamW"
27
+ lr = 1.5e-4
28
+ eps = 1e-8
29
+
30
+ [lr_scheduler]
31
+ warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
32
+
33
+ [training]
34
+ batch_size = 8
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 1000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8 # 8-way TP
45
+ pipeline_parallel_degree = 1
46
+ context_parallel_degree = 1
47
+
48
+ [checkpoint]
49
+ enable_checkpoint = false
50
+ folder = "checkpoint"
51
+ interval = 500
52
+ model_weights_only = false
53
+ export_dtype = "float32"
54
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
55
+
56
+ [activation_checkpoint]
57
+ mode = 'full'
58
+
59
+ [float8]
60
+ enable_fsdp_float8_all_gather = false
61
+ precompute_float8_dynamic_scale_for_fsdp = false
62
+ filter_fqns = "output"
torchtitan/protocols/train_spec.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from abc import abstractmethod
10
+ from collections.abc import Callable, Mapping
11
+ from dataclasses import dataclass
12
+ from typing import Protocol, TypeAlias
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.distributed.pipelining.schedules import _PipelineSchedule
17
+
18
+ from torchtitan.components.dataloader import BaseDataLoader
19
+ from torchtitan.components.ft import FTManager
20
+ from torchtitan.components.loss import LossFunction
21
+ from torchtitan.components.lr_scheduler import LRSchedulersContainer
22
+ from torchtitan.components.metrics import MetricsProcessor
23
+ from torchtitan.components.optimizer import OptimizersContainer
24
+ from torchtitan.components.tokenizer import Tokenizer
25
+ from torchtitan.config_manager import JobConfig
26
+
27
+ DeviceType = int | str | torch.device
28
+
29
+
30
+ @dataclass
31
+ class BaseModelArgs:
32
+ """All ModelArgs should inherit from this class.
33
+
34
+ The only usage of this class is type checking but allows us to extend common
35
+ arguments to all models in the future.
36
+ """
37
+
38
+ _enforced: str = "This field is used to enforce all fields have defaults."
39
+
40
+ @abstractmethod
41
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def get_nparams_and_flops(
46
+ self, model: nn.Module, seq_len: int
47
+ ) -> tuple[int, float]:
48
+ pass
49
+
50
+
51
+ class ModelProtocol(Protocol):
52
+ """Defines the interface for a model class.
53
+
54
+ This is used to enforce that all model classes have some methods that are
55
+ required by the TorchTitan trainer.
56
+ """
57
+
58
+ @classmethod
59
+ def from_model_args(cls, args: BaseModelArgs) -> nn.Module:
60
+ ...
61
+
62
+
63
+ ParallelizeFunction: TypeAlias = Callable[..., nn.Module]
64
+ PipeliningFunction: TypeAlias = Callable[
65
+ ..., tuple[_PipelineSchedule, list[nn.Module], bool, bool]
66
+ ]
67
+ DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader]
68
+ TokenizerBuilder: TypeAlias = Callable[..., Tokenizer]
69
+ MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
70
+ OptimizersBuilder: TypeAlias = Callable[
71
+ [list[nn.Module], JobConfig, FTManager], OptimizersContainer
72
+ ]
73
+ LRSchedulersBuilder: TypeAlias = Callable[
74
+ [OptimizersContainer, JobConfig], LRSchedulersContainer
75
+ ]
76
+ LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
77
+
78
+
79
+ @dataclass
80
+ class TrainSpec:
81
+ name: str
82
+ cls: type[nn.Module]
83
+ config: Mapping[str, BaseModelArgs]
84
+ parallelize_fn: ParallelizeFunction
85
+ pipelining_fn: PipeliningFunction | None
86
+ build_optimizers_fn: OptimizersBuilder
87
+ build_lr_schedulers_fn: LRSchedulersBuilder
88
+ build_dataloader_fn: DataLoaderBuilder
89
+ build_tokenizer_fn: TokenizerBuilder | None
90
+ build_loss_fn: LossFunctionBuilder
91
+ build_metrics_processor_fn: MetricsProcessorBuilder | None = None
92
+
93
+
94
+ _train_specs = {}
95
+
96
+
97
+ def register_train_spec(train_spec: TrainSpec) -> None:
98
+ global _train_specs
99
+ if train_spec.name in _train_specs:
100
+ raise ValueError(f"Model {train_spec.name} is already registered.")
101
+
102
+ _train_specs[train_spec.name] = train_spec
103
+
104
+
105
+ def get_train_spec(name: str) -> TrainSpec:
106
+ global _train_specs
107
+ if name not in _train_specs:
108
+ raise ValueError(f"Model {name} is not registered.")
109
+ return _train_specs[name]
110
+
111
+
112
+ def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None:
113
+ global _train_specs
114
+ for name, train_spec in _train_specs.items():
115
+ _train_specs[name] = func(train_spec)
train.sh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ params=""
4
+ if [ $# -ne 0 ]; then
5
+ params="$*"
6
+ fi
7
+
8
+ # use envs as local params for convenience
9
+ # e.g.
10
+ # NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
11
+ NNODE=${NNODE:-"1"}
12
+ NGPU=${NGPU:-"8"}
13
+ LOG_RANK=${LOG_RANK:-0}
14
+
15
+ if [[ -z "${MASTER_ADDR}" ]]; then
16
+ export MASTER_ADDR="localhost"
17
+ fi
18
+ if [[ -z "${MASTER_PORT}" ]]; then
19
+ export MASTER_PORT="0"
20
+ fi
21
+
22
+ : '
23
+ Usage:
24
+
25
+ bash train.sh -h
26
+
27
+ Training a 340M model:
28
+
29
+ NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
30
+ --job.config_file flame/models/fla.toml \
31
+ --job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
32
+ --model.config configs/transformer_340M.json \
33
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
34
+ --optimizer.name AdamW \
35
+ --optimizer.eps 1e-15 \
36
+ --optimizer.lr 3e-4 \
37
+ --lr_scheduler.warmup_steps 1024 \
38
+ --lr_scheduler.lr_min 0.1 \
39
+ --lr_scheduler.decay_type cosine \
40
+ --training.batch_size 32 \
41
+ --training.seq_len 2048 \
42
+ --training.gradient_accumulation_steps 1 \
43
+ --training.steps 20480 \
44
+ --training.max_norm 1.0 \
45
+ --training.skip_nan_inf \
46
+ --training.dataset HuggingFaceFW/fineweb-edu \
47
+ --training.dataset_name default \
48
+ --training.dataset_split train \
49
+ --training.streaming \
50
+ --training.num_workers 32 \
51
+ --training.prefetch_factor 2 \
52
+ --training.seed 42 \
53
+ --training.compile \
54
+ --training.tensor_parallel_degree 1 \
55
+ --training.disable_loss_parallel \
56
+ --checkpoint.interval 2048 \
57
+ --checkpoint.load_step -1 \
58
+ --metrics.log_freq 1
59
+ '
60
+
61
+ echo "Launching training..."
62
+
63
+ set -x
64
+ path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
65
+ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
66
+ config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
67
+ tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
68
+ model=$(
69
+ python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
70
+ )
71
+
72
+ mkdir -p $path
73
+ cp * $path
74
+ cp -r configs $path
75
+ cp -r flame $path
76
+ cp -r 3rdparty/flash-linear-attention/fla $path
77
+ cp -r 3rdparty/torchtitan/torchtitan $path
78
+
79
+ # for offline systems
80
+ # export TRANSFORMERS_OFFLINE=1
81
+ # export HF_DATASETS_OFFLINE=1
82
+ # export HF_HUB_OFFLINE=1
83
+ if [ "$date" == "" ]; then
84
+ date=$(date +%Y%m%d%H%M)
85
+ fi
86
+ RUN_NAME="$model-$(basename $path)"
87
+ RUN_ID="$RUN_NAME-$date"
88
+
89
+ export WANDB_RESUME=allow
90
+ if [[ -z "${WANDB_PROJECT}" ]]; then
91
+ export WANDB_PROJECT="fla"
92
+ fi
93
+ if [[ -z "${WANDB_NAME}" ]]; then
94
+ export WANDB_NAME="$RUN_NAME"
95
+ fi
96
+ if [[ -z "${WANDB_RUN_ID}" ]]; then
97
+ export WANDB_RUN_ID="$RUN_ID"
98
+ fi
99
+
100
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
101
+ torchrun --nnodes=${NNODE} \
102
+ --nproc_per_node=${NGPU} \
103
+ --rdzv_backend c10d \
104
+ --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
105
+ --local-ranks-filter ${LOG_RANK} \
106
+ --role rank \
107
+ --tee 3 \
108
+ --log-dir $path/logs \
109
+ -m flame.train \
110
+ $params
111
+
112
+ echo "TRAINING DONE!"
113
+ echo "Converting the DCP checkpoints to HF format..."
114
+
115
+ python -m flame.utils.convert_dcp_to_hf \
116
+ --path $path \
117
+ --step $steps \
118
+ --config $config \
119
+ --tokenizer $tokenizer
120
+
121
+ echo "RUNNING DONE!"