Add files using upload-large-folder tool
Browse files- fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc +0 -0
- logs/none_zo1mfnl3/attempt_0/0/stderr.log +0 -0
- logs/none_zo1mfnl3/attempt_0/2/stderr.log +0 -0
- logs/none_zo1mfnl3/attempt_0/3/stderr.log +0 -0
- logs/none_zo1mfnl3/attempt_0/4/stderr.log +0 -0
- torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/loss.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/metrics.cpython-312.pyc +0 -0
- torchtitan/components/dataloader.py +92 -0
- torchtitan/distributed/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
- torchtitan/experiments/deepseek_v3/attn_mask_utils.py +397 -0
- torchtitan/experiments/deepseek_v3/generate.py +308 -0
- torchtitan/experiments/deepseek_v3/model_config.py +204 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/math.py +38 -0
- torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
- torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/args.py +109 -0
- torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
- torchtitan/experiments/multimodal/tests/__init__.py +5 -0
- torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
- torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
- torchtitan/experiments/multimodal/utils.py +437 -0
- torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
- torchtitan/experiments/simple_fsdp/tests/test_numerics.py +128 -0
- torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
- torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
- torchtitan/models/llama3/parallelize_llama.py +398 -0
- torchtitan/models/llama3/train_configs/llama3_70b.toml +62 -0
- torchtitan/protocols/train_spec.py +115 -0
- train.sh +121 -0
fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc
ADDED
Binary file (17.8 kB). View file
|
|
logs/none_zo1mfnl3/attempt_0/0/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs/none_zo1mfnl3/attempt_0/2/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs/none_zo1mfnl3/attempt_0/3/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs/none_zo1mfnl3/attempt_0/4/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torchtitan/components/__pycache__/float8.cpython-312.pyc
ADDED
Binary file (6.2 kB). View file
|
|
torchtitan/components/__pycache__/loss.cpython-312.pyc
ADDED
Binary file (1.51 kB). View file
|
|
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
Binary file (7.71 kB). View file
|
|
torchtitan/components/__pycache__/metrics.cpython-312.pyc
ADDED
Binary file (19.6 kB). View file
|
|
torchtitan/components/dataloader.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
8 |
+
|
9 |
+
import pickle
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from collections.abc import Callable
|
12 |
+
from typing import Any
|
13 |
+
|
14 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
15 |
+
from torch.utils.data import IterableDataset
|
16 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
17 |
+
from torchtitan.tools.logging import logger
|
18 |
+
|
19 |
+
|
20 |
+
class BaseDataLoader(Stateful, ABC):
|
21 |
+
"""Base class for all dataloaders.
|
22 |
+
|
23 |
+
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
|
24 |
+
``state_dict()`` and ``load_state_dict()``.
|
25 |
+
"""
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
def __iter__(self):
|
29 |
+
...
|
30 |
+
|
31 |
+
|
32 |
+
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
|
33 |
+
"""Dataloader that is aware of distributed data parallelism.
|
34 |
+
|
35 |
+
This dataloader is used to load data in a distributed data parallel fashion. It also
|
36 |
+
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
|
37 |
+
methods such as ``__iter__``.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
dataset (IterableDataset): The dataset to iterate over.
|
41 |
+
dp_rank: Data parallelism rank for this dataloader.
|
42 |
+
dp_world_size: The world size of the data parallelism.
|
43 |
+
batch_size: The batch size to use for each iteration.
|
44 |
+
collate_fn: Optional function to collate samples in a batch.
|
45 |
+
"""
|
46 |
+
|
47 |
+
dp_rank: int
|
48 |
+
dp_world_size: int
|
49 |
+
batch_size: int
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
dataset: IterableDataset,
|
54 |
+
dp_rank: int,
|
55 |
+
dp_world_size: int,
|
56 |
+
batch_size: int,
|
57 |
+
collate_fn: Callable | None = None,
|
58 |
+
):
|
59 |
+
self.dp_world_size = dp_world_size
|
60 |
+
self.dp_rank = dp_rank
|
61 |
+
self.batch_size = batch_size
|
62 |
+
super().__init__(dataset, batch_size, collate_fn=collate_fn)
|
63 |
+
self._rank_id = f"dp_rank_{dp_rank}"
|
64 |
+
|
65 |
+
def state_dict(self) -> dict[str, Any]:
|
66 |
+
# Store state only for dp rank to avoid replicating the same state across other dimensions.
|
67 |
+
return {
|
68 |
+
# We don't have to use pickle as DCP will serialize the state_dict. However,
|
69 |
+
# we have to keep this for backward compatibility.
|
70 |
+
self._rank_id: pickle.dumps(super().state_dict()),
|
71 |
+
"world_size": self.dp_world_size,
|
72 |
+
}
|
73 |
+
|
74 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
75 |
+
# State being empty is valid.
|
76 |
+
if not state_dict:
|
77 |
+
return
|
78 |
+
|
79 |
+
if self._rank_id not in state_dict:
|
80 |
+
logger.warning(
|
81 |
+
f"DataLoader state is empty for dp rank {self.dp_rank}, "
|
82 |
+
"expected key {self._rank_id}"
|
83 |
+
)
|
84 |
+
return
|
85 |
+
|
86 |
+
assert self.dp_world_size == state_dict["world_size"], (
|
87 |
+
"dp_degree is inconsistent before and after checkpoint, "
|
88 |
+
"dataloader resharding is not supported yet."
|
89 |
+
)
|
90 |
+
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
|
91 |
+
# keep this for backward compatibility.
|
92 |
+
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
|
torchtitan/distributed/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (255 Bytes). View file
|
|
torchtitan/experiments/deepseek_v3/LICENSE-CODE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 DeepSeek
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
torchtitan/experiments/deepseek_v3/attn_mask_utils.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This code is based on src/transformers/modeling_attn_mask_utils.py of
|
8 |
+
# huggingface/transformers. It has been modified from its original forms to
|
9 |
+
# contain only the necessary utilities.
|
10 |
+
|
11 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
12 |
+
#
|
13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
14 |
+
# you may not use this file except in compliance with the License.
|
15 |
+
# You may obtain a copy of the License at
|
16 |
+
#
|
17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
18 |
+
#
|
19 |
+
# Unless required by applicable law or agreed to in writing, software
|
20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22 |
+
# See the License for the specific language governing permissions and
|
23 |
+
# limitations under the License.
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from typing import List, Optional, Tuple, Union
|
26 |
+
|
27 |
+
import torch
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class AttentionMaskConverter:
|
32 |
+
"""
|
33 |
+
A utility attention mask class that allows one to:
|
34 |
+
- Create a causal 4d mask
|
35 |
+
- Create a causal 4d mask with slided window
|
36 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
37 |
+
key_value_length) that can be multiplied with attention scores
|
38 |
+
|
39 |
+
Examples:
|
40 |
+
|
41 |
+
```python
|
42 |
+
>>> import torch
|
43 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
44 |
+
|
45 |
+
>>> converter = AttentionMaskConverter(True)
|
46 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
47 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
48 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
49 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
50 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
51 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
52 |
+
```
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
is_causal (`bool`):
|
56 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
57 |
+
|
58 |
+
sliding_window (`int`, *optional*):
|
59 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
60 |
+
"""
|
61 |
+
|
62 |
+
is_causal: bool
|
63 |
+
sliding_window: int
|
64 |
+
|
65 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
66 |
+
self.is_causal = is_causal
|
67 |
+
self.sliding_window = sliding_window
|
68 |
+
|
69 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
70 |
+
raise ValueError(
|
71 |
+
"Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
|
72 |
+
f"not `{self.sliding_window}`"
|
73 |
+
)
|
74 |
+
|
75 |
+
def to_causal_4d(
|
76 |
+
self,
|
77 |
+
batch_size: int,
|
78 |
+
query_length: int,
|
79 |
+
key_value_length: int,
|
80 |
+
dtype: torch.dtype,
|
81 |
+
device: Union[torch.device, "str"] = "cpu",
|
82 |
+
) -> Optional[torch.Tensor]:
|
83 |
+
"""
|
84 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
85 |
+
bias to upper right hand triangular matrix (causal mask).
|
86 |
+
"""
|
87 |
+
if not self.is_causal:
|
88 |
+
raise ValueError(
|
89 |
+
f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
|
90 |
+
)
|
91 |
+
|
92 |
+
# If shape is not cached, create a new causal mask and cache it
|
93 |
+
input_shape = (batch_size, query_length)
|
94 |
+
past_key_values_length = key_value_length - query_length
|
95 |
+
|
96 |
+
# create causal mask
|
97 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
98 |
+
causal_4d_mask = None
|
99 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
100 |
+
causal_4d_mask = self._make_causal_mask(
|
101 |
+
input_shape,
|
102 |
+
dtype,
|
103 |
+
device=device,
|
104 |
+
past_key_values_length=past_key_values_length,
|
105 |
+
sliding_window=self.sliding_window,
|
106 |
+
)
|
107 |
+
|
108 |
+
return causal_4d_mask
|
109 |
+
|
110 |
+
def to_4d(
|
111 |
+
self,
|
112 |
+
attention_mask_2d: torch.Tensor,
|
113 |
+
query_length: int,
|
114 |
+
dtype: torch.dtype,
|
115 |
+
key_value_length: Optional[int] = None,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
"""
|
118 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
119 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
120 |
+
causal, a causal mask will be added.
|
121 |
+
"""
|
122 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
123 |
+
|
124 |
+
# create causal mask
|
125 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
126 |
+
causal_4d_mask = None
|
127 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
128 |
+
if key_value_length is None:
|
129 |
+
raise ValueError(
|
130 |
+
"This attention mask converter is causal. Make sure to pass "
|
131 |
+
"`key_value_length` to correctly create a causal mask."
|
132 |
+
)
|
133 |
+
|
134 |
+
past_key_values_length = key_value_length - query_length
|
135 |
+
causal_4d_mask = self._make_causal_mask(
|
136 |
+
input_shape,
|
137 |
+
dtype,
|
138 |
+
device=attention_mask_2d.device,
|
139 |
+
past_key_values_length=past_key_values_length,
|
140 |
+
sliding_window=self.sliding_window,
|
141 |
+
)
|
142 |
+
elif self.sliding_window is not None:
|
143 |
+
raise NotImplementedError(
|
144 |
+
"Sliding window is currently only implemented for causal masking"
|
145 |
+
)
|
146 |
+
|
147 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
148 |
+
expanded_attn_mask = self._expand_mask(
|
149 |
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
150 |
+
).to(attention_mask_2d.device)
|
151 |
+
|
152 |
+
if causal_4d_mask is not None:
|
153 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(
|
154 |
+
expanded_attn_mask.bool(), torch.finfo(dtype).min
|
155 |
+
)
|
156 |
+
|
157 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
158 |
+
expanded_4d_mask = expanded_attn_mask
|
159 |
+
|
160 |
+
return expanded_4d_mask
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def _make_causal_mask(
|
164 |
+
input_ids_shape: torch.Size,
|
165 |
+
dtype: torch.dtype,
|
166 |
+
device: torch.device,
|
167 |
+
past_key_values_length: int = 0,
|
168 |
+
sliding_window: Optional[int] = None,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
Make causal mask used for bi-directional self-attention.
|
172 |
+
"""
|
173 |
+
bsz, tgt_len = input_ids_shape
|
174 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
175 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
176 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
177 |
+
|
178 |
+
mask = mask.to(dtype)
|
179 |
+
|
180 |
+
if past_key_values_length > 0:
|
181 |
+
mask = torch.cat(
|
182 |
+
[
|
183 |
+
torch.zeros(
|
184 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
185 |
+
),
|
186 |
+
mask,
|
187 |
+
],
|
188 |
+
dim=-1,
|
189 |
+
)
|
190 |
+
|
191 |
+
# add lower triangular sliding window mask if necessary
|
192 |
+
if sliding_window is not None:
|
193 |
+
diagonal = past_key_values_length - sliding_window - 1
|
194 |
+
|
195 |
+
context_mask = torch.tril(
|
196 |
+
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
|
197 |
+
)
|
198 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
199 |
+
|
200 |
+
return mask[None, None, :, :].expand(
|
201 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
202 |
+
)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def _expand_mask(
|
206 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
207 |
+
):
|
208 |
+
"""
|
209 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
210 |
+
"""
|
211 |
+
bsz, src_len = mask.size()
|
212 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
213 |
+
|
214 |
+
expanded_mask = (
|
215 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
216 |
+
)
|
217 |
+
|
218 |
+
inverted_mask = 1.0 - expanded_mask
|
219 |
+
|
220 |
+
return inverted_mask.masked_fill(
|
221 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
222 |
+
)
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def _unmask_unattended(
|
226 |
+
expanded_mask: torch.FloatTensor,
|
227 |
+
min_dtype: float,
|
228 |
+
):
|
229 |
+
# fmt: off
|
230 |
+
"""
|
231 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
232 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
233 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
234 |
+
|
235 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
236 |
+
`attention_mask` is [bsz, src_seq_len].
|
237 |
+
|
238 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
|
239 |
+
of alibi attention bias.
|
240 |
+
|
241 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
242 |
+
```
|
243 |
+
[[[[0, 0, 0],
|
244 |
+
[0, 0, 0],
|
245 |
+
[0, 0, 1]]],
|
246 |
+
[[[1, 0, 0],
|
247 |
+
[1, 1, 0],
|
248 |
+
[1, 1, 1]]],
|
249 |
+
[[[0, 0, 0],
|
250 |
+
[0, 1, 0],
|
251 |
+
[0, 1, 1]]]]
|
252 |
+
```
|
253 |
+
then the modified `expanded_mask` will be
|
254 |
+
```
|
255 |
+
[[[[1, 1, 1], <-- modified
|
256 |
+
[1, 1, 1], <-- modified
|
257 |
+
[0, 0, 1]]],
|
258 |
+
[[[1, 0, 0],
|
259 |
+
[1, 1, 0],
|
260 |
+
[1, 1, 1]]],
|
261 |
+
[[[1, 1, 1], <-- modified
|
262 |
+
[0, 1, 0],
|
263 |
+
[0, 1, 1]]]]
|
264 |
+
```
|
265 |
+
"""
|
266 |
+
# fmt: on
|
267 |
+
if expanded_mask.dtype == torch.bool:
|
268 |
+
raise ValueError(
|
269 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
270 |
+
)
|
271 |
+
|
272 |
+
return expanded_mask.mul(
|
273 |
+
~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)
|
274 |
+
)
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def _ignore_causal_mask_sdpa(
|
278 |
+
attention_mask: Optional[torch.Tensor],
|
279 |
+
inputs_embeds: torch.Tensor,
|
280 |
+
past_key_values_length: int,
|
281 |
+
sliding_window: Optional[int] = None,
|
282 |
+
is_training: bool = False,
|
283 |
+
) -> bool:
|
284 |
+
"""
|
285 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
|
286 |
+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
287 |
+
|
288 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
289 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
290 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
291 |
+
passed).
|
292 |
+
"""
|
293 |
+
|
294 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
295 |
+
key_value_length = query_length + past_key_values_length
|
296 |
+
|
297 |
+
is_tracing = (
|
298 |
+
torch.jit.is_tracing()
|
299 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
300 |
+
or is_torchdynamo_compiling()
|
301 |
+
)
|
302 |
+
|
303 |
+
ignore_causal_mask = False
|
304 |
+
|
305 |
+
if attention_mask is None:
|
306 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
|
307 |
+
# shape, thus SDPA's `is_causal` argument is rightfully updated
|
308 |
+
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
|
309 |
+
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
310 |
+
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
|
311 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
312 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
313 |
+
#
|
314 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
|
315 |
+
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
316 |
+
if (
|
317 |
+
(is_training or not is_tracing)
|
318 |
+
and (query_length == 1 or key_value_length == query_length)
|
319 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
320 |
+
):
|
321 |
+
ignore_causal_mask = True
|
322 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
323 |
+
if len(attention_mask.shape) == 4:
|
324 |
+
return False
|
325 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
326 |
+
if query_length == 1 or key_value_length == query_length:
|
327 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
328 |
+
ignore_causal_mask = True
|
329 |
+
|
330 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
|
331 |
+
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
|
332 |
+
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
333 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
334 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
335 |
+
|
336 |
+
return ignore_causal_mask
|
337 |
+
|
338 |
+
|
339 |
+
def _prepare_4d_causal_attention_mask(
|
340 |
+
attention_mask: Optional[torch.Tensor],
|
341 |
+
input_shape: Union[torch.Size, Tuple, List],
|
342 |
+
inputs_embeds: torch.Tensor,
|
343 |
+
past_key_values_length: int,
|
344 |
+
sliding_window: Optional[int] = None,
|
345 |
+
):
|
346 |
+
"""
|
347 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
348 |
+
`(batch_size, key_value_length)`
|
349 |
+
|
350 |
+
Args:
|
351 |
+
attention_mask (`torch.Tensor` or `None`):
|
352 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
353 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
354 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
355 |
+
inputs_embeds (`torch.Tensor`):
|
356 |
+
The embedded inputs as a torch Tensor.
|
357 |
+
past_key_values_length (`int`):
|
358 |
+
The length of the key value cache.
|
359 |
+
sliding_window (`int`, *optional*):
|
360 |
+
If the model uses windowed attention, a sliding window should be passed.
|
361 |
+
"""
|
362 |
+
attn_mask_converter = AttentionMaskConverter(
|
363 |
+
is_causal=True, sliding_window=sliding_window
|
364 |
+
)
|
365 |
+
|
366 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
367 |
+
|
368 |
+
# 4d mask is passed through the layers
|
369 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
370 |
+
attention_mask = attn_mask_converter.to_4d(
|
371 |
+
attention_mask,
|
372 |
+
input_shape[-1],
|
373 |
+
key_value_length=key_value_length,
|
374 |
+
dtype=inputs_embeds.dtype,
|
375 |
+
)
|
376 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
377 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
378 |
+
if tuple(attention_mask.shape) != expected_shape:
|
379 |
+
raise ValueError(
|
380 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
384 |
+
inverted_mask = 1.0 - attention_mask
|
385 |
+
attention_mask = inverted_mask.masked_fill(
|
386 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
390 |
+
input_shape[0],
|
391 |
+
input_shape[-1],
|
392 |
+
key_value_length,
|
393 |
+
dtype=inputs_embeds.dtype,
|
394 |
+
device=inputs_embeds.device,
|
395 |
+
)
|
396 |
+
|
397 |
+
return attention_mask
|
torchtitan/experiments/deepseek_v3/generate.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# torchrun --standalone --nproc-per-node 4 generate.py
|
8 |
+
|
9 |
+
# use inference.sh "Your Question Here?" to run inference with a single prompt.
|
10 |
+
|
11 |
+
import sys
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
from checkpoint import load_weights_from_hf
|
18 |
+
from model import DeepseekForCausalLM
|
19 |
+
from model_config import deepseek_config_registry
|
20 |
+
from torch.distributed.device_mesh import DeviceMesh
|
21 |
+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
|
22 |
+
from torchtitan.tools.utils import Color
|
23 |
+
from transformers import AutoTokenizer
|
24 |
+
|
25 |
+
# Uncomment the model you want to run.
|
26 |
+
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
|
27 |
+
# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
|
28 |
+
|
29 |
+
|
30 |
+
def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
|
31 |
+
"""Parse and colorize chat output with optional colors for each role."""
|
32 |
+
lines = text.split("\n")
|
33 |
+
result = []
|
34 |
+
|
35 |
+
current_role = None
|
36 |
+
current_content = []
|
37 |
+
|
38 |
+
def _process_current_content():
|
39 |
+
if not current_role or not current_content:
|
40 |
+
return None
|
41 |
+
|
42 |
+
content = "\n".join(current_content)
|
43 |
+
if current_role == "output":
|
44 |
+
return (
|
45 |
+
f"Output: {output_color}{content}{color.reset}"
|
46 |
+
if output_color
|
47 |
+
else f"Output: {content}"
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
try:
|
51 |
+
prefix, rest = current_content[0].split(":", 1)
|
52 |
+
role_color = user_color if current_role == "user" else assistant_color
|
53 |
+
if role_color:
|
54 |
+
formatted = f"{prefix}:{role_color}{rest}{color.reset}"
|
55 |
+
if len(current_content) > 1:
|
56 |
+
formatted += (
|
57 |
+
f"{role_color}\n"
|
58 |
+
+ "\n".join(current_content[1:])
|
59 |
+
+ f"{color.reset}"
|
60 |
+
)
|
61 |
+
return formatted
|
62 |
+
except ValueError:
|
63 |
+
pass
|
64 |
+
return content
|
65 |
+
|
66 |
+
for line in lines:
|
67 |
+
if line.startswith("Output:"):
|
68 |
+
if processed := _process_current_content():
|
69 |
+
result.append(processed)
|
70 |
+
current_role = "output"
|
71 |
+
content = line[len("Output:") :].strip()
|
72 |
+
if output_color:
|
73 |
+
content = f"Output: {output_color}{content}{color.reset}"
|
74 |
+
else:
|
75 |
+
content = f"Output: {content}"
|
76 |
+
result.append(content)
|
77 |
+
current_content = []
|
78 |
+
|
79 |
+
elif line.startswith("User:"):
|
80 |
+
if processed := _process_current_content():
|
81 |
+
result.append(processed)
|
82 |
+
current_role = "user"
|
83 |
+
current_content = [line]
|
84 |
+
|
85 |
+
elif line.startswith("Assistant:"):
|
86 |
+
if processed := _process_current_content():
|
87 |
+
result.append(processed)
|
88 |
+
current_role = "assistant"
|
89 |
+
current_content = [line]
|
90 |
+
|
91 |
+
else:
|
92 |
+
if current_content:
|
93 |
+
current_content.append(line)
|
94 |
+
elif line.strip() and current_role is None:
|
95 |
+
# Handle system message at the beginning
|
96 |
+
current_role = "output"
|
97 |
+
if output_color:
|
98 |
+
result.append(f"Output: {output_color}{line.strip()}{color.reset}")
|
99 |
+
else:
|
100 |
+
result.append(f"Output: {line.strip()}")
|
101 |
+
|
102 |
+
# Process the last segment
|
103 |
+
if processed := _process_current_content():
|
104 |
+
result.append(processed)
|
105 |
+
|
106 |
+
return "\n".join(result)
|
107 |
+
|
108 |
+
|
109 |
+
color = Color()
|
110 |
+
|
111 |
+
|
112 |
+
@dataclass
|
113 |
+
class DistConfig:
|
114 |
+
mesh: DeviceMesh
|
115 |
+
pp_mesh: DeviceMesh
|
116 |
+
ep_mesh: DeviceMesh
|
117 |
+
pp_size: int
|
118 |
+
ep_size: int
|
119 |
+
ep_rank: int
|
120 |
+
pp_rank: int
|
121 |
+
device: torch.device
|
122 |
+
|
123 |
+
|
124 |
+
def create_model(dist_config: DistConfig):
|
125 |
+
model_args = deepseek_config_registry[model_id]
|
126 |
+
model_args.ep_size = dist_config.ep_size
|
127 |
+
model_args.num_stages = dist_config.pp_size
|
128 |
+
model_args.stage_idx = dist_config.pp_rank
|
129 |
+
model_args.max_seq_len = 16384
|
130 |
+
|
131 |
+
with dist_config.device, dist_config.mesh:
|
132 |
+
model = DeepseekForCausalLM(model_args)
|
133 |
+
load_weights_from_hf(model, model_id, dist_config.device)
|
134 |
+
model.eval()
|
135 |
+
model.setup_symm_mem(torch.bfloat16, dist_config.device)
|
136 |
+
|
137 |
+
stage = PipelineStage(
|
138 |
+
model,
|
139 |
+
dist_config.pp_rank,
|
140 |
+
dist_config.pp_size,
|
141 |
+
dist_config.device,
|
142 |
+
group=dist_config.pp_mesh.get_group(),
|
143 |
+
)
|
144 |
+
pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
|
145 |
+
return model, pp_schedule
|
146 |
+
|
147 |
+
|
148 |
+
def create_dist_config(mesh: DeviceMesh):
|
149 |
+
rank = dist.get_rank()
|
150 |
+
device_count = torch.cuda.device_count()
|
151 |
+
device = torch.device("cuda", rank % device_count)
|
152 |
+
|
153 |
+
dist_config = DistConfig(
|
154 |
+
mesh=mesh,
|
155 |
+
pp_mesh=mesh["pp"],
|
156 |
+
ep_mesh=mesh["ep"],
|
157 |
+
pp_rank=mesh["pp"].get_local_rank(),
|
158 |
+
pp_size=mesh["pp"].size(),
|
159 |
+
ep_size=mesh["ep"].size(),
|
160 |
+
ep_rank=mesh["ep"].get_local_rank(),
|
161 |
+
device=device,
|
162 |
+
)
|
163 |
+
return dist_config
|
164 |
+
|
165 |
+
|
166 |
+
def decode(tokenizer, x):
|
167 |
+
output = tokenizer.decode(x[0])
|
168 |
+
# Clean up the output by removing special tokens
|
169 |
+
bos = tokenizer.bos_token
|
170 |
+
output = output.replace(bos, "")
|
171 |
+
# Truncate at end of sentence token
|
172 |
+
eos_token = tokenizer.eos_token
|
173 |
+
if eos_token and eos_token in output:
|
174 |
+
output = output.split(eos_token)[0]
|
175 |
+
colored_output = colorize_chat(
|
176 |
+
output,
|
177 |
+
user_color=color.green,
|
178 |
+
assistant_color=color.cyan,
|
179 |
+
output_color=color.blue,
|
180 |
+
)
|
181 |
+
return colored_output
|
182 |
+
|
183 |
+
|
184 |
+
@torch.inference_mode()
|
185 |
+
def generate(
|
186 |
+
model,
|
187 |
+
pp_schedule,
|
188 |
+
tokenizer,
|
189 |
+
dist_config,
|
190 |
+
messages: list[dict],
|
191 |
+
n_tokens: int = 50,
|
192 |
+
):
|
193 |
+
rank = dist.get_rank()
|
194 |
+
device = dist_config.device
|
195 |
+
x = tokenizer.apply_chat_template(
|
196 |
+
[messages] * dist_config.pp_size,
|
197 |
+
add_generation_prompt=True,
|
198 |
+
return_tensors="pt",
|
199 |
+
)
|
200 |
+
next_idx = x.shape[-1]
|
201 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
202 |
+
x = x.to(device)
|
203 |
+
|
204 |
+
for _ in range(n_tokens):
|
205 |
+
if dist_config.pp_size > 1:
|
206 |
+
if dist_config.pp_rank == 0:
|
207 |
+
pp_schedule.step(x)
|
208 |
+
torch.distributed.broadcast(
|
209 |
+
x,
|
210 |
+
group=dist_config.pp_mesh.get_group(),
|
211 |
+
group_src=dist_config.pp_size - 1,
|
212 |
+
)
|
213 |
+
elif dist_config.pp_rank == dist_config.pp_size - 1:
|
214 |
+
preds = pp_schedule.step()
|
215 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
216 |
+
x[:, next_idx] = next_token
|
217 |
+
torch.distributed.broadcast(
|
218 |
+
x,
|
219 |
+
group=dist_config.pp_mesh.get_group(),
|
220 |
+
group_src=dist_config.pp_size - 1,
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
pp_schedule.step()
|
224 |
+
torch.distributed.broadcast(
|
225 |
+
x,
|
226 |
+
group=dist_config.pp_mesh.get_group(),
|
227 |
+
group_src=dist_config.pp_size - 1,
|
228 |
+
)
|
229 |
+
|
230 |
+
next_idx += 1
|
231 |
+
else:
|
232 |
+
preds = model(x)
|
233 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
234 |
+
x[:, next_idx] = next_token
|
235 |
+
next_idx += 1
|
236 |
+
|
237 |
+
if rank == 0:
|
238 |
+
colored_output = decode(tokenizer, x)
|
239 |
+
print(f"Without CUDA Graph:\n{colored_output}")
|
240 |
+
|
241 |
+
|
242 |
+
@torch.inference_mode()
|
243 |
+
def generate_with_cuda_graph(
|
244 |
+
model,
|
245 |
+
tokenizer,
|
246 |
+
dist_config,
|
247 |
+
messages: list[dict],
|
248 |
+
n_tokens: int = 10,
|
249 |
+
):
|
250 |
+
rank = dist.get_rank()
|
251 |
+
device = dist_config.device
|
252 |
+
x = tokenizer.apply_chat_template(
|
253 |
+
[messages] * dist_config.pp_size,
|
254 |
+
add_generation_prompt=True,
|
255 |
+
return_tensors="pt",
|
256 |
+
)
|
257 |
+
next_idx = x.shape[-1]
|
258 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
259 |
+
x = x.to(device)
|
260 |
+
|
261 |
+
torch.cuda.synchronize()
|
262 |
+
|
263 |
+
# Create CUDA graph
|
264 |
+
g = torch.cuda.CUDAGraph()
|
265 |
+
with torch.cuda.graph(g):
|
266 |
+
preds = model(x)
|
267 |
+
|
268 |
+
# Run CUDA graph
|
269 |
+
for _ in range(n_tokens):
|
270 |
+
g.replay()
|
271 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
272 |
+
x[:, next_idx] = next_token
|
273 |
+
next_idx += 1
|
274 |
+
|
275 |
+
if rank == 0:
|
276 |
+
colored_output = decode(tokenizer, x)
|
277 |
+
print(f"With CUDA Graph:\n{colored_output}")
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
# Get user prompt from command line arguments
|
282 |
+
user_prompt = "What is 2+2?" # Default prompt
|
283 |
+
if len(sys.argv) > 1:
|
284 |
+
user_prompt = sys.argv[1]
|
285 |
+
|
286 |
+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
|
287 |
+
rank = dist.get_rank()
|
288 |
+
if rank == 0:
|
289 |
+
print(
|
290 |
+
f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
|
291 |
+
)
|
292 |
+
|
293 |
+
dist_config = create_dist_config(mesh)
|
294 |
+
model, pp_schedule = create_model(dist_config)
|
295 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
296 |
+
|
297 |
+
messages = [
|
298 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
299 |
+
{"role": "user", "content": user_prompt},
|
300 |
+
]
|
301 |
+
|
302 |
+
generate(model, pp_schedule, tokenizer, dist_config, messages)
|
303 |
+
generate_with_cuda_graph(model, tokenizer, dist_config, messages)
|
304 |
+
|
305 |
+
if rank == 0:
|
306 |
+
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
|
307 |
+
|
308 |
+
dist.destroy_process_group()
|
torchtitan/experiments/deepseek_v3/model_config.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class ModelArgs:
|
12 |
+
r"""
|
13 |
+
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
14 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
15 |
+
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
+
documentation from [`PretrainedConfig`] for more information.
|
18 |
+
Args:
|
19 |
+
vocab_size (`int`, *optional*, defaults to 129280):
|
20 |
+
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
21 |
+
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
23 |
+
Dimension of the hidden representations.
|
24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
25 |
+
Dimension of the MLP representations.
|
26 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
27 |
+
Dimension of the MoE representations.
|
28 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
29 |
+
Number of hidden layers in the Transformer decoder.
|
30 |
+
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
31 |
+
Number of nextn predict layers in the DeepSeekV3 Model.
|
32 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
33 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
34 |
+
n_shared_experts (`int`, *optional*, defaults to None):
|
35 |
+
Number of shared experts, None means dense model.
|
36 |
+
n_routed_experts (`int`, *optional*, defaults to None):
|
37 |
+
Number of routed experts, None means dense model.
|
38 |
+
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
39 |
+
Scaling factor or routed experts.
|
40 |
+
topk_method (`str`, *optional*, defaults to `gready`):
|
41 |
+
Topk method used in routed gate.
|
42 |
+
n_group (`int`, *optional*, defaults to None):
|
43 |
+
Number of groups for routed experts.
|
44 |
+
topk_group (`int`, *optional*, defaults to None):
|
45 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within
|
46 |
+
`topk_group` groups).
|
47 |
+
num_experts_per_tok (`int`, *optional*, defaults to None):
|
48 |
+
Number of selected experts, None means dense model.
|
49 |
+
moe_layer_freq (`int`, *optional*, defaults to 1):
|
50 |
+
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
51 |
+
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
52 |
+
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
53 |
+
\--k dense layers--/
|
54 |
+
norm_topk_prob (`bool`, *optional*, defaults to False):
|
55 |
+
Whether to normalize the weights of the routed experts.
|
56 |
+
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
57 |
+
Method of computing expert weights.
|
58 |
+
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
59 |
+
Auxiliary loss weight coefficient.
|
60 |
+
seq_aux = (`bool`, *optional*, defaults to True):
|
61 |
+
Whether to compute the auxiliary loss for each individual sample.
|
62 |
+
num_key_value_heads (`int`, *optional*):
|
63 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
64 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
65 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
66 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
67 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
68 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
69 |
+
`num_attention_heads`.
|
70 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
71 |
+
The non-linear activation function (function or string) in the decoder.
|
72 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
73 |
+
The maximum sequence length that this model might ever be used with.
|
74 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
75 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
76 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
77 |
+
The epsilon used by the rms normalization layers.
|
78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
80 |
+
relevant if `config.is_decoder=True`.
|
81 |
+
pad_token_id (`int`, *optional*):
|
82 |
+
Padding token id.
|
83 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
84 |
+
Beginning of stream token id.
|
85 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
86 |
+
End of stream token id.
|
87 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
88 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
89 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
90 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
91 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
92 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to tie weight embeddings
|
94 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
95 |
+
The base period of the RoPE embeddings.
|
96 |
+
rope_scaling (`Dict`, *optional*):
|
97 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
98 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
99 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
100 |
+
`max_position_embeddings` to the expected new maximum.
|
101 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
102 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
103 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
104 |
+
The dropout ratio for the attention probabilities.
|
105 |
+
"""
|
106 |
+
|
107 |
+
vocab_size: int = 129280
|
108 |
+
hidden_size: int = 7168
|
109 |
+
intermediate_size: int = 18432
|
110 |
+
moe_intermediate_size: int = 2048
|
111 |
+
num_hidden_layers: int = 61
|
112 |
+
num_nextn_predict_layers: int = 1
|
113 |
+
num_attention_heads: int = 128
|
114 |
+
num_key_value_heads: int = 128
|
115 |
+
n_shared_experts: int = 1
|
116 |
+
n_routed_experts: int = 256
|
117 |
+
ep_size: int = 1
|
118 |
+
routed_scaling_factor: float = 2.5
|
119 |
+
kv_lora_rank: int = 512
|
120 |
+
q_lora_rank: int = 1536
|
121 |
+
qk_rope_head_dim: int = 64
|
122 |
+
v_head_dim: int = 128
|
123 |
+
qk_nope_head_dim: int = 128
|
124 |
+
topk_method: str = "noaux_tc"
|
125 |
+
n_group: int = 8
|
126 |
+
topk_group: int = 4
|
127 |
+
num_experts_per_tok: int = 8
|
128 |
+
moe_layer_freq: int = 1
|
129 |
+
first_k_dense_replace: int = 3
|
130 |
+
norm_topk_prob: bool = True
|
131 |
+
scoring_func: str = "sigmoid"
|
132 |
+
aux_loss_alpha: float = 0.001
|
133 |
+
seq_aux: bool = True
|
134 |
+
hidden_act: str = "silu"
|
135 |
+
max_position_embeddings: int = 163840
|
136 |
+
initializer_range: float = 0.02
|
137 |
+
rms_norm_eps: float = 1e-6
|
138 |
+
rope_theta: float = 10000.0
|
139 |
+
rope_scaling: dict = field(
|
140 |
+
default_factory=lambda: {
|
141 |
+
"beta_fast": 32,
|
142 |
+
"beta_slow": 1,
|
143 |
+
"factor": 40,
|
144 |
+
"mscale": 1.0,
|
145 |
+
"mscale_all_dim": 1.0,
|
146 |
+
"original_max_position_embeddings": 4096,
|
147 |
+
"type": "yarn",
|
148 |
+
}
|
149 |
+
)
|
150 |
+
attention_bias: bool = False
|
151 |
+
attention_dropout: float = 0.0
|
152 |
+
pad_token_id = None
|
153 |
+
# Added for symmetric memory
|
154 |
+
max_seq_len: int = 4096
|
155 |
+
dtype: str = "bfloat16"
|
156 |
+
# Added for pipeline parallel
|
157 |
+
num_stages: int = 1
|
158 |
+
stage_idx: int = 0
|
159 |
+
|
160 |
+
|
161 |
+
# This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
|
162 |
+
deepseek_v2_lite_config = ModelArgs(
|
163 |
+
vocab_size=102400,
|
164 |
+
hidden_size=2048,
|
165 |
+
intermediate_size=10944,
|
166 |
+
moe_intermediate_size=1408,
|
167 |
+
num_hidden_layers=27,
|
168 |
+
num_attention_heads=16,
|
169 |
+
num_key_value_heads=16,
|
170 |
+
n_shared_experts=2,
|
171 |
+
n_routed_experts=64,
|
172 |
+
routed_scaling_factor=1.0,
|
173 |
+
kv_lora_rank=512,
|
174 |
+
q_lora_rank=None,
|
175 |
+
qk_rope_head_dim=64,
|
176 |
+
v_head_dim=128,
|
177 |
+
qk_nope_head_dim=128,
|
178 |
+
topk_method="greedy",
|
179 |
+
n_group=1,
|
180 |
+
topk_group=1,
|
181 |
+
num_experts_per_tok=6,
|
182 |
+
first_k_dense_replace=1,
|
183 |
+
norm_topk_prob=False,
|
184 |
+
scoring_func="softmax",
|
185 |
+
max_position_embeddings=4096,
|
186 |
+
rope_scaling={
|
187 |
+
"beta_fast": 32,
|
188 |
+
"beta_slow": 1,
|
189 |
+
"factor": 40,
|
190 |
+
"mscale": 0.707,
|
191 |
+
"mscale_all_dim": 0.707,
|
192 |
+
"original_max_position_embeddings": 4096,
|
193 |
+
"type": "yarn",
|
194 |
+
},
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
# Model configuration registry
|
199 |
+
# Key is the model distribution ID on HuggingFace Hub
|
200 |
+
deepseek_config_registry = {
|
201 |
+
"deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
|
202 |
+
"deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
|
203 |
+
"deepseek-ai/deepseek-v3": ModelArgs(),
|
204 |
+
}
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
|
11 |
+
@triton.jit
|
12 |
+
def get_tid():
|
13 |
+
return tl.inline_asm_elementwise(
|
14 |
+
"""
|
15 |
+
mov.u32 $0, %tid.x;
|
16 |
+
mov.u32 $1, %tid.y;
|
17 |
+
mov.u32 $2, %tid.z;
|
18 |
+
""",
|
19 |
+
"=r,=r,=r",
|
20 |
+
[],
|
21 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
22 |
+
is_pure=True,
|
23 |
+
pack=1,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@triton.jit
|
28 |
+
def get_ntid():
|
29 |
+
return tl.inline_asm_elementwise(
|
30 |
+
"""
|
31 |
+
mov.u32 $0, %ntid.x;
|
32 |
+
mov.u32 $1, %ntid.y;
|
33 |
+
mov.u32 $2, %ntid.z;
|
34 |
+
""",
|
35 |
+
"=r,=r,=r",
|
36 |
+
[],
|
37 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
38 |
+
is_pure=True,
|
39 |
+
pack=1,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
@triton.jit
|
44 |
+
def get_flat_tid():
|
45 |
+
tid_x, tid_y, tid_z = get_tid()
|
46 |
+
ntid_x, ntid_y, _ = get_ntid()
|
47 |
+
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
|
48 |
+
|
49 |
+
|
50 |
+
@triton.jit
|
51 |
+
def get_flat_bid():
|
52 |
+
return (
|
53 |
+
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
|
54 |
+
+ tl.program_id(1) * tl.num_programs(0)
|
55 |
+
+ tl.program_id(0)
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
@triton.jit
|
60 |
+
def sync_threads():
|
61 |
+
tl.inline_asm_elementwise(
|
62 |
+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
|
63 |
+
)
|
torchtitan/experiments/flux/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FLUX model in torchtitan
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
7 |
+
```bash
|
8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
9 |
+
```
|
10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
11 |
+
|
12 |
+
Run the following command to train the model on a single GPU:
|
13 |
+
```bash
|
14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
15 |
+
```
|
16 |
+
|
17 |
+
## TODO
|
18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
21 |
+
- [ ] Support for distributed checkpointing and loading
|
22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/dataset/flux_dataset.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import random
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Any, Callable, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from datasets import Dataset, load_dataset
|
17 |
+
from datasets.distributed import split_dataset_by_node
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
21 |
+
|
22 |
+
from torch.utils.data import IterableDataset
|
23 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
24 |
+
|
25 |
+
from torchtitan.config_manager import JobConfig
|
26 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
27 |
+
from torchtitan.tools.logging import logger
|
28 |
+
|
29 |
+
|
30 |
+
def _process_cc12m_image(
|
31 |
+
img: Image.Image,
|
32 |
+
output_size: int = 256,
|
33 |
+
) -> Optional[torch.Tensor]:
|
34 |
+
"""Process CC12M image to the desired size."""
|
35 |
+
|
36 |
+
width, height = img.size
|
37 |
+
# Skip low resolution images
|
38 |
+
if width < output_size or height < output_size:
|
39 |
+
return None
|
40 |
+
|
41 |
+
if width >= height:
|
42 |
+
# resize height to be equal to output_size, then crop
|
43 |
+
new_width, new_height = math.ceil(output_size / height * width), output_size
|
44 |
+
img = img.resize((new_width, new_height))
|
45 |
+
left = random.randint(0, new_width - output_size)
|
46 |
+
resized_img = img.crop((left, 0, left + output_size, output_size))
|
47 |
+
else:
|
48 |
+
# resize width to be equal to output_size, the crop
|
49 |
+
new_width, new_height = (
|
50 |
+
output_size,
|
51 |
+
math.ceil(output_size / width * height),
|
52 |
+
)
|
53 |
+
img = img.resize((new_width, new_height))
|
54 |
+
lower = random.randint(0, new_width - output_size)
|
55 |
+
resized_img = img.crop((0, lower, output_size, lower + output_size))
|
56 |
+
|
57 |
+
assert resized_img.size[0] == resized_img.size[1] == output_size
|
58 |
+
|
59 |
+
# Skip grayscale images
|
60 |
+
if resized_img.mode == "L":
|
61 |
+
return None
|
62 |
+
|
63 |
+
np_img = np.array(resized_img).transpose((2, 0, 1))
|
64 |
+
tensor_img = torch.tensor(np_img).float() / 255.0
|
65 |
+
|
66 |
+
# NOTE: The following commented code is an alternative way
|
67 |
+
# img_transform = transforms.Compose(
|
68 |
+
# [
|
69 |
+
# transforms.Resize(max(output_size, output_size)),
|
70 |
+
# transforms.CenterCrop((output_size, output_size)),
|
71 |
+
# transforms.ToTensor(),
|
72 |
+
# ]
|
73 |
+
# )
|
74 |
+
# tensor_img = img_transform(img)
|
75 |
+
|
76 |
+
return tensor_img
|
77 |
+
|
78 |
+
|
79 |
+
def _flux_data_processor(
|
80 |
+
sample: dict[str, Any],
|
81 |
+
t5_tokenizer: FluxTokenizer,
|
82 |
+
clip_tokenizer: FluxTokenizer,
|
83 |
+
output_size: int = 256,
|
84 |
+
) -> dict[str, Any]:
|
85 |
+
"""
|
86 |
+
Preprocess CC12M dataset sample image and text for Flux model.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
sample: A sample from dataset
|
90 |
+
t5_encoder: T5 encoder
|
91 |
+
clip_encoder: CLIP encoder
|
92 |
+
output_size: The output image size
|
93 |
+
|
94 |
+
"""
|
95 |
+
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
|
96 |
+
t5_tokens = t5_tokenizer.encode(sample["txt"])
|
97 |
+
clip_tokens = clip_tokenizer.encode(sample["txt"])
|
98 |
+
|
99 |
+
return {
|
100 |
+
"image": img,
|
101 |
+
"clip_tokens": clip_tokens, # type: List[int]
|
102 |
+
"t5_tokens": t5_tokens, # type: List[int]
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class TextToImageDatasetConfig:
|
108 |
+
path: str
|
109 |
+
loader: Callable
|
110 |
+
data_processor: Callable
|
111 |
+
|
112 |
+
|
113 |
+
DATASETS = {
|
114 |
+
"cc12m": TextToImageDatasetConfig(
|
115 |
+
path="pixparse/cc12m-wds",
|
116 |
+
loader=lambda path: load_dataset(path, split="train", streaming=True),
|
117 |
+
data_processor=_flux_data_processor,
|
118 |
+
),
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
def _validate_dataset(
|
123 |
+
dataset_name: str, dataset_path: Optional[str] = None
|
124 |
+
) -> tuple[str, Callable, Callable]:
|
125 |
+
"""Validate dataset name and path."""
|
126 |
+
if dataset_name not in DATASETS:
|
127 |
+
raise ValueError(
|
128 |
+
f"Dataset {dataset_name} is not supported. "
|
129 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
130 |
+
)
|
131 |
+
|
132 |
+
config = DATASETS[dataset_name]
|
133 |
+
path = dataset_path or config.path
|
134 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
135 |
+
return path, config.loader, config.data_processor
|
136 |
+
|
137 |
+
|
138 |
+
class FluxDataset(IterableDataset, Stateful):
|
139 |
+
"""Dataset for FLUX text-to-image model.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
dataset_name (str): Name of the dataset.
|
143 |
+
dataset_path (str): Path to the dataset.
|
144 |
+
model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
|
145 |
+
dp_rank (int): Data parallel rank.
|
146 |
+
dp_world_size (int): Data parallel world size.
|
147 |
+
infinite (bool): Whether to loop over the dataset infinitely.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
dataset_name: str,
|
153 |
+
dataset_path: Optional[str],
|
154 |
+
t5_tokenizer: FluxTokenizer,
|
155 |
+
clip_tokenizer: FluxTokenizer,
|
156 |
+
job_config: Optional[JobConfig] = None,
|
157 |
+
dp_rank: int = 0,
|
158 |
+
dp_world_size: int = 1,
|
159 |
+
infinite: bool = False,
|
160 |
+
) -> None:
|
161 |
+
|
162 |
+
# Force lowercase for consistent comparison
|
163 |
+
dataset_name = dataset_name.lower()
|
164 |
+
|
165 |
+
path, dataset_loader, data_processor = _validate_dataset(
|
166 |
+
dataset_name, dataset_path
|
167 |
+
)
|
168 |
+
ds = dataset_loader(path)
|
169 |
+
|
170 |
+
self.dataset_name = dataset_name
|
171 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
172 |
+
|
173 |
+
self._t5_tokenizer = t5_tokenizer
|
174 |
+
self._clip_tokenizer = clip_tokenizer
|
175 |
+
self._data_processor = data_processor
|
176 |
+
self.job_config = job_config
|
177 |
+
|
178 |
+
self.infinite = infinite
|
179 |
+
|
180 |
+
# Variables for checkpointing
|
181 |
+
self._sample_idx = 0
|
182 |
+
self._all_samples: list[dict[str, Any]] = []
|
183 |
+
|
184 |
+
def _get_data_iter(self):
|
185 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
186 |
+
return iter([])
|
187 |
+
|
188 |
+
it = iter(self._data)
|
189 |
+
for _ in range(self._sample_idx):
|
190 |
+
next(it)
|
191 |
+
return it
|
192 |
+
|
193 |
+
def __iter__(self):
|
194 |
+
while True:
|
195 |
+
for sample in self._get_data_iter():
|
196 |
+
# Use the dataset-specific preprocessor
|
197 |
+
sample_dict = self._data_processor(
|
198 |
+
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
|
199 |
+
)
|
200 |
+
|
201 |
+
# skip low quality image or image with color channel = 1
|
202 |
+
if sample_dict["image"] is None:
|
203 |
+
logger.warning(
|
204 |
+
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
|
205 |
+
)
|
206 |
+
continue
|
207 |
+
|
208 |
+
self._all_samples.extend(sample_dict)
|
209 |
+
self._sample_idx += 1
|
210 |
+
|
211 |
+
labels = sample_dict.pop("image")
|
212 |
+
yield sample_dict, labels
|
213 |
+
|
214 |
+
if not self.infinite:
|
215 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
216 |
+
break
|
217 |
+
else:
|
218 |
+
# Reset offset for the next iteration
|
219 |
+
self._sample_idx = 0
|
220 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
221 |
+
|
222 |
+
def load_state_dict(self, state_dict):
|
223 |
+
self._sample_idx = state_dict["sample_idx"]
|
224 |
+
self._all_samples = state_dict["all_samples"]
|
225 |
+
|
226 |
+
def state_dict(self):
|
227 |
+
return {
|
228 |
+
"all_samples": self._all_samples,
|
229 |
+
"sample_idx": self._sample_idx,
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def build_flux_dataloader(
|
234 |
+
dp_world_size: int,
|
235 |
+
dp_rank: int,
|
236 |
+
job_config: JobConfig,
|
237 |
+
# This parameter is not used, keep it for compatibility
|
238 |
+
tokenizer: FluxTokenizer | None,
|
239 |
+
infinite: bool = True,
|
240 |
+
) -> ParallelAwareDataloader:
|
241 |
+
"""Build a data loader for HuggingFace datasets."""
|
242 |
+
dataset_name = job_config.training.dataset
|
243 |
+
dataset_path = job_config.training.dataset_path
|
244 |
+
batch_size = job_config.training.batch_size
|
245 |
+
|
246 |
+
t5_encoder_name = job_config.encoder.t5_encoder
|
247 |
+
clip_encoder_name = job_config.encoder.clip_encoder
|
248 |
+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
|
249 |
+
|
250 |
+
ds = FluxDataset(
|
251 |
+
dataset_name=dataset_name,
|
252 |
+
dataset_path=dataset_path,
|
253 |
+
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
|
254 |
+
clip_tokenizer=FluxTokenizer(
|
255 |
+
clip_encoder_name, max_length=77
|
256 |
+
), # fix max_length for CLIP
|
257 |
+
dp_rank=dp_rank,
|
258 |
+
dp_world_size=dp_world_size,
|
259 |
+
infinite=infinite,
|
260 |
+
)
|
261 |
+
|
262 |
+
return ParallelAwareDataloader(
|
263 |
+
dataset=ds,
|
264 |
+
dp_rank=dp_rank,
|
265 |
+
dp_world_size=dp_world_size,
|
266 |
+
batch_size=batch_size,
|
267 |
+
)
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
9 |
+
|
10 |
+
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
15 |
+
|
16 |
+
|
17 |
+
class FluxTokenizer(Tokenizer):
|
18 |
+
"""
|
19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
27 |
+
super().__init__()
|
28 |
+
self._n_words = 8 # TODO(jianiw): check
|
29 |
+
self._max_length = max_length
|
30 |
+
|
31 |
+
self.is_clip = model_path.startswith("openai")
|
32 |
+
|
33 |
+
if self.is_clip:
|
34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
35 |
+
model_path, max_length=max_length
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
39 |
+
model_path, max_length=max_length
|
40 |
+
)
|
41 |
+
|
42 |
+
def encode(
|
43 |
+
self,
|
44 |
+
s: str,
|
45 |
+
) -> List[int]:
|
46 |
+
"""
|
47 |
+
Encode the prompt text into tokens.
|
48 |
+
"""
|
49 |
+
tokens = self._tokenizer(
|
50 |
+
s,
|
51 |
+
truncation=True,
|
52 |
+
max_length=self._max_length,
|
53 |
+
return_length=False,
|
54 |
+
return_overflowing_tokens=False,
|
55 |
+
padding="max_length",
|
56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
57 |
+
)["input_ids"]
|
58 |
+
return tokens
|
59 |
+
|
60 |
+
def decode(self, t: List[int]) -> str:
|
61 |
+
"""
|
62 |
+
Decode function. This function will not be called.
|
63 |
+
"""
|
64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from torch import nn, Tensor
|
8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
9 |
+
|
10 |
+
|
11 |
+
class FluxEmbedder(nn.Module):
|
12 |
+
def __init__(self, version: str, **hf_kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.is_clip = version.startswith("openai")
|
15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
16 |
+
|
17 |
+
if self.is_clip:
|
18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
19 |
+
version, **hf_kwargs
|
20 |
+
)
|
21 |
+
else:
|
22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
23 |
+
version, **hf_kwargs
|
24 |
+
)
|
25 |
+
|
26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
27 |
+
|
28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
29 |
+
"""
|
30 |
+
batch_tokens: [bsz, embedding_length]
|
31 |
+
|
32 |
+
For T5 Encoder, embeding_length is 768
|
33 |
+
For CLIP, embedding_length is 256
|
34 |
+
"""
|
35 |
+
outputs = self.hf_module(
|
36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
37 |
+
attention_mask=None,
|
38 |
+
output_hidden_states=False,
|
39 |
+
)
|
40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/math.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
|
12 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
13 |
+
q, k = apply_rope(q, k, pe)
|
14 |
+
|
15 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
16 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
22 |
+
assert dim % 2 == 0
|
23 |
+
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
24 |
+
omega = 1.0 / (theta**scale)
|
25 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
26 |
+
out = torch.stack(
|
27 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
28 |
+
)
|
29 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
30 |
+
return out.float()
|
31 |
+
|
32 |
+
|
33 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
34 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
35 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
36 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
37 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
38 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
torchtitan/experiments/flux/scripts/download_autoencoder.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from requests.exceptions import HTTPError
|
10 |
+
|
11 |
+
|
12 |
+
def hf_download(
|
13 |
+
repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
|
14 |
+
) -> None:
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
try:
|
18 |
+
hf_hub_download(
|
19 |
+
repo_id=repo_id,
|
20 |
+
filename=file_path,
|
21 |
+
local_dir=local_dir,
|
22 |
+
local_dir_use_symlinks=False,
|
23 |
+
token=hf_token,
|
24 |
+
)
|
25 |
+
except HTTPError as e:
|
26 |
+
if e.response.status_code == 401:
|
27 |
+
print(
|
28 |
+
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
raise e
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
import argparse
|
36 |
+
|
37 |
+
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
|
38 |
+
parser.add_argument(
|
39 |
+
"--repo_id",
|
40 |
+
type=str,
|
41 |
+
default="black-forest-labs/FLUX.1-dev",
|
42 |
+
help="Repository ID to download from. default to Flux-dev model",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--ae_path",
|
46 |
+
type=str,
|
47 |
+
default="ae.safetensors",
|
48 |
+
help="the autoencoder path relative to repo_id",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--hf_token", type=str, default=None, help="HuggingFace API token"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--local_dir",
|
55 |
+
type=str,
|
56 |
+
default="torchtitan/experiments/flux/assets/autoencoder/",
|
57 |
+
help="local directory to save the autoencoder",
|
58 |
+
)
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from PIL import ExifTags, Image
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
20 |
+
|
21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
22 |
+
AutoEncoder,
|
23 |
+
AutoEncoderParams,
|
24 |
+
load_ae,
|
25 |
+
)
|
26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
27 |
+
|
28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
29 |
+
from torchtitan.experiments.flux.utils import (
|
30 |
+
create_position_encoding_for_latents,
|
31 |
+
generate_noise_latent,
|
32 |
+
pack_latents,
|
33 |
+
preprocess_flux_data,
|
34 |
+
unpack_latents,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lin_function(
|
43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
44 |
+
) -> Callable[[float], float]:
|
45 |
+
m = (y2 - y1) / (x2 - x1)
|
46 |
+
b = y1 - m * x1
|
47 |
+
return lambda x: m * x + b
|
48 |
+
|
49 |
+
|
50 |
+
def get_schedule(
|
51 |
+
num_steps: int,
|
52 |
+
image_seq_len: int,
|
53 |
+
base_shift: float = 0.5,
|
54 |
+
max_shift: float = 1.15,
|
55 |
+
shift: bool = True,
|
56 |
+
) -> list[float]:
|
57 |
+
# extra step for zero
|
58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
59 |
+
|
60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
61 |
+
if shift:
|
62 |
+
# estimate mu based on linear estimation between two points
|
63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
65 |
+
|
66 |
+
return timesteps.tolist()
|
67 |
+
|
68 |
+
|
69 |
+
class TestGenerateImage:
|
70 |
+
def test_generate_image(self):
|
71 |
+
"""
|
72 |
+
Run a forward pass of flux model to generate an image.
|
73 |
+
"""
|
74 |
+
name = "flux-dev"
|
75 |
+
img_width = 512
|
76 |
+
img_height = 512
|
77 |
+
seed = None
|
78 |
+
prompt = (
|
79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
81 |
+
)
|
82 |
+
device = "cuda"
|
83 |
+
num_steps = None
|
84 |
+
loop = False
|
85 |
+
guidance = 3.5
|
86 |
+
output_dir = "output"
|
87 |
+
add_sampling_metadata = True
|
88 |
+
|
89 |
+
prompt = prompt.split("|")
|
90 |
+
if len(prompt) == 1:
|
91 |
+
prompt = prompt[0]
|
92 |
+
additional_prompts = None
|
93 |
+
else:
|
94 |
+
additional_prompts = prompt[1:]
|
95 |
+
prompt = prompt[0]
|
96 |
+
|
97 |
+
assert not (
|
98 |
+
(additional_prompts is not None) and loop
|
99 |
+
), "Do not provide additional prompts and set loop to True"
|
100 |
+
|
101 |
+
torch_device = torch.device(device)
|
102 |
+
if num_steps is None:
|
103 |
+
num_steps = 30
|
104 |
+
|
105 |
+
# allow for packing and conversion to latent space
|
106 |
+
img_height = 16 * (img_height // 16)
|
107 |
+
img_width = 16 * (img_width // 16)
|
108 |
+
|
109 |
+
# init all components
|
110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
111 |
+
|
112 |
+
ae = load_ae(
|
113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
114 |
+
autoencoder_params=AutoEncoderParams(),
|
115 |
+
device=torch_device,
|
116 |
+
dtype=torch.bfloat16,
|
117 |
+
)
|
118 |
+
clip_tokenizer = FluxTokenizer(
|
119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
120 |
+
)
|
121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
123 |
+
torch_device, dtype=torch.bfloat16
|
124 |
+
)
|
125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
126 |
+
torch_device, dtype=torch.bfloat16
|
127 |
+
)
|
128 |
+
|
129 |
+
rng = torch.Generator(device="cpu")
|
130 |
+
|
131 |
+
if seed is None:
|
132 |
+
seed = rng.seed()
|
133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
134 |
+
t0 = time.perf_counter()
|
135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
136 |
+
|
137 |
+
# Tokenize the prompt, on CPU
|
138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
140 |
+
|
141 |
+
batch = preprocess_flux_data(
|
142 |
+
device=torch_device,
|
143 |
+
dtype=torch.bfloat16,
|
144 |
+
autoencoder=None,
|
145 |
+
clip_encoder=clip_encoder,
|
146 |
+
t5_encoder=t5_encoder,
|
147 |
+
batch={
|
148 |
+
"clip_tokens": clip_tokens,
|
149 |
+
"t5_tokens": t5_tokens,
|
150 |
+
},
|
151 |
+
)
|
152 |
+
|
153 |
+
img = self._generate_images(
|
154 |
+
device=torch_device,
|
155 |
+
dtype=torch.bfloat16,
|
156 |
+
model=model,
|
157 |
+
decoder=ae,
|
158 |
+
img_width=img_width,
|
159 |
+
img_height=img_height,
|
160 |
+
denoising_steps=num_steps,
|
161 |
+
seed=seed,
|
162 |
+
clip_encodings=batch["clip_encodings"],
|
163 |
+
t5_encodings=batch["t5_encodings"],
|
164 |
+
guidance=guidance,
|
165 |
+
)
|
166 |
+
|
167 |
+
if torch.cuda.is_available():
|
168 |
+
torch.cuda.synchronize()
|
169 |
+
t1 = time.perf_counter()
|
170 |
+
|
171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
172 |
+
|
173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
174 |
+
|
175 |
+
def _generate_images(
|
176 |
+
self,
|
177 |
+
device: torch.device,
|
178 |
+
dtype: torch.dtype,
|
179 |
+
model: FluxModel,
|
180 |
+
decoder: AutoEncoder,
|
181 |
+
# image params:
|
182 |
+
img_width: int,
|
183 |
+
img_height: int,
|
184 |
+
# sampling params:
|
185 |
+
denoising_steps: int,
|
186 |
+
seed: int,
|
187 |
+
clip_encodings: torch.Tensor,
|
188 |
+
t5_encodings: torch.Tensor,
|
189 |
+
guidance: float = 4.0,
|
190 |
+
):
|
191 |
+
|
192 |
+
bsz = clip_encodings.shape[0]
|
193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
195 |
+
|
196 |
+
# create denoising schedule
|
197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
198 |
+
|
199 |
+
# create positional encodings
|
200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
203 |
+
).to(latents)
|
204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
205 |
+
|
206 |
+
# convert img-like latents into sequences of patches
|
207 |
+
latents = pack_latents(latents)
|
208 |
+
|
209 |
+
# this is ignored for schnell
|
210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
213 |
+
pred = model(
|
214 |
+
img=latents,
|
215 |
+
img_ids=latent_pos_enc,
|
216 |
+
txt=t5_encodings,
|
217 |
+
txt_ids=text_pos_enc,
|
218 |
+
y=clip_encodings,
|
219 |
+
timesteps=t_vec,
|
220 |
+
guidance=guidance_vec,
|
221 |
+
)
|
222 |
+
|
223 |
+
latents = latents + (t_prev - t_curr) * pred
|
224 |
+
|
225 |
+
# convert sequences of patches into img-like latents
|
226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
227 |
+
|
228 |
+
img = decoder.decode(latents)
|
229 |
+
return img
|
230 |
+
|
231 |
+
def _save_image(
|
232 |
+
self,
|
233 |
+
name: str,
|
234 |
+
output_name: str,
|
235 |
+
x: torch.Tensor,
|
236 |
+
add_sampling_metadata: bool,
|
237 |
+
prompt: str,
|
238 |
+
):
|
239 |
+
print(f"Saving {output_name}")
|
240 |
+
# bring into PIL format and save
|
241 |
+
x = x.clamp(-1, 1)
|
242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
243 |
+
|
244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
245 |
+
|
246 |
+
exif_data = Image.Exif()
|
247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
249 |
+
exif_data[ExifTags.Base.Model] = name
|
250 |
+
if add_sampling_metadata:
|
251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[job]
|
3 |
+
dump_folder = "./outputs"
|
4 |
+
description = "Flux debug model"
|
5 |
+
print_args = false
|
6 |
+
use_for_integration_test = true
|
7 |
+
|
8 |
+
[profiling]
|
9 |
+
enable_profiling = false
|
10 |
+
save_traces_folder = "profile_trace"
|
11 |
+
profile_freq = 10
|
12 |
+
enable_memory_snapshot = false
|
13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
14 |
+
|
15 |
+
[metrics]
|
16 |
+
log_freq = 1
|
17 |
+
disable_color_printing = false
|
18 |
+
enable_tensorboard = false
|
19 |
+
save_tb_folder = "tb"
|
20 |
+
enable_wandb = false
|
21 |
+
|
22 |
+
[model]
|
23 |
+
name = "flux"
|
24 |
+
flavor = "flux-debug"
|
25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
26 |
+
# test tokenizer.model, for debug purpose only
|
27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
28 |
+
# converters = "float8"
|
29 |
+
|
30 |
+
|
31 |
+
[optimizer]
|
32 |
+
name = "AdamW"
|
33 |
+
lr = 8e-4
|
34 |
+
eps = 1e-8
|
35 |
+
|
36 |
+
[lr_scheduler]
|
37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
39 |
+
decay_type = "linear"
|
40 |
+
lr_min = 0.0
|
41 |
+
|
42 |
+
[training]
|
43 |
+
batch_size = 32
|
44 |
+
seq_len = 512
|
45 |
+
max_norm = 1.0 # grad norm clipping
|
46 |
+
steps = 10
|
47 |
+
compile = false
|
48 |
+
dataset = "cc12m"
|
49 |
+
guidance = 3.5
|
50 |
+
seed = 0
|
51 |
+
|
52 |
+
[encoder]
|
53 |
+
t5_encoder="google/t5-v1_1-small"
|
54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
55 |
+
max_t5_encoding_len=512
|
56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
57 |
+
|
58 |
+
[parallelism]
|
59 |
+
data_parallel_replicate_degree = 1
|
60 |
+
data_parallel_shard_degree = 1
|
61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
62 |
+
tensor_parallel_degree = 1
|
63 |
+
enable_async_tensor_parallel = false
|
64 |
+
pipeline_parallel_degree = 1
|
65 |
+
context_parallel_degree = 1
|
66 |
+
|
67 |
+
[experimental]
|
68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
import logging
|
14 |
+
import time
|
15 |
+
|
16 |
+
# from typing import Dict, List, Optional, Tuple
|
17 |
+
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import triton
|
22 |
+
|
23 |
+
# import triton.language as tl
|
24 |
+
|
25 |
+
# Configure logging
|
26 |
+
logging.basicConfig(
|
27 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
28 |
+
)
|
29 |
+
|
30 |
+
# Try to import the optimized implementations
|
31 |
+
try:
|
32 |
+
from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
|
33 |
+
|
34 |
+
except ImportError:
|
35 |
+
logging.error(
|
36 |
+
"Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
|
37 |
+
)
|
38 |
+
raise
|
39 |
+
|
40 |
+
|
41 |
+
def compute_reference_forward(x, w, m_sizes):
|
42 |
+
"""
|
43 |
+
Reference PyTorch implementation of M*G grouped GEMM forward pass.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
47 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
48 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
torch.Tensor: Output tensor of shape (M, N)
|
52 |
+
"""
|
53 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
54 |
+
|
55 |
+
m_start = 0
|
56 |
+
for g in range(len(m_sizes)):
|
57 |
+
m_size = m_sizes[g].item()
|
58 |
+
if m_size > 0:
|
59 |
+
m_end = m_start + m_size
|
60 |
+
|
61 |
+
# Extract group input
|
62 |
+
x_g = x[m_start:m_end]
|
63 |
+
|
64 |
+
# Compute group output
|
65 |
+
y_g = torch.matmul(x_g, w.T)
|
66 |
+
|
67 |
+
# Store result
|
68 |
+
result[m_start:m_end] = y_g
|
69 |
+
|
70 |
+
# Update start index
|
71 |
+
m_start = m_end
|
72 |
+
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
@triton.testing.perf_report(
|
77 |
+
triton.testing.Benchmark(
|
78 |
+
x_names=["N"], # We'll vary the output dimension
|
79 |
+
x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
|
80 |
+
# x_vals=[8192, 16384],
|
81 |
+
line_arg="provider", # We'll compare different providers
|
82 |
+
line_vals=["pytorch_reference", "M*G grouped GEMM"],
|
83 |
+
line_names=["PyTorch Reference", "M*G grouped Kernel"],
|
84 |
+
styles=[("blue", "-"), ("red", "-")],
|
85 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
86 |
+
plot_name="mg_grouped_gemm_comparison",
|
87 |
+
args={
|
88 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
89 |
+
"K": 7168, # Hidden dimension, fixed for all tests
|
90 |
+
"G": 8, # Number of groups
|
91 |
+
"dtype": torch.float16,
|
92 |
+
"device": "cuda",
|
93 |
+
},
|
94 |
+
)
|
95 |
+
)
|
96 |
+
def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
97 |
+
"""
|
98 |
+
Benchmark the forward pass of the grouped GEMM implementation.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
M (int): Total batch size dimension
|
102 |
+
K (int): Hidden dimension
|
103 |
+
N (int): Output dimension
|
104 |
+
G (int): Number of groups
|
105 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
106 |
+
dtype (torch.dtype): Data type to use
|
107 |
+
device (str): Device to use
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
float: Performance in TFLOPS
|
111 |
+
"""
|
112 |
+
# Create group sizes for M dimension (balanced across groups)
|
113 |
+
base_size = M // G
|
114 |
+
remainder = M % G
|
115 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
116 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
117 |
+
|
118 |
+
print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
|
119 |
+
|
120 |
+
# Create input and weight tensors
|
121 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
122 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
123 |
+
|
124 |
+
# Pre-compute for PyTorch reference to ensure fair comparison
|
125 |
+
if provider == "pytorch_reference":
|
126 |
+
# Warmup
|
127 |
+
torch.cuda.synchronize()
|
128 |
+
compute_reference_forward(x, w, m_sizes)
|
129 |
+
torch.cuda.synchronize()
|
130 |
+
|
131 |
+
# Benchmark
|
132 |
+
start_time = time.time()
|
133 |
+
for _ in range(10): # Average over 10 runs
|
134 |
+
compute_reference_forward(x, w, m_sizes)
|
135 |
+
torch.cuda.synchronize()
|
136 |
+
end_time = time.time()
|
137 |
+
else: # Optimized kernel
|
138 |
+
# Warmup
|
139 |
+
torch.cuda.synchronize()
|
140 |
+
grouped_gemm_forward(x, w, m_sizes)
|
141 |
+
torch.cuda.synchronize()
|
142 |
+
|
143 |
+
# Benchmark
|
144 |
+
start_time = time.time()
|
145 |
+
for _ in range(10): # Average over 10 runs
|
146 |
+
grouped_gemm_forward(x, w, m_sizes)
|
147 |
+
torch.cuda.synchronize()
|
148 |
+
end_time = time.time()
|
149 |
+
|
150 |
+
# Calculate FLOPs
|
151 |
+
# For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
|
152 |
+
flops = 2 * M * N * K
|
153 |
+
|
154 |
+
# Convert to TFLOPS (tera-FLOPS)
|
155 |
+
avg_time = (end_time - start_time) / 10 # Average time per run
|
156 |
+
tflops = flops / avg_time / 1e12
|
157 |
+
|
158 |
+
return tflops
|
159 |
+
|
160 |
+
|
161 |
+
@triton.testing.perf_report(
|
162 |
+
triton.testing.Benchmark(
|
163 |
+
x_names=["G"], # We'll vary the number of groups
|
164 |
+
x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
|
165 |
+
line_arg="provider", # We'll compare different providers
|
166 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
167 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
168 |
+
styles=[("blue", "-"), ("red", "-")],
|
169 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
170 |
+
plot_name="mg_grouped_gemm_group_scaling",
|
171 |
+
args={
|
172 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
173 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
174 |
+
"N": 8192, # Output dimension, fixed for all tests
|
175 |
+
"dtype": torch.float16,
|
176 |
+
"device": "cuda",
|
177 |
+
},
|
178 |
+
)
|
179 |
+
)
|
180 |
+
def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
181 |
+
"""
|
182 |
+
Benchmark how performance scales with number of groups.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
M (int): Total batch size dimension
|
186 |
+
K (int): Hidden dimension
|
187 |
+
N (int): Output dimension
|
188 |
+
G (int): Number of groups
|
189 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
190 |
+
dtype (torch.dtype): Data type to use
|
191 |
+
device (str): Device to use
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
float: Performance in TFLOPS
|
195 |
+
"""
|
196 |
+
# Create group sizes for M dimension (balanced across groups)
|
197 |
+
base_size = M // G
|
198 |
+
remainder = M % G
|
199 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
200 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
201 |
+
|
202 |
+
# Create input and weight tensors
|
203 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
204 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
205 |
+
|
206 |
+
# Benchmark logic - same as previous function
|
207 |
+
if provider == "pytorch_reference":
|
208 |
+
torch.cuda.synchronize()
|
209 |
+
compute_reference_forward(x, w, m_sizes)
|
210 |
+
torch.cuda.synchronize()
|
211 |
+
|
212 |
+
start_time = time.time()
|
213 |
+
for _ in range(10):
|
214 |
+
compute_reference_forward(x, w, m_sizes)
|
215 |
+
torch.cuda.synchronize()
|
216 |
+
end_time = time.time()
|
217 |
+
else:
|
218 |
+
torch.cuda.synchronize()
|
219 |
+
grouped_gemm_forward(x, w, m_sizes)
|
220 |
+
torch.cuda.synchronize()
|
221 |
+
|
222 |
+
start_time = time.time()
|
223 |
+
for _ in range(10):
|
224 |
+
grouped_gemm_forward(x, w, m_sizes)
|
225 |
+
torch.cuda.synchronize()
|
226 |
+
end_time = time.time()
|
227 |
+
|
228 |
+
# Calculate FLOPs and TFLOPS
|
229 |
+
flops = 2 * M * N * K
|
230 |
+
avg_time = (end_time - start_time) / 10
|
231 |
+
tflops = flops / avg_time / 1e12
|
232 |
+
|
233 |
+
return tflops
|
234 |
+
|
235 |
+
|
236 |
+
@triton.testing.perf_report(
|
237 |
+
triton.testing.Benchmark(
|
238 |
+
x_names=["group_balance"], # We'll vary the group balance factor
|
239 |
+
x_vals=[
|
240 |
+
0.0,
|
241 |
+
0.25,
|
242 |
+
0.5,
|
243 |
+
0.75,
|
244 |
+
0.9,
|
245 |
+
], # Different imbalance factors (0 = balanced, 1 = max imbalance)
|
246 |
+
line_arg="provider", # We'll compare different providers
|
247 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
248 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
249 |
+
styles=[("blue", "-"), ("red", "-")],
|
250 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
251 |
+
plot_name="mg_grouped_gemm_imbalance",
|
252 |
+
args={
|
253 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
254 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
255 |
+
"N": 8192, # Output dimension, fixed for all tests
|
256 |
+
"G": 4, # Number of groups
|
257 |
+
"dtype": torch.float16,
|
258 |
+
"device": "cuda",
|
259 |
+
},
|
260 |
+
)
|
261 |
+
)
|
262 |
+
def benchmark_imbalance(
|
263 |
+
M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Benchmark how performance is affected by imbalanced group sizes.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
M (int): Total batch size dimension
|
270 |
+
K (int): Hidden dimension
|
271 |
+
N (int): Output dimension
|
272 |
+
G (int): Number of groups
|
273 |
+
group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
|
274 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
275 |
+
dtype (torch.dtype): Data type to use
|
276 |
+
device (str): Device to use
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
float: Performance in TFLOPS
|
280 |
+
"""
|
281 |
+
# Create imbalanced group sizes for M dimension
|
282 |
+
if group_balance == 0:
|
283 |
+
# Balanced case
|
284 |
+
base_size = M // G
|
285 |
+
remainder = M % G
|
286 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
287 |
+
else:
|
288 |
+
# Imbalanced case
|
289 |
+
# First group gets more elements, last group gets fewer
|
290 |
+
# The imbalance is controlled by the group_balance factor
|
291 |
+
remaining = M
|
292 |
+
M_sizes = []
|
293 |
+
for g in range(G):
|
294 |
+
# Interpolate from balanced to imbalanced based on group_balance
|
295 |
+
# For balanced (group_balance=0), each group gets M/G
|
296 |
+
# For imbalanced (group_balance=1), first group gets much more than last group
|
297 |
+
balanced_size = remaining // (G - g)
|
298 |
+
|
299 |
+
# Adjusting size based on position and imbalance factor
|
300 |
+
# First groups get more, last groups get less
|
301 |
+
if g < G // 2:
|
302 |
+
# First half of groups get more
|
303 |
+
adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
|
304 |
+
size = balanced_size + adjustment
|
305 |
+
else:
|
306 |
+
# Second half of groups get less
|
307 |
+
adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
|
308 |
+
size = balanced_size - adjustment
|
309 |
+
|
310 |
+
# Ensure we don't go below 1 or take more than remaining
|
311 |
+
size = max(1, min(size, remaining))
|
312 |
+
M_sizes.append(size)
|
313 |
+
remaining -= size
|
314 |
+
|
315 |
+
# Handle any remaining elements
|
316 |
+
if remaining > 0:
|
317 |
+
M_sizes[-1] += remaining
|
318 |
+
|
319 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
320 |
+
|
321 |
+
# Create input and weight tensors
|
322 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
323 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
324 |
+
|
325 |
+
# Benchmark logic
|
326 |
+
if provider == "pytorch_reference":
|
327 |
+
torch.cuda.synchronize()
|
328 |
+
compute_reference_forward(x, w, m_sizes)
|
329 |
+
torch.cuda.synchronize()
|
330 |
+
|
331 |
+
start_time = time.time()
|
332 |
+
for _ in range(10):
|
333 |
+
compute_reference_forward(x, w, m_sizes)
|
334 |
+
torch.cuda.synchronize()
|
335 |
+
end_time = time.time()
|
336 |
+
else:
|
337 |
+
torch.cuda.synchronize()
|
338 |
+
grouped_gemm_forward(x, w, m_sizes)
|
339 |
+
torch.cuda.synchronize()
|
340 |
+
|
341 |
+
start_time = time.time()
|
342 |
+
for _ in range(10):
|
343 |
+
grouped_gemm_forward(x, w, m_sizes)
|
344 |
+
torch.cuda.synchronize()
|
345 |
+
end_time = time.time()
|
346 |
+
|
347 |
+
# Calculate FLOPs and TFLOPS
|
348 |
+
flops = 2 * M * N * K
|
349 |
+
avg_time = (end_time - start_time) / 10
|
350 |
+
tflops = flops / avg_time / 1e12
|
351 |
+
|
352 |
+
return tflops
|
353 |
+
|
354 |
+
|
355 |
+
def benchmark_model_configs():
|
356 |
+
"""
|
357 |
+
Benchmark common model configurations used in DeepSeek-like models.
|
358 |
+
"""
|
359 |
+
# Model configurations: (M, K, N, G)
|
360 |
+
configs = [
|
361 |
+
(8192, 7168, 4096, 4), # Config 1
|
362 |
+
(8192, 2048, 7168, 4), # Config 2
|
363 |
+
(4096, 7168, 4096, 8), # Config 3
|
364 |
+
(4096, 2048, 7168, 8), # Config 4
|
365 |
+
]
|
366 |
+
|
367 |
+
results = []
|
368 |
+
|
369 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
370 |
+
dtype = torch.float16
|
371 |
+
|
372 |
+
for config_idx, (M, K, N, G) in enumerate(configs):
|
373 |
+
logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
|
374 |
+
logging.info(f"M={M}, K={K}, N={N}, G={G}")
|
375 |
+
|
376 |
+
# Create group sizes for M dimension
|
377 |
+
base_size = M // G
|
378 |
+
remainder = M % G
|
379 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
380 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
381 |
+
|
382 |
+
# Create tensors
|
383 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
384 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
385 |
+
|
386 |
+
# Benchmark PyTorch reference
|
387 |
+
torch.cuda.synchronize()
|
388 |
+
compute_reference_forward(x, w, m_sizes) # Warmup
|
389 |
+
torch.cuda.synchronize()
|
390 |
+
|
391 |
+
logging.info("Benchmarking PyTorch reference...")
|
392 |
+
torch.cuda.reset_peak_memory_stats()
|
393 |
+
start_time = time.time()
|
394 |
+
for _ in range(10):
|
395 |
+
compute_reference_forward(x, w, m_sizes)
|
396 |
+
torch.cuda.synchronize()
|
397 |
+
end_time = time.time()
|
398 |
+
pt_time = (end_time - start_time) / 10
|
399 |
+
pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
400 |
+
|
401 |
+
# Benchmark optimized kernel
|
402 |
+
torch.cuda.synchronize()
|
403 |
+
grouped_gemm_forward(x, w, m_sizes) # Warmup
|
404 |
+
torch.cuda.synchronize()
|
405 |
+
|
406 |
+
logging.info("Benchmarking optimized kernel...")
|
407 |
+
torch.cuda.reset_peak_memory_stats()
|
408 |
+
start_time = time.time()
|
409 |
+
for _ in range(10):
|
410 |
+
grouped_gemm_forward(x, w, m_sizes)
|
411 |
+
torch.cuda.synchronize()
|
412 |
+
end_time = time.time()
|
413 |
+
opt_time = (end_time - start_time) / 10
|
414 |
+
opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
415 |
+
|
416 |
+
# Calculate FLOPs and speedup
|
417 |
+
flops = 2 * M * N * K
|
418 |
+
pt_tflops = flops / pt_time / 1e12
|
419 |
+
opt_tflops = flops / opt_time / 1e12
|
420 |
+
speedup = pt_time / opt_time
|
421 |
+
|
422 |
+
# Store results
|
423 |
+
results.append(
|
424 |
+
{
|
425 |
+
"config": f"Config {config_idx + 1}",
|
426 |
+
"dimensions": f"M={M}, K={K}, N={N}, G={G}",
|
427 |
+
"pt_time_ms": pt_time * 1000,
|
428 |
+
"opt_time_ms": opt_time * 1000,
|
429 |
+
"pt_tflops": pt_tflops,
|
430 |
+
"opt_tflops": opt_tflops,
|
431 |
+
"speedup": speedup,
|
432 |
+
"pt_memory_mb": pt_memory,
|
433 |
+
"opt_memory_mb": opt_memory,
|
434 |
+
"memory_savings": (
|
435 |
+
(pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
|
436 |
+
),
|
437 |
+
}
|
438 |
+
)
|
439 |
+
|
440 |
+
logging.info(
|
441 |
+
f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
|
442 |
+
)
|
443 |
+
logging.info(
|
444 |
+
f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
|
445 |
+
)
|
446 |
+
logging.info(
|
447 |
+
f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
|
448 |
+
)
|
449 |
+
|
450 |
+
# Print summary table
|
451 |
+
logging.info("\n===== Benchmark Results Summary =====")
|
452 |
+
logging.info(
|
453 |
+
f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
|
454 |
+
)
|
455 |
+
logging.info(
|
456 |
+
f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
|
457 |
+
f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
|
458 |
+
)
|
459 |
+
logging.info("-" * 100)
|
460 |
+
|
461 |
+
for result in results:
|
462 |
+
logging.info(
|
463 |
+
f"{result['config']:<10} | "
|
464 |
+
f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
|
465 |
+
f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
|
466 |
+
f"{result['speedup']:<10.2f} | "
|
467 |
+
f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
|
468 |
+
f"{result['memory_savings']:<12.2f}%"
|
469 |
+
)
|
470 |
+
|
471 |
+
return results
|
472 |
+
|
473 |
+
|
474 |
+
def plot_benchmark_results(results):
|
475 |
+
"""
|
476 |
+
Plot benchmark results as bar charts.
|
477 |
+
"""
|
478 |
+
# Extract data
|
479 |
+
configs = [r["config"] for r in results]
|
480 |
+
pt_tflops = [r["pt_tflops"] for r in results]
|
481 |
+
opt_tflops = [r["opt_tflops"] for r in results]
|
482 |
+
speedups = [r["speedup"] for r in results]
|
483 |
+
|
484 |
+
# Create figure with subplots
|
485 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
486 |
+
|
487 |
+
# Plot TFLOPS comparison
|
488 |
+
x = np.arange(len(configs))
|
489 |
+
width = 0.35
|
490 |
+
ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
|
491 |
+
ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
|
492 |
+
ax1.set_xlabel("Model Configuration")
|
493 |
+
ax1.set_ylabel("TFLOPS")
|
494 |
+
ax1.set_title("Performance Comparison (Higher is Better)")
|
495 |
+
ax1.set_xticks(x)
|
496 |
+
ax1.set_xticklabels(configs)
|
497 |
+
ax1.legend()
|
498 |
+
ax1.grid(axis="y", linestyle="--", alpha=0.7)
|
499 |
+
|
500 |
+
# Plot speedup
|
501 |
+
ax2.bar(x, speedups, width=0.6, color="green")
|
502 |
+
ax2.set_xlabel("Model Configuration")
|
503 |
+
ax2.set_ylabel("Speedup (x)")
|
504 |
+
ax2.set_title("Speedup Factor (Higher is Better)")
|
505 |
+
ax2.set_xticks(x)
|
506 |
+
ax2.set_xticklabels(configs)
|
507 |
+
ax2.grid(axis="y", linestyle="--", alpha=0.7)
|
508 |
+
|
509 |
+
# Add speedup values on top of bars
|
510 |
+
for i, v in enumerate(speedups):
|
511 |
+
ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
|
512 |
+
|
513 |
+
plt.tight_layout()
|
514 |
+
plt.savefig("mg_grouped_gemm_benchmark_results.png")
|
515 |
+
logging.info(
|
516 |
+
"Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
def compare_mg_implementations():
|
521 |
+
"""
|
522 |
+
Combine the M*G and N*G benchmark results for comparison.
|
523 |
+
"""
|
524 |
+
# Only run this if both NG and MG benchmarks have been run
|
525 |
+
try:
|
526 |
+
import pandas as pd
|
527 |
+
|
528 |
+
# Try to load previous benchmark results
|
529 |
+
mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
|
530 |
+
ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
|
531 |
+
|
532 |
+
# Create comparison plot
|
533 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
534 |
+
|
535 |
+
# Plot speedup comparison
|
536 |
+
configs = mg_results["config"].unique()
|
537 |
+
mg_speedups = mg_results.groupby("config")["speedup"].mean()
|
538 |
+
ng_speedups = ng_results.groupby("config")["speedup"].mean()
|
539 |
+
|
540 |
+
x = np.arange(len(configs))
|
541 |
+
width = 0.35
|
542 |
+
|
543 |
+
axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
|
544 |
+
axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
|
545 |
+
axes[0].set_xlabel("Model Configuration")
|
546 |
+
axes[0].set_ylabel("Speedup (x)")
|
547 |
+
axes[0].set_title("Speedup Comparison: M*G vs N*G")
|
548 |
+
axes[0].set_xticks(x)
|
549 |
+
axes[0].set_xticklabels(configs)
|
550 |
+
axes[0].legend()
|
551 |
+
axes[0].grid(axis="y", linestyle="--", alpha=0.7)
|
552 |
+
|
553 |
+
# Plot TFLOPS comparison for optimized kernels
|
554 |
+
mg_tflops = (
|
555 |
+
mg_results[mg_results["implementation"] == "optimized"]
|
556 |
+
.groupby("config")["tflops"]
|
557 |
+
.mean()
|
558 |
+
)
|
559 |
+
ng_tflops = (
|
560 |
+
ng_results[ng_results["implementation"] == "optimized"]
|
561 |
+
.groupby("config")["tflops"]
|
562 |
+
.mean()
|
563 |
+
)
|
564 |
+
|
565 |
+
axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
|
566 |
+
axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
|
567 |
+
axes[1].set_xlabel("Model Configuration")
|
568 |
+
axes[1].set_ylabel("TFLOPS")
|
569 |
+
axes[1].set_title("Performance Comparison: M*G vs N*G")
|
570 |
+
axes[1].set_xticks(x)
|
571 |
+
axes[1].set_xticklabels(configs)
|
572 |
+
axes[1].legend()
|
573 |
+
axes[1].grid(axis="y", linestyle="--", alpha=0.7)
|
574 |
+
|
575 |
+
plt.tight_layout()
|
576 |
+
plt.savefig("mg_vs_ng_comparison.png")
|
577 |
+
logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
|
578 |
+
|
579 |
+
except Exception as e:
|
580 |
+
logging.error(f"Could not create comparison plot: {e}")
|
581 |
+
logging.info(
|
582 |
+
"Run both M*G and N*G benchmarks first to generate comparison plots"
|
583 |
+
)
|
584 |
+
|
585 |
+
|
586 |
+
if __name__ == "__main__":
|
587 |
+
parser = argparse.ArgumentParser(
|
588 |
+
description="Benchmark M*G Grouped GEMM implementations"
|
589 |
+
)
|
590 |
+
parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
|
591 |
+
parser.add_argument(
|
592 |
+
"--triton-bench", action="store_true", help="Run Triton performance reports"
|
593 |
+
)
|
594 |
+
parser.add_argument(
|
595 |
+
"--model-configs", action="store_true", help="Benchmark model configurations"
|
596 |
+
)
|
597 |
+
parser.add_argument(
|
598 |
+
"--compare-mg-ng",
|
599 |
+
action="store_true",
|
600 |
+
help="Compare M*G and N*G implementations",
|
601 |
+
)
|
602 |
+
args = parser.parse_args()
|
603 |
+
|
604 |
+
# Check if CUDA is available
|
605 |
+
if not torch.cuda.is_available():
|
606 |
+
logging.error(
|
607 |
+
"CUDA is not available. This benchmark requires a CUDA-capable GPU."
|
608 |
+
)
|
609 |
+
exit(1)
|
610 |
+
|
611 |
+
if args.run_all or args.model_configs:
|
612 |
+
# Benchmark model configurations
|
613 |
+
logging.info("Running benchmark for model configurations...")
|
614 |
+
results = benchmark_model_configs()
|
615 |
+
plot_benchmark_results(results)
|
616 |
+
|
617 |
+
if args.run_all or args.triton_bench:
|
618 |
+
# Run Triton performance reports
|
619 |
+
logging.info("Running Triton performance reports...")
|
620 |
+
benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
|
621 |
+
benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
|
622 |
+
benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
|
623 |
+
logging.info(
|
624 |
+
"Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
|
625 |
+
)
|
626 |
+
|
627 |
+
if args.run_all or args.compare_mg_ng:
|
628 |
+
# Compare M*G and N*G implementations
|
629 |
+
logging.info("Comparing M*G and N*G implementations...")
|
630 |
+
compare_mg_implementations()
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mg_grouped_gemm import grouped_gemm_forward
|
8 |
+
from .tma_autotuning import ALIGN_SIZE_M
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"grouped_gemm_forward",
|
12 |
+
"ALIGN_SIZE_M",
|
13 |
+
]
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# pyre-unsafe
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from reference_utils import (
|
14 |
+
analyze_tensor_differences,
|
15 |
+
compute_reference_backward,
|
16 |
+
compute_reference_forward,
|
17 |
+
)
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(
|
21 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
22 |
+
)
|
23 |
+
|
24 |
+
# Import grouped GEMM implementations
|
25 |
+
try:
|
26 |
+
from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
|
27 |
+
|
28 |
+
except ImportError:
|
29 |
+
logging.error(
|
30 |
+
"Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
|
31 |
+
)
|
32 |
+
raise
|
33 |
+
|
34 |
+
|
35 |
+
def test_forward_pass():
|
36 |
+
"""
|
37 |
+
A simple test for the M*G grouped GEMM forward pass with detailed error handling.
|
38 |
+
|
39 |
+
In M*G grouping:
|
40 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
41 |
+
- N dimension is the same for all groups
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
|
46 |
+
# Test parameters for DeepSeek-like models
|
47 |
+
G = 1 # Number of groups
|
48 |
+
M_sizes = [
|
49 |
+
2048,
|
50 |
+
] # 2048, 2048, 2048] # Group sizes (will be adjusted)
|
51 |
+
M_total = sum(M_sizes) # Total M dimension
|
52 |
+
N = 4096 # Output dimension (same for all groups)
|
53 |
+
K = 7168 # Hidden dimension
|
54 |
+
|
55 |
+
# Create group sizes tensor
|
56 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
57 |
+
|
58 |
+
# Create input and weight tensors - using float16 for higher precision
|
59 |
+
x = torch.randn(M_total, K, dtype=torch.float16, device=device)
|
60 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device)
|
61 |
+
|
62 |
+
# Log the setup
|
63 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
64 |
+
logging.info(f"Group sizes: {m_sizes}")
|
65 |
+
logging.info(f"Input x shape: {x.shape}")
|
66 |
+
logging.info(f"Weight w shape: {w.shape}")
|
67 |
+
|
68 |
+
# Run forward pass
|
69 |
+
logging.info("Running forward pass with grouped GEMM")
|
70 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
71 |
+
logging.info(f"Forward result shape: {result.shape}")
|
72 |
+
|
73 |
+
# Compute reference result
|
74 |
+
logging.info("Computing reference result with PyTorch")
|
75 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
76 |
+
|
77 |
+
# Compare results
|
78 |
+
logging.info("Comparing with PyTorch reference")
|
79 |
+
forward_close = analyze_tensor_differences(
|
80 |
+
result, reference_result, "Forward output"
|
81 |
+
)
|
82 |
+
|
83 |
+
return forward_close
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"Test failed with error: {e}")
|
87 |
+
import traceback
|
88 |
+
|
89 |
+
logging.error(traceback.format_exc())
|
90 |
+
return False
|
91 |
+
|
92 |
+
|
93 |
+
def test_backward_pass():
|
94 |
+
"""
|
95 |
+
A simple test for the M*G grouped GEMM backward pass with detailed error handling.
|
96 |
+
|
97 |
+
In M*G grouping:
|
98 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
99 |
+
- N dimension is the same for all groups
|
100 |
+
"""
|
101 |
+
try:
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
+
|
104 |
+
# Test parameters for DeepSeek-like models
|
105 |
+
G = 4 # Number of groups
|
106 |
+
M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
|
107 |
+
M_total = sum(M_sizes) # Total M dimension
|
108 |
+
N = 4096 # Output dimension (same for all groups)
|
109 |
+
K = 7168 # Hidden dimension
|
110 |
+
|
111 |
+
# Create group sizes tensor
|
112 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
113 |
+
|
114 |
+
# Create input and weight tensors - using float16 for higher precision
|
115 |
+
x = torch.randn(
|
116 |
+
M_total, K, dtype=torch.float16, device=device, requires_grad=True
|
117 |
+
)
|
118 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
|
119 |
+
|
120 |
+
# Log the setup
|
121 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
122 |
+
logging.info(f"Group sizes: {m_sizes}")
|
123 |
+
logging.info(f"Input x shape: {x.shape}")
|
124 |
+
logging.info(f"Weight w shape: {w.shape}")
|
125 |
+
|
126 |
+
# Step 1: Run forward pass
|
127 |
+
logging.info("Running forward pass")
|
128 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
129 |
+
logging.info(f"Forward result shape: {result.shape}")
|
130 |
+
|
131 |
+
# Create a gradient for backpropagation
|
132 |
+
grad_output = torch.randn_like(result)
|
133 |
+
logging.info(f"Created gradient with shape: {grad_output.shape}")
|
134 |
+
|
135 |
+
# Step 2: Run backward pass directly
|
136 |
+
logging.info("Running backward pass directly")
|
137 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
138 |
+
|
139 |
+
# Verify gradient shapes
|
140 |
+
logging.info(
|
141 |
+
f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
|
142 |
+
)
|
143 |
+
|
144 |
+
# Step 3: Verify gradient computation using PyTorch's autograd
|
145 |
+
logging.info("Running PyTorch reference implementation")
|
146 |
+
|
147 |
+
# Compute reference gradients
|
148 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
|
149 |
+
|
150 |
+
# Compare gradients
|
151 |
+
logging.info("Comparing gradients with PyTorch reference")
|
152 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
153 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
154 |
+
|
155 |
+
# Log overall result
|
156 |
+
if grad_x_close and grad_w_close:
|
157 |
+
logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
|
158 |
+
else:
|
159 |
+
logging.error("✗ FAILURE: Gradient mismatch detected")
|
160 |
+
|
161 |
+
return grad_x_close and grad_w_close
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
logging.error(f"Test failed with error: {e}")
|
165 |
+
import traceback
|
166 |
+
|
167 |
+
logging.error(traceback.format_exc())
|
168 |
+
return False
|
169 |
+
|
170 |
+
|
171 |
+
def test_multiple_deepseek_configs():
|
172 |
+
"""
|
173 |
+
Test multiple DeepSeek model configurations with both forward and backward pass verification.
|
174 |
+
"""
|
175 |
+
# DeepSeek configurations: (G, M, K, N)
|
176 |
+
configs = [
|
177 |
+
(4, 8192, 7168, 4096), # Config 1
|
178 |
+
(4, 8192, 2048, 7168), # Config 2
|
179 |
+
(8, 4096, 7168, 4096), # Config 3
|
180 |
+
(8, 4096, 2048, 7168), # Config 4
|
181 |
+
]
|
182 |
+
|
183 |
+
results = []
|
184 |
+
|
185 |
+
for config_idx, (G, M, K, N) in enumerate(configs):
|
186 |
+
logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
|
187 |
+
logging.info(f"G={G}, M={M}, K={K}, N={N}")
|
188 |
+
|
189 |
+
try:
|
190 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
191 |
+
|
192 |
+
# Create even group sizes
|
193 |
+
base_size = M // G
|
194 |
+
remainder = M % G
|
195 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
196 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
197 |
+
|
198 |
+
# Create input and weight tensors using float16 for higher precision
|
199 |
+
x = torch.randn(
|
200 |
+
M, K, dtype=torch.float16, device=device, requires_grad=True
|
201 |
+
)
|
202 |
+
w = torch.randn(
|
203 |
+
N, K, dtype=torch.float16, device=device, requires_grad=True
|
204 |
+
)
|
205 |
+
|
206 |
+
logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
|
207 |
+
|
208 |
+
# Run forward pass
|
209 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
210 |
+
logging.info(f"Forward result shape: {result.shape}")
|
211 |
+
|
212 |
+
# ===== FORWARD PASS VERIFICATION =====
|
213 |
+
# Compute reference forward result
|
214 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
215 |
+
|
216 |
+
# Compare forward results
|
217 |
+
forward_close = analyze_tensor_differences(
|
218 |
+
result, reference_result, "Forward output"
|
219 |
+
)
|
220 |
+
|
221 |
+
# ===== BACKWARD PASS VERIFICATION =====
|
222 |
+
# Create gradient for backpropagation
|
223 |
+
grad_output = torch.randn_like(result)
|
224 |
+
|
225 |
+
# Run backward pass
|
226 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
227 |
+
|
228 |
+
# Compute reference gradients
|
229 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(
|
230 |
+
x, w, m_sizes, grad_output
|
231 |
+
)
|
232 |
+
|
233 |
+
# Compare backward results
|
234 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
235 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
236 |
+
|
237 |
+
# Overall config result
|
238 |
+
backward_close = grad_x_close and grad_w_close
|
239 |
+
config_success = forward_close and backward_close
|
240 |
+
results.append(
|
241 |
+
(config_idx + 1, config_success, forward_close, backward_close)
|
242 |
+
)
|
243 |
+
|
244 |
+
# Log overall config result
|
245 |
+
if config_success:
|
246 |
+
logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
|
247 |
+
else:
|
248 |
+
logging.error(
|
249 |
+
f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
|
250 |
+
)
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
logging.error(f"Config {config_idx+1} test failed with error: {e}")
|
254 |
+
import traceback
|
255 |
+
|
256 |
+
logging.error(traceback.format_exc())
|
257 |
+
results.append((config_idx + 1, False, False, False))
|
258 |
+
|
259 |
+
# Summary
|
260 |
+
logging.info("\n===== Test Results Summary =====")
|
261 |
+
for config_idx, overall_success, forward_success, backward_success in results:
|
262 |
+
overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
|
263 |
+
forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
|
264 |
+
backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
|
265 |
+
|
266 |
+
logging.info(f"Config {config_idx}: {overall_status}")
|
267 |
+
logging.info(f" - Forward pass: {forward_status}")
|
268 |
+
logging.info(f" - Backward pass: {backward_status}")
|
269 |
+
|
270 |
+
return all(overall_success for _, overall_success, _, _ in results)
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
logging.info(
|
275 |
+
"Running verification for both forward and backward pass of M*G grouped GEMM"
|
276 |
+
)
|
277 |
+
|
278 |
+
# Run basic forward pass test
|
279 |
+
logging.info("\n===== Running basic forward pass test =====")
|
280 |
+
success_forward = test_forward_pass()
|
281 |
+
logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
|
282 |
+
|
283 |
+
# Run basic backward pass test
|
284 |
+
logging.info("\n===== Running basic backward pass test =====")
|
285 |
+
success_backward = test_backward_pass()
|
286 |
+
logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
|
287 |
+
|
288 |
+
# Run multiple DeepSeek configs with forward and backward verification
|
289 |
+
logging.info("\n===== Running tests for all DeepSeek configs =====")
|
290 |
+
success_configs = test_multiple_deepseek_configs()
|
291 |
+
logging.info(
|
292 |
+
f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
|
293 |
+
)
|
294 |
+
|
295 |
+
# Overall result
|
296 |
+
overall_success = success_forward and success_backward and success_configs
|
297 |
+
logging.info(
|
298 |
+
f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
|
299 |
+
)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# credit - flat index forward kernel is derived from FBGemm:
|
8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
9 |
+
|
10 |
+
# pyre-unsafe
|
11 |
+
import functools
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from typing import Any, Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
from triton import Config as TConfig
|
23 |
+
|
24 |
+
from triton.runtime import driver # @manual
|
25 |
+
|
26 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
27 |
+
|
28 |
+
from tma_autotuning import (
|
29 |
+
ALIGN_SIZE_M,
|
30 |
+
_NV_CONFIGS,
|
31 |
+
CudaUtils,
|
32 |
+
early_config_prune,
|
33 |
+
TmaDescriptorHelper,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
# Configure logging
|
38 |
+
logging.basicConfig(
|
39 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
40 |
+
)
|
41 |
+
|
42 |
+
# ============== Start Triton Kernels ===============
|
43 |
+
|
44 |
+
|
45 |
+
@triton.autotune(
|
46 |
+
configs=_NV_CONFIGS,
|
47 |
+
key=["G", "M_BUCKET", "N", "K"],
|
48 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
49 |
+
)
|
50 |
+
@triton.jit
|
51 |
+
def _kernel_mg_forward_hopper(
|
52 |
+
a_desc_ptr,
|
53 |
+
b_desc_ptr,
|
54 |
+
c_ptr,
|
55 |
+
workspace,
|
56 |
+
m_sizes,
|
57 |
+
# problem sizes
|
58 |
+
G: tl.constexpr,
|
59 |
+
M_BUCKET: tl.constexpr,
|
60 |
+
N: tl.constexpr,
|
61 |
+
K: tl.constexpr,
|
62 |
+
# config
|
63 |
+
NUM_SMS: tl.constexpr,
|
64 |
+
TMA_SIZE: tl.constexpr,
|
65 |
+
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
66 |
+
# tiles
|
67 |
+
BLOCK_SIZE_M: tl.constexpr,
|
68 |
+
BLOCK_SIZE_N: tl.constexpr,
|
69 |
+
BLOCK_SIZE_K: tl.constexpr,
|
70 |
+
) -> None:
|
71 |
+
"""
|
72 |
+
Flat index style forward kernel for Hopper.
|
73 |
+
For simplicity, we always use TMA Load and TMA Store
|
74 |
+
"""
|
75 |
+
tbidx = tl.program_id(0) # thread block index
|
76 |
+
|
77 |
+
c_dtype = c_ptr.dtype.element_ty # output dtype
|
78 |
+
|
79 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
|
80 |
+
|
81 |
+
M_end = 0
|
82 |
+
M_start = 0
|
83 |
+
processed_tiles = 0
|
84 |
+
# Size of individual weight matrix
|
85 |
+
n_size = N // G
|
86 |
+
n_start = 0
|
87 |
+
|
88 |
+
for g in range(G):
|
89 |
+
# Move down along groups
|
90 |
+
# reset to new M offset
|
91 |
+
M_start = M_end
|
92 |
+
m_size = tl.load(m_sizes + g)
|
93 |
+
M_end = M_start + m_size
|
94 |
+
n_start = n_size * g
|
95 |
+
|
96 |
+
if m_size > 0:
|
97 |
+
# Process this group
|
98 |
+
|
99 |
+
# Acquire hold on c_desc_ptr for TMA Store
|
100 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
101 |
+
desc_ptr=c_desc_ptr,
|
102 |
+
global_address=c_ptr + M_start * n_size,
|
103 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
104 |
+
global_size=[m_size, n_size],
|
105 |
+
element_ty=c_dtype,
|
106 |
+
)
|
107 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
108 |
+
|
109 |
+
# tiles for this group
|
110 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
111 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
112 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
113 |
+
|
114 |
+
while tbidx >= processed_tiles and tbidx < (
|
115 |
+
processed_tiles + group_num_tiles
|
116 |
+
):
|
117 |
+
group_index = tbidx - processed_tiles
|
118 |
+
|
119 |
+
# columnwise
|
120 |
+
tile_m_index = group_index % num_m_tiles
|
121 |
+
tile_n_index = group_index // num_m_tiles
|
122 |
+
|
123 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
124 |
+
|
125 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
126 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
127 |
+
global_n_offset = (n_start + n_offset).to(tl.int32)
|
128 |
+
|
129 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
130 |
+
# input block [M,K]
|
131 |
+
a = tl._experimental_descriptor_load(
|
132 |
+
a_desc_ptr,
|
133 |
+
[m_offset, k_offset],
|
134 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
135 |
+
c_dtype,
|
136 |
+
)
|
137 |
+
# weight block [N, K]
|
138 |
+
b = tl._experimental_descriptor_load(
|
139 |
+
b_desc_ptr,
|
140 |
+
[global_n_offset, k_offset],
|
141 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
142 |
+
c_dtype,
|
143 |
+
)
|
144 |
+
|
145 |
+
accumulator += tl.dot(a, b.T)
|
146 |
+
|
147 |
+
# Store using TMA
|
148 |
+
|
149 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
150 |
+
|
151 |
+
if USE_EPILOGUE_SUBTILING:
|
152 |
+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
153 |
+
acc = tl.permute(acc, (0, 2, 1))
|
154 |
+
acc0, acc1 = tl.split(acc)
|
155 |
+
c0 = acc0.to(c_dtype)
|
156 |
+
tl._experimental_descriptor_store(
|
157 |
+
c_desc_ptr, c0, [m_offset, n_offset]
|
158 |
+
)
|
159 |
+
c1 = acc1.to(c_dtype)
|
160 |
+
tl._experimental_descriptor_store(
|
161 |
+
c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
tl._experimental_descriptor_store(
|
165 |
+
c_desc_ptr,
|
166 |
+
accumulator.to(c_dtype),
|
167 |
+
[m_offset, n_offset],
|
168 |
+
)
|
169 |
+
# move to next tile in group
|
170 |
+
tbidx += NUM_SMS
|
171 |
+
# Update the total tiles count for the next group
|
172 |
+
processed_tiles += group_num_tiles
|
173 |
+
|
174 |
+
|
175 |
+
@triton.autotune(
|
176 |
+
configs=_NV_CONFIGS,
|
177 |
+
key=["G", "M_BUCKET", "N", "K"],
|
178 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
179 |
+
)
|
180 |
+
@triton.jit
|
181 |
+
def _kernel_mg_forward_tma(
|
182 |
+
a_desc_ptr,
|
183 |
+
b_desc_ptr,
|
184 |
+
c_ptr,
|
185 |
+
workspace,
|
186 |
+
m_sizes,
|
187 |
+
a_scale_ptr,
|
188 |
+
b_scale_ptr,
|
189 |
+
# problem sizes
|
190 |
+
G: tl.constexpr,
|
191 |
+
M_BUCKET: tl.constexpr,
|
192 |
+
N: tl.constexpr,
|
193 |
+
K: tl.constexpr,
|
194 |
+
# config
|
195 |
+
NUM_SMS: tl.constexpr,
|
196 |
+
USE_TMA_LOAD: tl.constexpr,
|
197 |
+
USE_TMA_STORE: tl.constexpr,
|
198 |
+
TMA_SIZE: tl.constexpr,
|
199 |
+
USE_FP8: tl.constexpr,
|
200 |
+
# tiles
|
201 |
+
BLOCK_SIZE_M: tl.constexpr,
|
202 |
+
BLOCK_SIZE_N: tl.constexpr,
|
203 |
+
BLOCK_SIZE_K: tl.constexpr,
|
204 |
+
) -> None:
|
205 |
+
"""
|
206 |
+
Flat index style forward kernel.
|
207 |
+
For simplicity, we always use TMA Load and TMA Store
|
208 |
+
"""
|
209 |
+
tbidx = tl.program_id(0) # thread block index
|
210 |
+
|
211 |
+
c_dtype = c_ptr.dtype.element_ty
|
212 |
+
|
213 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
214 |
+
|
215 |
+
M_end = 0
|
216 |
+
processed_tiles = 0
|
217 |
+
|
218 |
+
for g in range(G):
|
219 |
+
# Move down along groups
|
220 |
+
# reset to new M offset
|
221 |
+
M_start = M_end
|
222 |
+
m_size = tl.load(m_sizes + g)
|
223 |
+
M_end = M_start + m_size
|
224 |
+
|
225 |
+
if m_size > 0:
|
226 |
+
# Process this group
|
227 |
+
n_size = N
|
228 |
+
|
229 |
+
# TMA Store prep
|
230 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
231 |
+
desc_ptr=c_desc_ptr,
|
232 |
+
global_address=c_ptr + M_start * N,
|
233 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
234 |
+
global_size=[m_size, n_size],
|
235 |
+
element_ty=c_dtype,
|
236 |
+
)
|
237 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
238 |
+
|
239 |
+
# tiles for this group
|
240 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
241 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
242 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
243 |
+
|
244 |
+
while tbidx >= processed_tiles and tbidx < (
|
245 |
+
processed_tiles + group_num_tiles
|
246 |
+
):
|
247 |
+
group_index = tbidx - processed_tiles
|
248 |
+
|
249 |
+
tile_m_index = group_index % num_m_tiles
|
250 |
+
tile_n_index = group_index // num_m_tiles
|
251 |
+
|
252 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
253 |
+
|
254 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
255 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
256 |
+
|
257 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
258 |
+
# input block [M,K]
|
259 |
+
a = tl._experimental_descriptor_load(
|
260 |
+
a_desc_ptr,
|
261 |
+
[m_offset, k_offset],
|
262 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
263 |
+
c_dtype,
|
264 |
+
)
|
265 |
+
# weight block [N, K]
|
266 |
+
b = tl._experimental_descriptor_load(
|
267 |
+
b_desc_ptr,
|
268 |
+
[n_offset, k_offset],
|
269 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
270 |
+
c_dtype,
|
271 |
+
)
|
272 |
+
|
273 |
+
accumulator += tl.dot(a, b.T)
|
274 |
+
|
275 |
+
# Store using TMA
|
276 |
+
|
277 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
278 |
+
# n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
279 |
+
|
280 |
+
tl._experimental_descriptor_store(
|
281 |
+
c_desc_ptr,
|
282 |
+
accumulator.to(c_dtype),
|
283 |
+
[m_offset, n_offset],
|
284 |
+
)
|
285 |
+
|
286 |
+
# Move to the next tile
|
287 |
+
tbidx += NUM_SMS
|
288 |
+
# Update the total tiles count for the next group
|
289 |
+
processed_tiles += group_num_tiles
|
290 |
+
|
291 |
+
|
292 |
+
@triton.autotune(
|
293 |
+
configs=_NV_CONFIGS,
|
294 |
+
key=["G", "M_BUCKET", "N", "K"],
|
295 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
296 |
+
)
|
297 |
+
@triton.jit
|
298 |
+
def _kernel_mg_forward_no_tma(
|
299 |
+
a_ptr,
|
300 |
+
b_ptr,
|
301 |
+
c_ptr,
|
302 |
+
workspace,
|
303 |
+
m_sizes,
|
304 |
+
# problem sizes
|
305 |
+
G: tl.constexpr,
|
306 |
+
M_BUCKET: tl.constexpr,
|
307 |
+
N: tl.constexpr,
|
308 |
+
K: tl.constexpr,
|
309 |
+
# config
|
310 |
+
NUM_SMS: tl.constexpr,
|
311 |
+
USE_TMA_LOAD: tl.constexpr,
|
312 |
+
USE_TMA_STORE: tl.constexpr,
|
313 |
+
TMA_SIZE: tl.constexpr,
|
314 |
+
# tiles
|
315 |
+
BLOCK_SIZE_M: tl.constexpr,
|
316 |
+
BLOCK_SIZE_N: tl.constexpr,
|
317 |
+
BLOCK_SIZE_K: tl.constexpr,
|
318 |
+
) -> None:
|
319 |
+
"""
|
320 |
+
Flat index style forward kernel.
|
321 |
+
For bc and Ampere, we never use TMA Load and TMA Store
|
322 |
+
"""
|
323 |
+
tbidx = tl.program_id(0) # thread block index
|
324 |
+
|
325 |
+
c_dtype = c_ptr.dtype.element_ty
|
326 |
+
c_desc_ptr = None
|
327 |
+
|
328 |
+
M_end = 0
|
329 |
+
processed_tiles = 0
|
330 |
+
|
331 |
+
for g in range(G):
|
332 |
+
# Move down along groups
|
333 |
+
# reset to new M offset
|
334 |
+
M_start = M_end
|
335 |
+
m_size = tl.load(m_sizes + g)
|
336 |
+
M_end = M_start + m_size
|
337 |
+
|
338 |
+
if m_size > 0:
|
339 |
+
# Process this group
|
340 |
+
n_size = N
|
341 |
+
|
342 |
+
# tiles for this group
|
343 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
344 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
345 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
346 |
+
|
347 |
+
while tbidx >= processed_tiles and tbidx < (
|
348 |
+
processed_tiles + group_num_tiles
|
349 |
+
):
|
350 |
+
group_index = tbidx - processed_tiles
|
351 |
+
|
352 |
+
tile_m_index = group_index % num_m_tiles
|
353 |
+
tile_n_index = group_index // num_m_tiles
|
354 |
+
|
355 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
356 |
+
|
357 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
358 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
359 |
+
|
360 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
361 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
362 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
363 |
+
|
364 |
+
a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
|
365 |
+
b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
|
366 |
+
|
367 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
368 |
+
# Load with bounds checking
|
369 |
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
370 |
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
371 |
+
|
372 |
+
# Main matmul
|
373 |
+
accumulator += tl.dot(a, b.T)
|
374 |
+
|
375 |
+
# Update pointers for next block
|
376 |
+
a_ptrs += BLOCK_SIZE_K
|
377 |
+
b_ptrs += BLOCK_SIZE_K
|
378 |
+
|
379 |
+
# Store without TMA
|
380 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
381 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
382 |
+
|
383 |
+
c = accumulator.to(c_dtype)
|
384 |
+
|
385 |
+
tl.store(
|
386 |
+
c_ptr
|
387 |
+
+ (M_start + offs_am[:, None]) * N # Row stride is N
|
388 |
+
+ offs_bn[None, :], # Column offset
|
389 |
+
c,
|
390 |
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
391 |
+
)
|
392 |
+
# Move to the next tile
|
393 |
+
tbidx += NUM_SMS
|
394 |
+
# Update the total tiles count for the next group
|
395 |
+
processed_tiles += group_num_tiles
|
396 |
+
|
397 |
+
|
398 |
+
"""
|
399 |
+
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
400 |
+
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
401 |
+
"""
|
402 |
+
|
403 |
+
|
404 |
+
# ---- dx flat linear indexed ----
|
405 |
+
@triton.autotune(
|
406 |
+
configs=_NV_CONFIGS,
|
407 |
+
key=["G", "M_BUCKET", "N", "K"],
|
408 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
409 |
+
)
|
410 |
+
@triton.jit
|
411 |
+
def _kernel_mg_dx_tma(
|
412 |
+
grad_output_desc_ptr, # [MG, N]
|
413 |
+
w_desc_ptr, # [N, K]
|
414 |
+
grad_input_ptr, # output grad_x [MG, K]
|
415 |
+
workspace, # for TMA store
|
416 |
+
m_sizes, # group sizes [G]
|
417 |
+
# problem sizes
|
418 |
+
G: tl.constexpr,
|
419 |
+
M_BUCKET: tl.constexpr,
|
420 |
+
N: tl.constexpr,
|
421 |
+
K: tl.constexpr,
|
422 |
+
# config
|
423 |
+
NUM_SMS: tl.constexpr,
|
424 |
+
USE_TMA_LOAD: tl.constexpr,
|
425 |
+
USE_TMA_STORE: tl.constexpr,
|
426 |
+
TMA_SIZE: tl.constexpr,
|
427 |
+
# tiles
|
428 |
+
BLOCK_SIZE_M: tl.constexpr,
|
429 |
+
BLOCK_SIZE_N: tl.constexpr,
|
430 |
+
BLOCK_SIZE_K: tl.constexpr,
|
431 |
+
) -> None:
|
432 |
+
"""
|
433 |
+
TMA-optimized kernel for computing gradients with respect to input (dx).
|
434 |
+
For the forward pass Y = X @ W.T, the backward for input is:
|
435 |
+
grad_X = grad_Y @ W
|
436 |
+
|
437 |
+
This maps to [MG, N] @ [N, K] -> [MG, K]
|
438 |
+
|
439 |
+
Key differences from forward:
|
440 |
+
1. W is used directly and not transposed
|
441 |
+
2. The reduction dimension is now N (not K)
|
442 |
+
3. Output is [M, K] instead of [M, N]
|
443 |
+
"""
|
444 |
+
tbidx = tl.program_id(0) # thread block index
|
445 |
+
|
446 |
+
c_dtype = grad_input_ptr.dtype.element_ty
|
447 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
448 |
+
|
449 |
+
M_end = 0
|
450 |
+
processed_tiles = 0
|
451 |
+
|
452 |
+
for g in range(G):
|
453 |
+
# Move down along groups - same as forward
|
454 |
+
M_start = M_end
|
455 |
+
m_size = tl.load(m_sizes + g)
|
456 |
+
M_end = M_start + m_size
|
457 |
+
|
458 |
+
if m_size > 0:
|
459 |
+
# Process this group
|
460 |
+
# tiles for this group - now producing [M, K] output
|
461 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
462 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
463 |
+
group_num_tiles = num_m_tiles * num_k_tiles
|
464 |
+
|
465 |
+
# TMA Store prep for [M, K] output
|
466 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
467 |
+
desc_ptr=c_desc_ptr,
|
468 |
+
global_address=grad_input_ptr + M_start * K,
|
469 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
470 |
+
global_size=[m_size, K],
|
471 |
+
element_ty=c_dtype,
|
472 |
+
)
|
473 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
474 |
+
|
475 |
+
while tbidx >= processed_tiles and tbidx < (
|
476 |
+
processed_tiles + group_num_tiles
|
477 |
+
):
|
478 |
+
group_index = tbidx - processed_tiles
|
479 |
+
|
480 |
+
# Different tiling scheme for [M, K] output
|
481 |
+
tile_m_index = group_index % num_m_tiles
|
482 |
+
tile_k_index = group_index // num_m_tiles
|
483 |
+
|
484 |
+
# for grad_input block [M, K]
|
485 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
486 |
+
|
487 |
+
# Position in full matrix
|
488 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
489 |
+
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
490 |
+
|
491 |
+
# reduce along N dimension (instead of K in forward)
|
492 |
+
for n_offset in range(0, N, BLOCK_SIZE_N):
|
493 |
+
# grad_output block [M, N]
|
494 |
+
grad_output = tl._experimental_descriptor_load(
|
495 |
+
grad_output_desc_ptr,
|
496 |
+
[m_offset, n_offset],
|
497 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
498 |
+
c_dtype,
|
499 |
+
)
|
500 |
+
|
501 |
+
# weight block [N, K] - no transpose needed
|
502 |
+
w = tl._experimental_descriptor_load(
|
503 |
+
w_desc_ptr,
|
504 |
+
[n_offset, k_offset],
|
505 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
506 |
+
c_dtype,
|
507 |
+
)
|
508 |
+
|
509 |
+
# grad_x = grad_output @ w
|
510 |
+
# reducing along N dimension
|
511 |
+
accumulator += tl.dot(grad_output, w)
|
512 |
+
|
513 |
+
# Store using TMA
|
514 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
515 |
+
# k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
516 |
+
|
517 |
+
tl._experimental_descriptor_store(
|
518 |
+
c_desc_ptr,
|
519 |
+
accumulator.to(c_dtype),
|
520 |
+
[m_offset, k_offset],
|
521 |
+
)
|
522 |
+
|
523 |
+
# Move to the next tile
|
524 |
+
tbidx += NUM_SMS
|
525 |
+
|
526 |
+
# Update the total tiles count for the next group
|
527 |
+
processed_tiles += group_num_tiles
|
528 |
+
|
529 |
+
|
530 |
+
# ---- dw flat linear indexed ----
|
531 |
+
|
532 |
+
|
533 |
+
@triton.autotune(
|
534 |
+
configs=_NV_CONFIGS,
|
535 |
+
key=["G", "M_BUCKET", "N", "K"],
|
536 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
537 |
+
)
|
538 |
+
@triton.jit
|
539 |
+
def _kernel_mg_dw_tma(
|
540 |
+
x_desc_ptr, # input descriptor [M_total, K]
|
541 |
+
grad_output_desc_ptr, # grad_output descriptor [M_total, N]
|
542 |
+
grad_weight_ptr, # output grad_w [N, K]
|
543 |
+
workspace, # workspace for TMA store
|
544 |
+
m_sizes, # group sizes [G]
|
545 |
+
# problem sizes
|
546 |
+
G: tl.constexpr,
|
547 |
+
M_BUCKET: tl.constexpr,
|
548 |
+
N: tl.constexpr,
|
549 |
+
K: tl.constexpr,
|
550 |
+
# config
|
551 |
+
NUM_SMS: tl.constexpr,
|
552 |
+
USE_TMA_LOAD: tl.constexpr,
|
553 |
+
USE_TMA_STORE: tl.constexpr,
|
554 |
+
TMA_SIZE: tl.constexpr,
|
555 |
+
# tiles
|
556 |
+
BLOCK_SIZE_N: tl.constexpr,
|
557 |
+
BLOCK_SIZE_K: tl.constexpr,
|
558 |
+
BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
|
559 |
+
) -> None:
|
560 |
+
"""
|
561 |
+
Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
|
562 |
+
Uses flat index structure similar to forward.
|
563 |
+
|
564 |
+
For the forward pass Y = X @ W.T,
|
565 |
+
the backward for weights is:
|
566 |
+
grad_W = grad_Y.T @ X
|
567 |
+
|
568 |
+
Where:
|
569 |
+
- grad_Y is [MG, N]
|
570 |
+
- X is [MG, K]
|
571 |
+
- grad_W is [N, K]
|
572 |
+
- we return [N,K]
|
573 |
+
"""
|
574 |
+
# Get thread block index l
|
575 |
+
tbidx = tl.program_id(0)
|
576 |
+
|
577 |
+
# Get output data type
|
578 |
+
c_dtype = grad_weight_ptr.dtype.element_ty
|
579 |
+
|
580 |
+
# Calculate number of output tiles
|
581 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
582 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
583 |
+
total_output_tiles = num_n_tiles * num_k_tiles
|
584 |
+
|
585 |
+
# Process tiles in strided manner across SMs
|
586 |
+
for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
|
587 |
+
# Calculate tile indices
|
588 |
+
tile_n_idx = tile_idx % num_n_tiles
|
589 |
+
tile_k_idx = tile_idx // num_n_tiles
|
590 |
+
|
591 |
+
# Calculate global offsets
|
592 |
+
n_offset = tile_n_idx * BLOCK_SIZE_N
|
593 |
+
k_offset = tile_k_idx * BLOCK_SIZE_K
|
594 |
+
|
595 |
+
# Initialize accumulator for this output tile [N, K]
|
596 |
+
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
597 |
+
|
598 |
+
# Process each group
|
599 |
+
M_end = 0
|
600 |
+
for g in range(G):
|
601 |
+
# Get group boundaries
|
602 |
+
M_start = M_end
|
603 |
+
m_size = tl.load(m_sizes + g)
|
604 |
+
M_end = M_start + m_size
|
605 |
+
|
606 |
+
# Only process if group is non-empty
|
607 |
+
if m_size > 0:
|
608 |
+
# Process this group in chunks along the M dimension
|
609 |
+
for m_offset in range(0, m_size, BLOCK_SIZE_M):
|
610 |
+
# Calculate actual block size (handling boundary)
|
611 |
+
m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
|
612 |
+
|
613 |
+
# Only process if we have actual work to do
|
614 |
+
if m_block_size > 0:
|
615 |
+
# Global offset for this chunk
|
616 |
+
m_global_offset = M_start + m_offset
|
617 |
+
|
618 |
+
if USE_TMA_LOAD:
|
619 |
+
# Load input chunk [M_chunk, K] using TMA
|
620 |
+
x_block = tl._experimental_descriptor_load(
|
621 |
+
x_desc_ptr,
|
622 |
+
[m_global_offset, k_offset],
|
623 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
624 |
+
c_dtype,
|
625 |
+
)
|
626 |
+
|
627 |
+
# Load grad_output chunk [M_chunk, N] using TMA
|
628 |
+
grad_output_block = tl._experimental_descriptor_load(
|
629 |
+
grad_output_desc_ptr,
|
630 |
+
[m_global_offset, n_offset],
|
631 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
632 |
+
c_dtype,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Apply masks for valid regions
|
636 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
637 |
+
m_mask = offs_m < m_block_size
|
638 |
+
|
639 |
+
# Zero out invalid elements
|
640 |
+
x_block = tl.where(m_mask[:, None], x_block, 0.0)
|
641 |
+
grad_output_block = tl.where(
|
642 |
+
m_mask[:, None], grad_output_block, 0.0
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
# Manual load with bounds checking
|
646 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
647 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
648 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
649 |
+
|
650 |
+
# Create masks
|
651 |
+
m_mask = offs_m < m_block_size
|
652 |
+
n_mask = offs_n < N - n_offset
|
653 |
+
k_mask = offs_k < K - k_offset
|
654 |
+
|
655 |
+
# Combined masks
|
656 |
+
mk_mask = m_mask[:, None] & k_mask[None, :]
|
657 |
+
mn_mask = m_mask[:, None] & n_mask[None, :]
|
658 |
+
|
659 |
+
# Global offsets for loading
|
660 |
+
m_global_offs = m_global_offset + offs_m
|
661 |
+
|
662 |
+
# Load x block [M_chunk, K]
|
663 |
+
x_block = tl.load(
|
664 |
+
x_desc_ptr
|
665 |
+
+ m_global_offs[:, None] * K
|
666 |
+
+ (k_offset + offs_k)[None, :],
|
667 |
+
mask=mk_mask,
|
668 |
+
other=0.0,
|
669 |
+
)
|
670 |
+
|
671 |
+
# Load grad_output block [M_chunk, N]
|
672 |
+
grad_output_block = tl.load(
|
673 |
+
grad_output_desc_ptr
|
674 |
+
+ m_global_offs[:, None] * N
|
675 |
+
+ (n_offset + offs_n)[None, :],
|
676 |
+
mask=mn_mask,
|
677 |
+
other=0.0,
|
678 |
+
)
|
679 |
+
|
680 |
+
# Compute partial contribution: grad_W += grad_Y.T @ X
|
681 |
+
# transpose grad_output for the matmul
|
682 |
+
contribution = tl.dot(
|
683 |
+
grad_output_block.to(tl.float32).T, # [N, M_chunk]
|
684 |
+
x_block.to(tl.float32), # [M_chunk, K]
|
685 |
+
)
|
686 |
+
|
687 |
+
# Accumulate
|
688 |
+
accumulator += contribution
|
689 |
+
|
690 |
+
# Store the result
|
691 |
+
if USE_TMA_STORE:
|
692 |
+
# Store using TMA
|
693 |
+
tl._experimental_descriptor_store(
|
694 |
+
workspace, # TMA store descriptor
|
695 |
+
accumulator.to(c_dtype),
|
696 |
+
[n_offset, k_offset],
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
# Manual store with bounds checking
|
700 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
701 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
702 |
+
|
703 |
+
# Create masks for bounds checking
|
704 |
+
n_mask = offs_n < N - n_offset
|
705 |
+
k_mask = offs_k < K - k_offset
|
706 |
+
output_mask = n_mask[:, None] & k_mask[None, :]
|
707 |
+
|
708 |
+
# Store the result
|
709 |
+
tl.store(
|
710 |
+
grad_weight_ptr
|
711 |
+
+ (n_offset + offs_n)[:, None] * K
|
712 |
+
+ (k_offset + offs_k)[None, :],
|
713 |
+
accumulator.to(c_dtype),
|
714 |
+
mask=output_mask,
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
# ======== End Triton kernels ========
|
719 |
+
|
720 |
+
# ======== Triton wrapper functions ========
|
721 |
+
|
722 |
+
# ----- main forward pass wrapper -----
|
723 |
+
|
724 |
+
|
725 |
+
def grouped_gemm_forward(
|
726 |
+
x: torch.Tensor,
|
727 |
+
w: torch.Tensor,
|
728 |
+
m_sizes: torch.Tensor,
|
729 |
+
tma_size: int = 128,
|
730 |
+
) -> torch.Tensor:
|
731 |
+
"""
|
732 |
+
M*G style grouped GEMM with TMA and Float8 support.
|
733 |
+
# Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
|
734 |
+
|
735 |
+
"""
|
736 |
+
if not CudaUtils.verify_tma():
|
737 |
+
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
738 |
+
|
739 |
+
G = m_sizes.shape[0]
|
740 |
+
|
741 |
+
assert x.is_contiguous()
|
742 |
+
assert w.is_contiguous()
|
743 |
+
assert m_sizes.is_contiguous()
|
744 |
+
|
745 |
+
# Total input size is now [M_total, K] where M_total is the sum of all group sizes
|
746 |
+
M_total, K = x.shape
|
747 |
+
N = w.shape[0] # N is now the same for all groups
|
748 |
+
|
749 |
+
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
750 |
+
|
751 |
+
# Verify that all group sizes are multiples of ALIGN_SIZE_M
|
752 |
+
# This check is commented out because it will involve a GPU-CPU sync
|
753 |
+
# assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
|
754 |
+
|
755 |
+
# Create output tensor with correct shape [M_total, N]
|
756 |
+
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
757 |
+
|
758 |
+
if M_total == 0:
|
759 |
+
return y
|
760 |
+
|
761 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
762 |
+
USE_TMA_LOAD = True
|
763 |
+
USE_TMA_STORE = True
|
764 |
+
USE_EPILOGUE_SUBTILING = False
|
765 |
+
|
766 |
+
# TMA descriptor helper
|
767 |
+
desc_helper = None
|
768 |
+
desc_x = x
|
769 |
+
desc_w = w
|
770 |
+
workspace = None
|
771 |
+
|
772 |
+
if USE_TMA_LOAD:
|
773 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
774 |
+
desc_helper.init_tma_descriptor("x")
|
775 |
+
desc_helper.init_tma_descriptor("w")
|
776 |
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
777 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
778 |
+
|
779 |
+
if USE_TMA_STORE:
|
780 |
+
workspace = torch.empty(
|
781 |
+
NUM_SMS * desc_helper.tma_size,
|
782 |
+
device=x.device,
|
783 |
+
dtype=torch.uint8,
|
784 |
+
)
|
785 |
+
|
786 |
+
def grid(META):
|
787 |
+
if USE_TMA_LOAD:
|
788 |
+
nonlocal desc_helper
|
789 |
+
desc_helper.fill_2d_tma_descriptor(
|
790 |
+
"x",
|
791 |
+
x.data_ptr(),
|
792 |
+
M_total,
|
793 |
+
K,
|
794 |
+
META["BLOCK_SIZE_M"],
|
795 |
+
META["BLOCK_SIZE_K"],
|
796 |
+
x.element_size(),
|
797 |
+
)
|
798 |
+
|
799 |
+
desc_helper.fill_2d_tma_descriptor(
|
800 |
+
"w",
|
801 |
+
w.data_ptr(),
|
802 |
+
N,
|
803 |
+
K,
|
804 |
+
META["BLOCK_SIZE_N"],
|
805 |
+
META["BLOCK_SIZE_K"],
|
806 |
+
w.element_size(),
|
807 |
+
)
|
808 |
+
return (NUM_SMS,)
|
809 |
+
|
810 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
811 |
+
|
812 |
+
_kernel_mg_forward_hopper[grid](
|
813 |
+
desc_x,
|
814 |
+
desc_w,
|
815 |
+
y,
|
816 |
+
workspace,
|
817 |
+
m_sizes,
|
818 |
+
G,
|
819 |
+
M_BUCKET,
|
820 |
+
N,
|
821 |
+
K,
|
822 |
+
NUM_SMS,
|
823 |
+
TMA_SIZE=tma_size,
|
824 |
+
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
825 |
+
)
|
826 |
+
|
827 |
+
return y
|
828 |
+
|
829 |
+
|
830 |
+
# ======== Improved Backward =============
|
831 |
+
def grouped_gemm_backward(
|
832 |
+
grad_output: torch.Tensor,
|
833 |
+
x: torch.Tensor,
|
834 |
+
w: torch.Tensor,
|
835 |
+
m_sizes: torch.Tensor,
|
836 |
+
use_tma: bool = True,
|
837 |
+
tma_size: int = 128,
|
838 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
839 |
+
"""
|
840 |
+
Unified backward pass for grouped GeMM with M*G grouping.
|
841 |
+
Uses optimized TMA-based implementations for both dx and dw when available.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
grad_output: Gradient of output, shape [M_total, N]
|
845 |
+
x: Input tensor from forward pass, shape [M_total, K]
|
846 |
+
w: Weight tensor from forward pass, shape [N, K]
|
847 |
+
m_sizes: Group sizes tensor, shape [G]
|
848 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
849 |
+
tma_size: Size of TMA descriptor in bytes
|
850 |
+
|
851 |
+
|
852 |
+
Returns:
|
853 |
+
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
854 |
+
"""
|
855 |
+
logging.info("Starting unified grouped_gemm_backward")
|
856 |
+
|
857 |
+
# do this once, seems expensive
|
858 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
859 |
+
|
860 |
+
# Basic validation
|
861 |
+
G = m_sizes.shape[0]
|
862 |
+
M_total, K_x = x.shape
|
863 |
+
M_grad, N = grad_output.shape
|
864 |
+
N_w, K_w = w.shape
|
865 |
+
|
866 |
+
# Check dimensions
|
867 |
+
if K_x != K_w:
|
868 |
+
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
869 |
+
if M_total != M_grad:
|
870 |
+
raise ValueError(
|
871 |
+
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
872 |
+
)
|
873 |
+
|
874 |
+
# Check total M matches sum of group sizes
|
875 |
+
sum_m_sizes = m_sizes.sum().item()
|
876 |
+
if M_total != sum_m_sizes:
|
877 |
+
raise ValueError(
|
878 |
+
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
879 |
+
)
|
880 |
+
|
881 |
+
# Make sure inputs are contiguous
|
882 |
+
grad_output = grad_output.contiguous()
|
883 |
+
x = x.contiguous()
|
884 |
+
w = w.contiguous()
|
885 |
+
m_sizes = m_sizes.contiguous()
|
886 |
+
|
887 |
+
# Check TMA support
|
888 |
+
can_use_tma = use_tma and CudaUtils.verify_tma()
|
889 |
+
if use_tma and not can_use_tma:
|
890 |
+
logging.info("TMA requested but not supported on this device")
|
891 |
+
use_tma = False
|
892 |
+
|
893 |
+
# Compute grad_x using flat linear implementation
|
894 |
+
try:
|
895 |
+
logging.info(f"Computing grad_x with flat linear kernel")
|
896 |
+
|
897 |
+
# Use TMA-optimized implementation
|
898 |
+
grad_x = grouped_gemm_dx_tma(
|
899 |
+
grad_output=grad_output,
|
900 |
+
w=w,
|
901 |
+
m_sizes=m_sizes,
|
902 |
+
num_sms=NUM_SMS,
|
903 |
+
tma_size=tma_size,
|
904 |
+
)
|
905 |
+
|
906 |
+
except Exception as e:
|
907 |
+
logging.error(f"Error in grad_x computation: {e}")
|
908 |
+
raise
|
909 |
+
|
910 |
+
# Compute grad_w using flat linear style implementation
|
911 |
+
try:
|
912 |
+
logging.info(f"Computing grad_w with flat linear kernel")
|
913 |
+
|
914 |
+
grad_w = grouped_gemm_dw_tma(
|
915 |
+
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
916 |
+
)
|
917 |
+
except Exception as e:
|
918 |
+
logging.error(f"Error in grad_w computation: {e}")
|
919 |
+
raise
|
920 |
+
|
921 |
+
return grad_x, grad_w
|
922 |
+
|
923 |
+
|
924 |
+
# ----- dx backward pass wrapper -----
|
925 |
+
|
926 |
+
|
927 |
+
def grouped_gemm_dx_tma(
|
928 |
+
grad_output: torch.Tensor,
|
929 |
+
w: torch.Tensor,
|
930 |
+
m_sizes: torch.Tensor,
|
931 |
+
num_sms: int = 132,
|
932 |
+
tma_size: int = 128,
|
933 |
+
) -> torch.Tensor:
|
934 |
+
"""
|
935 |
+
Optimized backward pass wrapper for computing gradient with respect to input (dx)
|
936 |
+
using TMA patterns similar to the forward pass.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
grad_output: Gradient of output, shape [M_total, N]
|
940 |
+
w: Weight tensor, shape [N, K]
|
941 |
+
m_sizes: Group sizes tensor, shape [G]
|
942 |
+
tma_size: Size of TMA descriptor
|
943 |
+
# using_fp8: Whether to use FP8 quantization
|
944 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
945 |
+
# w_scale: Scale for w in FP8 mode
|
946 |
+
|
947 |
+
Returns:
|
948 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
949 |
+
"""
|
950 |
+
"""
|
951 |
+
Optimized backward pass for computing gradient with respect to input (dx)
|
952 |
+
using TMA patterns similar to the forward pass.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
grad_output: Gradient of output, shape [M_total, N]
|
956 |
+
w: Weight tensor, shape [N, K]
|
957 |
+
m_sizes: Group sizes tensor, shape [G]
|
958 |
+
tma_size: Size of TMA descriptor
|
959 |
+
using_fp8: Whether to use FP8 quantization
|
960 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
961 |
+
# w_scale: Scale for w in FP8 mode
|
962 |
+
|
963 |
+
Returns:
|
964 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
965 |
+
"""
|
966 |
+
if not CudaUtils.verify_tma():
|
967 |
+
raise NotImplementedError("Optimized dx computation requires TMA support")
|
968 |
+
|
969 |
+
G = m_sizes.shape[0]
|
970 |
+
|
971 |
+
assert grad_output.is_contiguous()
|
972 |
+
assert w.is_contiguous()
|
973 |
+
assert m_sizes.is_contiguous()
|
974 |
+
|
975 |
+
M_total, N_grad = grad_output.shape
|
976 |
+
N_w, K = w.shape
|
977 |
+
|
978 |
+
# Check dimensions
|
979 |
+
assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
|
980 |
+
|
981 |
+
# Verify that the sum of m_sizes matches M_total
|
982 |
+
sum_m_sizes = m_sizes.sum().item()
|
983 |
+
assert (
|
984 |
+
M_total == sum_m_sizes
|
985 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
986 |
+
|
987 |
+
# Create output tensor (grad_x) with shape [M_total, K]
|
988 |
+
grad_x = torch.empty(
|
989 |
+
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
990 |
+
)
|
991 |
+
|
992 |
+
NUM_SMS = num_sms # CudaUtils.get_num_sms()
|
993 |
+
USE_TMA_LOAD = True
|
994 |
+
USE_TMA_STORE = True
|
995 |
+
|
996 |
+
# Set up TMA descriptors
|
997 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
998 |
+
desc_helper.init_tma_descriptor("grad_output")
|
999 |
+
desc_helper.init_tma_descriptor("w")
|
1000 |
+
desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
|
1001 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
1002 |
+
|
1003 |
+
# Allocate workspace for TMA store
|
1004 |
+
workspace = torch.empty(
|
1005 |
+
NUM_SMS * desc_helper.tma_size,
|
1006 |
+
device=grad_output.device,
|
1007 |
+
dtype=torch.uint8,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
def grid(META):
|
1011 |
+
# Fill TMA descriptors with appropriate dimensions
|
1012 |
+
desc_helper.fill_2d_tma_descriptor(
|
1013 |
+
"grad_output",
|
1014 |
+
grad_output.data_ptr(),
|
1015 |
+
M_total,
|
1016 |
+
N_grad,
|
1017 |
+
META["BLOCK_SIZE_M"],
|
1018 |
+
META["BLOCK_SIZE_N"],
|
1019 |
+
grad_output.element_size(),
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
desc_helper.fill_2d_tma_descriptor(
|
1023 |
+
"w",
|
1024 |
+
w.data_ptr(),
|
1025 |
+
N_w,
|
1026 |
+
K,
|
1027 |
+
META["BLOCK_SIZE_N"],
|
1028 |
+
META["BLOCK_SIZE_K"],
|
1029 |
+
w.element_size(),
|
1030 |
+
)
|
1031 |
+
return (NUM_SMS,)
|
1032 |
+
|
1033 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1034 |
+
|
1035 |
+
# Launch the flat linear kernel for computing grad_x
|
1036 |
+
_kernel_mg_dx_tma[grid](
|
1037 |
+
desc_grad_output,
|
1038 |
+
desc_w,
|
1039 |
+
grad_x,
|
1040 |
+
workspace,
|
1041 |
+
m_sizes,
|
1042 |
+
G,
|
1043 |
+
M_BUCKET,
|
1044 |
+
N_grad, # N dimension is now the reduction dimension
|
1045 |
+
K,
|
1046 |
+
NUM_SMS,
|
1047 |
+
USE_TMA_LOAD,
|
1048 |
+
USE_TMA_STORE,
|
1049 |
+
TMA_SIZE=tma_size,
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
return grad_x
|
1053 |
+
|
1054 |
+
|
1055 |
+
# ======== dw wrapper function ==========
|
1056 |
+
|
1057 |
+
|
1058 |
+
def grouped_gemm_dw_tma(
|
1059 |
+
x: torch.Tensor,
|
1060 |
+
grad_output: torch.Tensor,
|
1061 |
+
m_sizes: torch.Tensor,
|
1062 |
+
num_sms: int = 132,
|
1063 |
+
tma_size: int = 128,
|
1064 |
+
) -> torch.Tensor:
|
1065 |
+
"""
|
1066 |
+
Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
|
1067 |
+
For the forward pass Y = X @ W.T, the backward for weights is:
|
1068 |
+
grad_W = grad_Y.T @ X
|
1069 |
+
|
1070 |
+
Args:
|
1071 |
+
x: Input tensor, shape [M_total, K]
|
1072 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1073 |
+
m_sizes: Group sizes tensor, shape [G]
|
1074 |
+
tma_size: Size of TMA descriptor in bytes
|
1075 |
+
|
1076 |
+
|
1077 |
+
Returns:
|
1078 |
+
grad_w: Gradient with respect to weights, shape [N, K]
|
1079 |
+
"""
|
1080 |
+
# Check TMA support
|
1081 |
+
has_tma_support = CudaUtils.verify_tma()
|
1082 |
+
|
1083 |
+
# Get group count
|
1084 |
+
G = m_sizes.shape[0]
|
1085 |
+
|
1086 |
+
# Ensure contiguous tensors
|
1087 |
+
x = x.contiguous()
|
1088 |
+
grad_output = grad_output.contiguous()
|
1089 |
+
m_sizes = m_sizes.contiguous()
|
1090 |
+
|
1091 |
+
# Get dimensions
|
1092 |
+
M_total, K_x = x.shape
|
1093 |
+
M_grad, N = grad_output.shape
|
1094 |
+
|
1095 |
+
# Check dimensions
|
1096 |
+
assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
|
1097 |
+
|
1098 |
+
# Verify that the sum of m_sizes matches M_total
|
1099 |
+
sum_m_sizes = m_sizes.sum().item()
|
1100 |
+
assert (
|
1101 |
+
sum_m_sizes == M_total
|
1102 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
1103 |
+
|
1104 |
+
# Create output tensor (grad_w) with shape [N, K]
|
1105 |
+
grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
|
1106 |
+
|
1107 |
+
NUM_SMS = num_sms
|
1108 |
+
|
1109 |
+
# TODO - hardcoded for now...but should set TMA flags based on hardware support
|
1110 |
+
USE_TMA_LOAD = True # has_tma_support
|
1111 |
+
USE_TMA_STORE = True # has_tma_support
|
1112 |
+
|
1113 |
+
# Set up TMA descriptors or direct pointers
|
1114 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1115 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
1116 |
+
|
1117 |
+
if USE_TMA_LOAD:
|
1118 |
+
desc_helper.init_tma_descriptor("x")
|
1119 |
+
desc_helper.init_tma_descriptor("grad_output")
|
1120 |
+
x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
|
1121 |
+
grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
|
1122 |
+
"grad_output"
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
x_desc = x
|
1126 |
+
grad_output_desc = grad_output
|
1127 |
+
|
1128 |
+
if USE_TMA_STORE:
|
1129 |
+
desc_helper.init_tma_descriptor("grad_w")
|
1130 |
+
workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
|
1131 |
+
else:
|
1132 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1133 |
+
else:
|
1134 |
+
# If not using TMA, just use the tensors directly
|
1135 |
+
x_desc = x
|
1136 |
+
grad_output_desc = grad_output
|
1137 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1138 |
+
|
1139 |
+
# M_BUCKET for grid size
|
1140 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1141 |
+
|
1142 |
+
# Define grid for kernel launch
|
1143 |
+
def grid(META):
|
1144 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1145 |
+
|
1146 |
+
if USE_TMA_LOAD:
|
1147 |
+
desc_helper.fill_2d_tma_descriptor(
|
1148 |
+
"x",
|
1149 |
+
x.data_ptr(),
|
1150 |
+
M_total,
|
1151 |
+
K_x,
|
1152 |
+
META["BLOCK_SIZE_M"],
|
1153 |
+
META["BLOCK_SIZE_K"],
|
1154 |
+
x.element_size(),
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
desc_helper.fill_2d_tma_descriptor(
|
1158 |
+
"grad_output",
|
1159 |
+
grad_output.data_ptr(),
|
1160 |
+
M_total,
|
1161 |
+
N,
|
1162 |
+
META["BLOCK_SIZE_M"],
|
1163 |
+
META["BLOCK_SIZE_N"],
|
1164 |
+
grad_output.element_size(),
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
if USE_TMA_STORE:
|
1168 |
+
desc_helper.fill_2d_tma_descriptor(
|
1169 |
+
"grad_w",
|
1170 |
+
grad_w.data_ptr(),
|
1171 |
+
N,
|
1172 |
+
K_x,
|
1173 |
+
META["BLOCK_SIZE_N"],
|
1174 |
+
META["BLOCK_SIZE_K"],
|
1175 |
+
grad_w.element_size(),
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
# Return grid size - one block per SM for balanced work distribution
|
1179 |
+
return (NUM_SMS,)
|
1180 |
+
|
1181 |
+
# Launch the optimized kernel
|
1182 |
+
_kernel_mg_dw_tma[grid](
|
1183 |
+
x_desc,
|
1184 |
+
grad_output_desc,
|
1185 |
+
grad_w,
|
1186 |
+
workspace,
|
1187 |
+
m_sizes,
|
1188 |
+
G,
|
1189 |
+
M_BUCKET,
|
1190 |
+
N,
|
1191 |
+
K_x,
|
1192 |
+
NUM_SMS,
|
1193 |
+
USE_TMA_LOAD,
|
1194 |
+
USE_TMA_STORE,
|
1195 |
+
TMA_SIZE=tma_size,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
return grad_w
|
1199 |
+
|
1200 |
+
|
1201 |
+
# ======== End Backwards Wrapper Functions =============
|
1202 |
+
|
1203 |
+
# ======== PyTorch wrapper functions ========
|
1204 |
+
|
1205 |
+
|
1206 |
+
class GroupedGEMM_mg(torch.autograd.Function):
|
1207 |
+
"""
|
1208 |
+
Autograd function for GroupedGEMM with M*G grouping.
|
1209 |
+
Supports both standard and FP8 quantized operations.
|
1210 |
+
"""
|
1211 |
+
|
1212 |
+
@staticmethod
|
1213 |
+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
|
1214 |
+
"""
|
1215 |
+
Forward pass of GroupedGEMM.
|
1216 |
+
|
1217 |
+
Args:
|
1218 |
+
x: Input tensor, shape [M_total, K]
|
1219 |
+
w: Weight tensor, shape [N, K]
|
1220 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1221 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1222 |
+
tma_size: Size of TMA descriptor in bytes
|
1223 |
+
using_fp8: Whether to use FP8 quantization
|
1224 |
+
|
1225 |
+
Returns:
|
1226 |
+
Output tensor, shape [M_total, N]
|
1227 |
+
"""
|
1228 |
+
|
1229 |
+
# Use regular forward without quantization
|
1230 |
+
output = grouped_gemm_forward(
|
1231 |
+
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
# Save inputs and parameters for backward pass
|
1235 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1236 |
+
ctx.use_tma = use_tma
|
1237 |
+
ctx.tma_size = tma_size
|
1238 |
+
|
1239 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1240 |
+
|
1241 |
+
return output
|
1242 |
+
|
1243 |
+
@staticmethod
|
1244 |
+
def backward(ctx, grad_output):
|
1245 |
+
"""
|
1246 |
+
Backward pass of M*G GroupedGEMM.
|
1247 |
+
|
1248 |
+
Args:
|
1249 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1250 |
+
|
1251 |
+
Returns:
|
1252 |
+
Tuple of gradients:
|
1253 |
+
- grad_x: Gradient with respect to x, shape [M_total, K]
|
1254 |
+
- grad_w: Gradient with respect to w, shape [N, K]
|
1255 |
+
- None: Gradient with respect to m_sizes (not differentiable)
|
1256 |
+
- None: Gradient with respect to use_tma (not differentiable)
|
1257 |
+
- None: Gradient with respect to tma_size (not differentiable)
|
1258 |
+
|
1259 |
+
"""
|
1260 |
+
# Retrieve saved tensors and parameters
|
1261 |
+
|
1262 |
+
x, w, m_sizes = ctx.saved_tensors
|
1263 |
+
|
1264 |
+
use_tma = ctx.use_tma
|
1265 |
+
tma_size = ctx.tma_size
|
1266 |
+
|
1267 |
+
# Compute gradients using the unified implementation
|
1268 |
+
grad_x, grad_w = grouped_gemm_backward(
|
1269 |
+
grad_output=grad_output,
|
1270 |
+
x=x,
|
1271 |
+
w=w,
|
1272 |
+
m_sizes=m_sizes,
|
1273 |
+
use_tma=use_tma,
|
1274 |
+
tma_size=tma_size,
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
# Return gradients for all inputs (None for non-differentiable parameters)
|
1278 |
+
return grad_x, grad_w, None, None
|
1279 |
+
|
1280 |
+
|
1281 |
+
def mg_grouped_gemm(
|
1282 |
+
x: torch.Tensor,
|
1283 |
+
w: torch.Tensor,
|
1284 |
+
m_sizes: torch.Tensor,
|
1285 |
+
use_tma: bool = True,
|
1286 |
+
tma_size: int = 128,
|
1287 |
+
using_fp8: bool = False,
|
1288 |
+
) -> torch.Tensor:
|
1289 |
+
"""
|
1290 |
+
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
1291 |
+
Supports both standard precision and FP8 quantized operations.
|
1292 |
+
|
1293 |
+
Args:
|
1294 |
+
x: Input tensor, shape [M_total, K]
|
1295 |
+
w: Weight tensor, shape [N, K]
|
1296 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1297 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1298 |
+
tma_size: Size of TMA descriptor in bytes
|
1299 |
+
using_fp8: Whether to use FP8 quantization
|
1300 |
+
|
1301 |
+
Returns:
|
1302 |
+
Output tensor, shape [M_total, N]
|
1303 |
+
"""
|
1304 |
+
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
9 |
+
|
10 |
+
# pyre-unsafe
|
11 |
+
import functools
|
12 |
+
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from typing import Any, Dict, Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
import triton
|
20 |
+
import triton.language as tl
|
21 |
+
from triton import Config as TConfig
|
22 |
+
|
23 |
+
from triton.runtime import driver # @manual
|
24 |
+
|
25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
26 |
+
|
27 |
+
|
28 |
+
# ===== Supporting utils, CUDA and TMA =====
|
29 |
+
|
30 |
+
|
31 |
+
class CudaUtils:
|
32 |
+
@staticmethod
|
33 |
+
def is_cuda() -> bool:
|
34 |
+
"""Check if Triton is running on CUDA backend."""
|
35 |
+
return driver.active.get_current_target().backend == "cuda"
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def verify_tma() -> bool:
|
39 |
+
"""Check if TMA is supported on the current device."""
|
40 |
+
return (
|
41 |
+
CudaUtils.is_cuda()
|
42 |
+
and torch.cuda.is_available()
|
43 |
+
and torch.cuda.get_device_capability()[0] >= 9
|
44 |
+
)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def get_num_sms() -> int:
|
48 |
+
"""Get the number of streaming multiprocessors on the current device."""
|
49 |
+
if not CudaUtils.is_cuda():
|
50 |
+
raise RuntimeError("Triton is not running on CUDA backend")
|
51 |
+
if not torch.cuda.is_available():
|
52 |
+
raise RuntimeError("CUDA is not available")
|
53 |
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
54 |
+
|
55 |
+
|
56 |
+
class TmaDescriptorHelper:
|
57 |
+
"""Helper class for managing TMA descriptors in Triton kernels."""
|
58 |
+
|
59 |
+
class KernelParamWrapper:
|
60 |
+
"""Wrapper to implement the TmaDescKernelParam interface."""
|
61 |
+
|
62 |
+
def __init__(self, desc: torch.Tensor):
|
63 |
+
self.desc = desc
|
64 |
+
|
65 |
+
def tma_desc_cpu_ptr(self) -> int:
|
66 |
+
"""Return the CPU pointer to the TMA descriptor."""
|
67 |
+
return self.desc.data_ptr()
|
68 |
+
|
69 |
+
def __init__(self, tma_size: int = 128):
|
70 |
+
"""Initialize the TMA descriptor helper.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
tma_size: Size of the TMA descriptor in bytes
|
74 |
+
"""
|
75 |
+
if not CudaUtils.verify_tma():
|
76 |
+
raise RuntimeError(
|
77 |
+
"TMA not supported on this device (requires Hopper or newer)"
|
78 |
+
)
|
79 |
+
if "nv_tma_desc_type" not in dir(tl):
|
80 |
+
raise RuntimeError(
|
81 |
+
"TMA grid constant descriptors not supported in your Triton version"
|
82 |
+
)
|
83 |
+
|
84 |
+
self.tma_size = tma_size
|
85 |
+
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
86 |
+
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
87 |
+
self.descriptors: Dict[str, torch.Tensor] = {}
|
88 |
+
|
89 |
+
def init_tma_descriptor(self, name: str) -> None:
|
90 |
+
"""Initialize a TMA descriptor with the given name.
|
91 |
+
|
92 |
+
Call this method outside of the lambda function for grid size.
|
93 |
+
"""
|
94 |
+
self.descriptors[name] = torch.empty(
|
95 |
+
self.tma_size, device="cpu", dtype=torch.int8
|
96 |
+
)
|
97 |
+
|
98 |
+
def fill_1d_tma_descriptor(
|
99 |
+
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
100 |
+
) -> None:
|
101 |
+
"""Fill a 1D TMA descriptor.
|
102 |
+
|
103 |
+
Call this method inside the lambda function for grid size.
|
104 |
+
"""
|
105 |
+
if name not in self.descriptors:
|
106 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
107 |
+
|
108 |
+
desc_x = self.descriptors[name]
|
109 |
+
if desc_x.data_ptr() % 64 != 0:
|
110 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
111 |
+
self.fill_1d_tma_descriptor_inner(
|
112 |
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
113 |
+
)
|
114 |
+
|
115 |
+
def fill_2d_tma_descriptor(
|
116 |
+
self,
|
117 |
+
name: str,
|
118 |
+
ptr: int,
|
119 |
+
dim1: int,
|
120 |
+
dim0: int,
|
121 |
+
block_dim1: int,
|
122 |
+
block_dim0: int,
|
123 |
+
element_size: int,
|
124 |
+
) -> None:
|
125 |
+
"""Fill a 2D TMA descriptor.
|
126 |
+
|
127 |
+
Call this method inside the lambda function for grid size.
|
128 |
+
"""
|
129 |
+
if name not in self.descriptors:
|
130 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
131 |
+
|
132 |
+
desc_x = self.descriptors[name]
|
133 |
+
if desc_x.data_ptr() % 64 != 0:
|
134 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
135 |
+
self.fill_2d_tma_descriptor_inner(
|
136 |
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
137 |
+
)
|
138 |
+
|
139 |
+
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
140 |
+
"""Get the TMA descriptor kernel parameter for the given name."""
|
141 |
+
if name not in self.descriptors or self.descriptors[name] is None:
|
142 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
143 |
+
return self.KernelParamWrapper(self.descriptors[name])
|
144 |
+
|
145 |
+
|
146 |
+
# ====== Autotuning utilities ======
|
147 |
+
ALIGN_SIZE_M = 128
|
148 |
+
|
149 |
+
_NV_CONFIGS = [
|
150 |
+
triton.Config(
|
151 |
+
{
|
152 |
+
"BLOCK_SIZE_M": block_size_m,
|
153 |
+
"BLOCK_SIZE_N": block_size_n,
|
154 |
+
"BLOCK_SIZE_K": block_size_k,
|
155 |
+
},
|
156 |
+
num_stages=num_stages,
|
157 |
+
num_warps=num_warps,
|
158 |
+
num_ctas=num_ctas,
|
159 |
+
)
|
160 |
+
for block_size_m in [ALIGN_SIZE_M, ]
|
161 |
+
for block_size_n in [64, 128, 256]
|
162 |
+
for block_size_k in [64, 128, 256]
|
163 |
+
for num_stages in [3, 4]
|
164 |
+
for num_warps in [4, 8]
|
165 |
+
for num_ctas in [1]
|
166 |
+
]
|
167 |
+
|
168 |
+
|
169 |
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
170 |
+
device = torch.cuda.current_device()
|
171 |
+
# Check for all possible pointer parameter names
|
172 |
+
if "grad_input_ptr" in named_args:
|
173 |
+
ptr_name = "grad_input_ptr"
|
174 |
+
elif "c_ptr" in named_args:
|
175 |
+
ptr_name = "c_ptr"
|
176 |
+
elif "grad_weight_ptr" in named_args:
|
177 |
+
ptr_name = "grad_weight_ptr"
|
178 |
+
else:
|
179 |
+
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
180 |
+
|
181 |
+
if dtsize is None:
|
182 |
+
dtsize = named_args[ptr_name].element_size()
|
183 |
+
if dtype is None:
|
184 |
+
dtype = named_args[ptr_name].dtype
|
185 |
+
|
186 |
+
pruned_configs = []
|
187 |
+
for config in configs:
|
188 |
+
kw = config.kwargs
|
189 |
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
190 |
+
kw["BLOCK_SIZE_M"],
|
191 |
+
kw["BLOCK_SIZE_N"],
|
192 |
+
kw["BLOCK_SIZE_K"],
|
193 |
+
config.num_stages,
|
194 |
+
)
|
195 |
+
G, M, N, K = (
|
196 |
+
named_args["G"],
|
197 |
+
named_args["M_BUCKET"],
|
198 |
+
named_args["N"],
|
199 |
+
named_args["K"],
|
200 |
+
)
|
201 |
+
|
202 |
+
# 1. make sure we have enough smem
|
203 |
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
204 |
+
"max_shared_mem"
|
205 |
+
]
|
206 |
+
|
207 |
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
208 |
+
if required_shared_memory > max_shared_memory:
|
209 |
+
continue
|
210 |
+
|
211 |
+
M_PER_GROUP = M // G
|
212 |
+
MIN_M_TILES = 64
|
213 |
+
# 2. make sure we don't load M tiles that are too big
|
214 |
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
215 |
+
continue
|
216 |
+
# 3. make sure we don't load N tiles that are too small
|
217 |
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
218 |
+
continue
|
219 |
+
|
220 |
+
num_sm = driver.active.utils.get_device_properties(device)[
|
221 |
+
"multiprocessor_count"
|
222 |
+
]
|
223 |
+
N_TILES = N // BLOCK_N
|
224 |
+
MIN_N_TILES = 64
|
225 |
+
# 4. make sure we don't load N tiles that are too big
|
226 |
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
227 |
+
continue
|
228 |
+
# 5. make sure we don't load N tiles that are too small
|
229 |
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
230 |
+
continue
|
231 |
+
# 6. make sure K can be evenly divided
|
232 |
+
if K % BLOCK_K != 0:
|
233 |
+
continue
|
234 |
+
|
235 |
+
pruned_configs.append(config)
|
236 |
+
|
237 |
+
return pruned_configs
|
238 |
+
|
239 |
+
|
240 |
+
# ======== End Autotuning utilities ========
|
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
torchtitan/experiments/llama4/model/args.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
13 |
+
from torchtitan.config_manager import JobConfig
|
14 |
+
|
15 |
+
from torchtitan.protocols.train_spec import BaseModelArgs
|
16 |
+
from torchtitan.tools.logging import logger
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class TransformerModelArgs(BaseModelArgs):
|
21 |
+
dim: int = 4096
|
22 |
+
n_layers: int = 32
|
23 |
+
n_heads: int = 32
|
24 |
+
n_kv_heads: Optional[int] = None
|
25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
27 |
+
ffn_dim_multiplier: Optional[float] = None
|
28 |
+
norm_eps: float = 1e-5
|
29 |
+
rope_theta: float = 10000
|
30 |
+
|
31 |
+
max_seq_len: int = 2048
|
32 |
+
# If `True`, then each transformer block init uses its layer ID, and if
|
33 |
+
# `False`, each uses the total number of transformer blocks
|
34 |
+
depth_init: bool = True
|
35 |
+
norm_type: str = "rmsnorm"
|
36 |
+
|
37 |
+
use_flex_attn: bool = False
|
38 |
+
attn_mask_type: str = "causal"
|
39 |
+
eos_id: int = 0
|
40 |
+
|
41 |
+
# MoE args
|
42 |
+
moe_enabled: bool = True
|
43 |
+
num_experts: int = 8
|
44 |
+
use_shared_expert: bool = True
|
45 |
+
auto_scale_hidden_dim: bool = True
|
46 |
+
# frequency of using MoE layer instead of feedforward layer in a transformer block
|
47 |
+
interleave_moe_layer_step: int = 2
|
48 |
+
# token-choice
|
49 |
+
top_k: int = 1
|
50 |
+
|
51 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
52 |
+
self.norm_type = job_config.model.norm_type
|
53 |
+
self.vocab_size = tokenizer.n_words
|
54 |
+
self.max_seq_len = job_config.training.seq_len
|
55 |
+
self.use_flex_attn = job_config.model.use_flex_attn
|
56 |
+
|
57 |
+
def get_nparams_and_flops(
|
58 |
+
self, model: nn.Module, seq_len: int
|
59 |
+
) -> tuple[int, float]:
|
60 |
+
nparams_embedding = 0
|
61 |
+
nparams_moe_router = 0
|
62 |
+
nparams_shared_expert = 0
|
63 |
+
nparams_experts = 0
|
64 |
+
nparams_dense = 0
|
65 |
+
|
66 |
+
for name, p in model.named_parameters():
|
67 |
+
if "embedding" in name:
|
68 |
+
nparams_embedding += p.numel()
|
69 |
+
nparams_dense += p.numel()
|
70 |
+
elif "moe.shared_expert" in name:
|
71 |
+
nparams_shared_expert += p.numel()
|
72 |
+
elif "moe.router" in name:
|
73 |
+
nparams_moe_router += p.numel()
|
74 |
+
elif "moe.experts" in name:
|
75 |
+
nparams_experts += p.numel()
|
76 |
+
else:
|
77 |
+
nparams_dense += p.numel()
|
78 |
+
|
79 |
+
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
|
80 |
+
nparams = nparams_dense + nparams_sparse
|
81 |
+
nparams_sparse_active = (
|
82 |
+
nparams_moe_router
|
83 |
+
+ nparams_shared_expert
|
84 |
+
+ nparams_experts * self.top_k // self.num_experts
|
85 |
+
)
|
86 |
+
|
87 |
+
logger.info(
|
88 |
+
f"Total parameter count: dense {nparams_dense:,}, "
|
89 |
+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
|
90 |
+
)
|
91 |
+
|
92 |
+
l, h, q, t = (
|
93 |
+
self.n_layers,
|
94 |
+
self.n_heads,
|
95 |
+
self.dim // self.n_heads,
|
96 |
+
seq_len,
|
97 |
+
)
|
98 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
99 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
100 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
101 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
102 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
103 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
104 |
+
num_flops_per_token = (
|
105 |
+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
|
106 |
+
+ 12 * l * h * q * t
|
107 |
+
)
|
108 |
+
|
109 |
+
return nparams, num_flops_per_token
|
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/bash
|
2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the BSD-style license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
set -ex
|
9 |
+
|
10 |
+
# use envs as local overrides for convenience
|
11 |
+
# e.g.
|
12 |
+
# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
|
13 |
+
NGPU=${NGPU:-"8"}
|
14 |
+
LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
|
15 |
+
CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
|
16 |
+
|
17 |
+
overrides=""
|
18 |
+
if [ $# -ne 0 ]; then
|
19 |
+
overrides="$*"
|
20 |
+
fi
|
21 |
+
|
22 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
23 |
+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
|
24 |
+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
|
25 |
+
convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
|
torchtitan/experiments/multimodal/tests/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
torchtitan/experiments/multimodal/tests/test_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def fixed_init_tensor(
|
16 |
+
shape: torch.Size,
|
17 |
+
min_val: Union[float, int] = 0.0,
|
18 |
+
max_val: Union[float, int] = 1.0,
|
19 |
+
nonlinear: bool = False,
|
20 |
+
dtype: torch.dtype = torch.float,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Utility for generating deterministic tensors of a given shape. In general stuff
|
24 |
+
like torch.ones, torch.eye, etc can result in trivial outputs. This utility
|
25 |
+
generates a range tensor [min_val, max_val) of a specified dtype, applies
|
26 |
+
a sine function if nonlinear=True, then reshapes to the appropriate shape.
|
27 |
+
"""
|
28 |
+
n_elements = math.prod(shape)
|
29 |
+
step_size = (max_val - min_val) / n_elements
|
30 |
+
x = torch.arange(min_val, max_val, step_size, dtype=dtype)
|
31 |
+
x = x.reshape(shape)
|
32 |
+
if nonlinear:
|
33 |
+
return torch.sin(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad
|
38 |
+
def fixed_init_model(
|
39 |
+
model: nn.Module,
|
40 |
+
min_val: Union[float, int] = 0.0,
|
41 |
+
max_val: Union[float, int] = 1.0,
|
42 |
+
nonlinear: bool = False,
|
43 |
+
dtype: Optional[torch.dtype] = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
This utility initializes all parameters of a model deterministically using the
|
47 |
+
function fixed_init_tensor above. See that docstring for details of each parameter.
|
48 |
+
"""
|
49 |
+
for _, param in model.named_parameters():
|
50 |
+
param.copy_(
|
51 |
+
fixed_init_tensor(
|
52 |
+
param.shape,
|
53 |
+
min_val=min_val,
|
54 |
+
max_val=max_val,
|
55 |
+
nonlinear=nonlinear,
|
56 |
+
dtype=param.dtype if dtype is None else dtype,
|
57 |
+
)
|
58 |
+
)
|
torchtitan/experiments/multimodal/tokenizer/tiktoken.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
9 |
+
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import (
|
13 |
+
AbstractSet,
|
14 |
+
Any,
|
15 |
+
cast,
|
16 |
+
Collection,
|
17 |
+
Dict,
|
18 |
+
Iterator,
|
19 |
+
List,
|
20 |
+
Literal,
|
21 |
+
Mapping,
|
22 |
+
Optional,
|
23 |
+
Sequence,
|
24 |
+
Union,
|
25 |
+
)
|
26 |
+
|
27 |
+
import tiktoken
|
28 |
+
import torch
|
29 |
+
from tiktoken.load import load_tiktoken_bpe
|
30 |
+
|
31 |
+
from torchtitan.components.tokenizer import Tokenizer
|
32 |
+
from torchtitan.config_manager import JobConfig
|
33 |
+
from torchtitan.tools.logging import logger
|
34 |
+
|
35 |
+
IMAGE_TOKEN_ID = 128256
|
36 |
+
IGNORE_INDEX = -100
|
37 |
+
|
38 |
+
|
39 |
+
class TikTokenizer(Tokenizer):
|
40 |
+
"""
|
41 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model_path (str): The path to the Tiktoken model file.
|
45 |
+
"""
|
46 |
+
|
47 |
+
special_tokens: Dict[str, int]
|
48 |
+
|
49 |
+
num_reserved_special_tokens = 256
|
50 |
+
|
51 |
+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
|
52 |
+
|
53 |
+
def __init__(self, model_path: str):
|
54 |
+
super().__init__(model_path)
|
55 |
+
assert os.path.isfile(model_path), model_path
|
56 |
+
|
57 |
+
mergeable_ranks = load_tiktoken_bpe(model_path)
|
58 |
+
num_base_tokens = len(mergeable_ranks)
|
59 |
+
special_tokens = [
|
60 |
+
"<|begin_of_text|>",
|
61 |
+
"<|end_of_text|>",
|
62 |
+
"<|reserved_special_token_0|>",
|
63 |
+
"<|reserved_special_token_1|>",
|
64 |
+
"<|reserved_special_token_2|>",
|
65 |
+
"<|reserved_special_token_3|>",
|
66 |
+
"<|start_header_id|>",
|
67 |
+
"<|end_header_id|>",
|
68 |
+
"<|reserved_special_token_4|>",
|
69 |
+
"<|eot_id|>", # end of turn
|
70 |
+
] + [
|
71 |
+
f"<|reserved_special_token_{i}|>"
|
72 |
+
for i in range(5, self.num_reserved_special_tokens - 5)
|
73 |
+
]
|
74 |
+
self.special_tokens = {
|
75 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
76 |
+
}
|
77 |
+
self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
|
78 |
+
self.model = tiktoken.Encoding(
|
79 |
+
name=Path(model_path).name,
|
80 |
+
pat_str=self.pat_str,
|
81 |
+
mergeable_ranks=mergeable_ranks,
|
82 |
+
special_tokens=self.special_tokens,
|
83 |
+
)
|
84 |
+
|
85 |
+
self._n_words: int = self.model.n_vocab
|
86 |
+
# BOS / EOS token IDs
|
87 |
+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
88 |
+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
89 |
+
self.pad_id: int = -1
|
90 |
+
self.image_id = IMAGE_TOKEN_ID
|
91 |
+
self.stop_tokens = {
|
92 |
+
self.special_tokens["<|end_of_text|>"],
|
93 |
+
self.special_tokens["<|eot_id|>"],
|
94 |
+
}
|
95 |
+
logger.info(
|
96 |
+
f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
|
97 |
+
)
|
98 |
+
|
99 |
+
def encode(
|
100 |
+
self,
|
101 |
+
s: str,
|
102 |
+
*,
|
103 |
+
bos: bool,
|
104 |
+
eos: bool,
|
105 |
+
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
106 |
+
disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
|
107 |
+
) -> List[int]:
|
108 |
+
"""
|
109 |
+
Encodes a string into a list of token IDs.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
s (str): The input string to be encoded.
|
113 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
114 |
+
eos (bool): Whether to append the end-of-sequence token.
|
115 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
116 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
list[int]: A list of token IDs.
|
120 |
+
|
121 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
122 |
+
special tokens. Specifically:
|
123 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
124 |
+
to special tokens to be encoded as natural text (insteading of raising
|
125 |
+
an error).
|
126 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
127 |
+
to special tokens to be encoded as special tokens.
|
128 |
+
"""
|
129 |
+
assert type(s) is str
|
130 |
+
allowed_special = allowed_special or set()
|
131 |
+
disallowed_special = disallowed_special or ()
|
132 |
+
|
133 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
134 |
+
# pyo3_runtime.PanicException.
|
135 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
136 |
+
|
137 |
+
# https://github.com/openai/tiktoken/issues/195
|
138 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
139 |
+
# of max consecutive non-whitespace or whitespace characters.
|
140 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
141 |
+
|
142 |
+
substrs = (
|
143 |
+
substr
|
144 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
145 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
146 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
147 |
+
)
|
148 |
+
)
|
149 |
+
t: List[int] = []
|
150 |
+
for substr in substrs:
|
151 |
+
t.extend(
|
152 |
+
self.model.encode(
|
153 |
+
substr,
|
154 |
+
allowed_special=allowed_special,
|
155 |
+
disallowed_special=disallowed_special,
|
156 |
+
)
|
157 |
+
)
|
158 |
+
if bos:
|
159 |
+
t.insert(0, self.bos_id)
|
160 |
+
if eos:
|
161 |
+
t.append(self.eos_id)
|
162 |
+
return t
|
163 |
+
|
164 |
+
def decode(self, t: Sequence[int]) -> str:
|
165 |
+
"""
|
166 |
+
Decodes a list of token IDs into a string.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
t (List[int]): The list of token IDs to be decoded.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: The decoded string.
|
173 |
+
"""
|
174 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
175 |
+
return self.model.decode(cast(List[int], t))
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def _split_whitespaces_or_nonwhitespaces(
|
179 |
+
s: str, max_consecutive_slice_len: int
|
180 |
+
) -> Iterator[str]:
|
181 |
+
"""
|
182 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
183 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
184 |
+
"""
|
185 |
+
current_slice_len = 0
|
186 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
187 |
+
slice_start = 0
|
188 |
+
|
189 |
+
for i in range(len(s)):
|
190 |
+
is_now_space = s[i].isspace()
|
191 |
+
|
192 |
+
if current_slice_is_space ^ is_now_space:
|
193 |
+
current_slice_len = 1
|
194 |
+
current_slice_is_space = is_now_space
|
195 |
+
else:
|
196 |
+
current_slice_len += 1
|
197 |
+
if current_slice_len > max_consecutive_slice_len:
|
198 |
+
yield s[slice_start:i]
|
199 |
+
slice_start = i
|
200 |
+
current_slice_len = 1
|
201 |
+
yield s[slice_start:]
|
202 |
+
|
203 |
+
def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
|
204 |
+
"""
|
205 |
+
Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
|
206 |
+
"""
|
207 |
+
# TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
|
208 |
+
# For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
|
209 |
+
# & everything else expects `tokens`
|
210 |
+
text = sample["text"]
|
211 |
+
tokens = self.encode(
|
212 |
+
text, bos=True, eos=True, allowed_special=set(["<|image|>"])
|
213 |
+
)
|
214 |
+
input_ids = torch.LongTensor(tokens[:-1])
|
215 |
+
labels = torch.LongTensor(tokens[1:])
|
216 |
+
labels = torch.where(
|
217 |
+
torch.isin(
|
218 |
+
labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
|
219 |
+
),
|
220 |
+
IGNORE_INDEX,
|
221 |
+
labels,
|
222 |
+
)
|
223 |
+
|
224 |
+
assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
|
225 |
+
|
226 |
+
sample.update({"tokens": input_ids, "labels": labels})
|
227 |
+
|
228 |
+
return sample
|
229 |
+
|
230 |
+
|
231 |
+
def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
|
232 |
+
return TikTokenizer(job_config.model.tokenizer_path)
|
torchtitan/experiments/multimodal/utils.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import List, Optional, Set, Tuple, Union
|
13 |
+
from urllib import request
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchvision
|
17 |
+
from torchvision.transforms.v2 import functional as F
|
18 |
+
|
19 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
|
20 |
+
def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
image (torch.Tensor): Input image to crop into tiles.
|
26 |
+
tile_size (int): Size of each tile.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
|
30 |
+
|
31 |
+
Examples:
|
32 |
+
>>> image = torch.rand(3, 200, 300)
|
33 |
+
>>> tiles = tile_crop(image, tile_size=50)
|
34 |
+
>>> tiles.shape # 4x6 = 24 tiles
|
35 |
+
torch.Size([24, 3, 50, 50])
|
36 |
+
|
37 |
+
>>> image = torch.rand(3, 400, 600)
|
38 |
+
>>> tiles = tile_crop(image, tile_size=200)
|
39 |
+
>>> tiles.shape # 2x3 = 6 tiles
|
40 |
+
torch.Size([6, 3, 200, 200])
|
41 |
+
"""
|
42 |
+
|
43 |
+
channel_size, height, width = image.shape
|
44 |
+
|
45 |
+
# assert sizes are divisible
|
46 |
+
assert (
|
47 |
+
height % tile_size == 0 and width % tile_size == 0
|
48 |
+
), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
|
49 |
+
|
50 |
+
# Reshape to split height and width into tile_size blocks
|
51 |
+
tiles_height = height // tile_size
|
52 |
+
tiles_width = width // tile_size
|
53 |
+
|
54 |
+
reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
|
55 |
+
|
56 |
+
# Transpose to bring tiles together
|
57 |
+
# We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
|
58 |
+
transposed = reshaped.permute(1, 3, 0, 2, 4)
|
59 |
+
|
60 |
+
# Flatten the tiles
|
61 |
+
tiles = transposed.contiguous().view(
|
62 |
+
tiles_height * tiles_width, channel_size, tile_size, tile_size
|
63 |
+
)
|
64 |
+
|
65 |
+
return tiles
|
66 |
+
|
67 |
+
|
68 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
69 |
+
def resize_with_pad(
|
70 |
+
image: torch.Tensor,
|
71 |
+
target_size: Tuple[int, int],
|
72 |
+
resample: torchvision.transforms.InterpolationMode,
|
73 |
+
max_size: Optional[int] = None,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Resizes and pads an image to target_size without causing distortion.
|
77 |
+
The user can set max_size to limit upscaling when target_size exceeds image_size.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
81 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
82 |
+
resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
|
83 |
+
Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
|
84 |
+
InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
|
85 |
+
max_size (Optional[int]): The maximum size to upscale the image to.
|
86 |
+
If None, will upscale up to target_size.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: The resized and padded image tensor in the format [..., H, W].
|
90 |
+
|
91 |
+
Examples:
|
92 |
+
|
93 |
+
Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
|
94 |
+
and then padded from (448, 1194) to (448, 1344).
|
95 |
+
|
96 |
+
>>> max_size = None
|
97 |
+
>>> image = torch.rand([3, 300, 800])
|
98 |
+
>>> target_size = (448, 1344)
|
99 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
100 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
101 |
+
|
102 |
+
Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
|
103 |
+
|
104 |
+
>>> max_size = 600
|
105 |
+
>>> image = torch.rand([3, 300, 800])
|
106 |
+
>>> target_size = (448, 1344)
|
107 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
108 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
109 |
+
|
110 |
+
Example 3: The image will be downscaled from (500, 1000) to (224, 448),
|
111 |
+
and padded from (224, 448) to (448, 448).
|
112 |
+
|
113 |
+
>>> max_size = 600
|
114 |
+
>>> image = torch.rand([3, 500, 1000])
|
115 |
+
>>> target_size = (448, 488)
|
116 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
117 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
118 |
+
|
119 |
+
"""
|
120 |
+
|
121 |
+
image_height, image_width = image.shape[-2:]
|
122 |
+
image_size = (image_height, image_width)
|
123 |
+
|
124 |
+
# If target_size requires upscaling, we might want to limit the upscaling to max_size
|
125 |
+
if max_size is not None:
|
126 |
+
new_target_height = min(max(image_height, max_size), target_size[0])
|
127 |
+
new_target_width = min(max(image_width, max_size), target_size[1])
|
128 |
+
target_size_resize = (new_target_height, new_target_width)
|
129 |
+
else:
|
130 |
+
target_size_resize = target_size
|
131 |
+
|
132 |
+
# resize to target_size while preserving aspect ratio
|
133 |
+
new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
|
134 |
+
image_size=image_size,
|
135 |
+
target_size=target_size_resize,
|
136 |
+
)
|
137 |
+
|
138 |
+
image = F.resize(
|
139 |
+
inpt=image,
|
140 |
+
size=list(new_size_preserving_aspect_ratio),
|
141 |
+
interpolation=resample,
|
142 |
+
antialias=True,
|
143 |
+
)
|
144 |
+
|
145 |
+
image = _pad_image_top_left(image=image, target_size=target_size)
|
146 |
+
|
147 |
+
return image
|
148 |
+
|
149 |
+
|
150 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
151 |
+
def _pad_image_top_left(
|
152 |
+
image: torch.Tensor,
|
153 |
+
target_size: Tuple[int, int],
|
154 |
+
) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
Places the image at the top left of the canvas and pads with 0 the right and bottom
|
157 |
+
to fit to the target resolution. If target_size < image_size, it will crop the image.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
161 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
torch.Tensor: The padded image tensor in the format [..., H, W].
|
165 |
+
"""
|
166 |
+
|
167 |
+
image_size = image.shape[-2:]
|
168 |
+
|
169 |
+
height, width = image_size
|
170 |
+
target_height, target_width = target_size
|
171 |
+
|
172 |
+
pad_x = target_width - width
|
173 |
+
pad_y = target_height - height
|
174 |
+
|
175 |
+
padding = [0, 0, pad_x, pad_y]
|
176 |
+
return F.pad(inpt=image, padding=padding)
|
177 |
+
|
178 |
+
|
179 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
180 |
+
def _get_max_res_without_distortion(
|
181 |
+
image_size: Tuple[int, int],
|
182 |
+
target_size: Tuple[int, int],
|
183 |
+
) -> Tuple[int, int]:
|
184 |
+
"""
|
185 |
+
Determines the maximum resolution to which an image can be resized to without distorting its
|
186 |
+
aspect ratio, based on the target resolution.
|
187 |
+
|
188 |
+
For example, if image_size = (200,400) and target_size = (600,800),
|
189 |
+
scale_h = 600/200 = 3
|
190 |
+
scale_w = 800/400 = 2
|
191 |
+
So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
|
192 |
+
|
193 |
+
Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image_size (Tuple[int, int]): The original resolution of the image.
|
197 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into.
|
198 |
+
Returns:
|
199 |
+
Tuple[int, int]: The optimal dimensions to which the image should be resized.
|
200 |
+
Examples:
|
201 |
+
>>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
|
202 |
+
(133, 200)
|
203 |
+
>>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
|
204 |
+
(450, 337)
|
205 |
+
"""
|
206 |
+
|
207 |
+
original_height, original_width = image_size
|
208 |
+
target_height, target_width = target_size
|
209 |
+
|
210 |
+
scale_w = target_width / original_width
|
211 |
+
scale_h = target_height / original_height
|
212 |
+
|
213 |
+
if scale_w < scale_h:
|
214 |
+
new_width = target_width
|
215 |
+
new_height = min(math.floor(original_height * scale_w), target_height)
|
216 |
+
else:
|
217 |
+
new_height = target_height
|
218 |
+
new_width = min(math.floor(original_width * scale_h), target_width)
|
219 |
+
|
220 |
+
return new_height, new_width
|
221 |
+
|
222 |
+
|
223 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
224 |
+
def _get_factors(n: int) -> Set[int]:
|
225 |
+
"""
|
226 |
+
Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
n (int): The number to find factors for.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
set: A set containing all factors of the number.
|
233 |
+
|
234 |
+
Examples:
|
235 |
+
>>> _get_factors(n=12)
|
236 |
+
{1, 2, 3, 4, 6, 12}
|
237 |
+
"""
|
238 |
+
factors_set = set()
|
239 |
+
|
240 |
+
for i in range(1, int(n**0.5) + 1):
|
241 |
+
if n % i == 0:
|
242 |
+
factors_set.add(i)
|
243 |
+
factors_set.add(n // i)
|
244 |
+
return factors_set
|
245 |
+
|
246 |
+
|
247 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
248 |
+
def get_canvas_best_fit(
|
249 |
+
image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
|
250 |
+
) -> Tuple[int, int]:
|
251 |
+
"""
|
252 |
+
Determines the best canvas possible from a list of possible resolutions to
|
253 |
+
resize an image to, without distortion.
|
254 |
+
|
255 |
+
For each possible resolution, calculates the scaling factors for
|
256 |
+
width and height, and selects the smallest one, which is the limiting side.
|
257 |
+
E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
|
258 |
+
then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
|
259 |
+
|
260 |
+
If there are multiple canvases that satisfy the conditions,
|
261 |
+
we pick the one with the lowest area to minimize padding.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
image (torch.Tensor): The image we want to fit into a canvas.
|
265 |
+
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
266 |
+
row represents a possible canvas.
|
267 |
+
resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
|
268 |
+
If False, pick the canvas that minimizes downscaling, including no downscaling at all.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Tuple[int, int]: The best resolution to fit the image into.
|
272 |
+
|
273 |
+
Examples:
|
274 |
+
>>> image = torch.rand(3, 200, 300)
|
275 |
+
>>> possible_resolutions = torch.tensor([
|
276 |
+
... [224, 672],
|
277 |
+
... [672, 224],
|
278 |
+
... [224, 448],
|
279 |
+
... [448, 224],
|
280 |
+
... [224, 224]
|
281 |
+
... ])
|
282 |
+
>>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
|
283 |
+
(224, 448)
|
284 |
+
|
285 |
+
In the example above, we calculate the scaling factors for each possible resolution
|
286 |
+
|
287 |
+
>>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
288 |
+
>>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
289 |
+
>>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
290 |
+
|
291 |
+
Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
|
292 |
+
|
293 |
+
>>> upscaling_options = torch.tensor([1.1200, 1.1200])
|
294 |
+
>>> selected_scale = torch.tensor(1.1200)
|
295 |
+
|
296 |
+
There are two possible options, so we pick the one with the smallest area
|
297 |
+
|
298 |
+
>>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
|
299 |
+
>>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
|
300 |
+
"""
|
301 |
+
|
302 |
+
original_height, original_width = image.shape[-2:]
|
303 |
+
|
304 |
+
# possible resolutions heights/widths
|
305 |
+
target_heights, target_widths = (
|
306 |
+
possible_resolutions[:, 0],
|
307 |
+
possible_resolutions[:, 1],
|
308 |
+
)
|
309 |
+
|
310 |
+
# scaling factors to resize the image without distortion
|
311 |
+
scale_w = target_widths / original_width
|
312 |
+
scale_h = target_heights / original_height
|
313 |
+
|
314 |
+
# get limiting side scaling -> no distortion
|
315 |
+
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
316 |
+
|
317 |
+
# filter only scales that allow upscaling
|
318 |
+
upscaling_options = scales[scales >= 1]
|
319 |
+
if len(upscaling_options) > 0:
|
320 |
+
if resize_to_max_canvas:
|
321 |
+
selected_scale = torch.max(upscaling_options)
|
322 |
+
else:
|
323 |
+
selected_scale = torch.min(upscaling_options)
|
324 |
+
else:
|
325 |
+
# no upscaling possible,
|
326 |
+
# get the minimum downscaling (max scale for scales<1)
|
327 |
+
downscaling_options = scales[scales < 1]
|
328 |
+
selected_scale = torch.max(downscaling_options)
|
329 |
+
|
330 |
+
# get all resolutions that support this scaling factor,
|
331 |
+
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
332 |
+
chosen_canvas = possible_resolutions[scales == selected_scale]
|
333 |
+
|
334 |
+
# if there are multiple resolutions,
|
335 |
+
# get the one with minimum area to reduce padding
|
336 |
+
if len(chosen_canvas) > 1:
|
337 |
+
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
338 |
+
optimal_idx = torch.argmin(areas)
|
339 |
+
optimal_canvas = chosen_canvas[optimal_idx]
|
340 |
+
else:
|
341 |
+
optimal_canvas = chosen_canvas[0]
|
342 |
+
|
343 |
+
return tuple(optimal_canvas.tolist())
|
344 |
+
|
345 |
+
|
346 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
347 |
+
def find_supported_resolutions(
|
348 |
+
max_num_tiles: int, tile_size: int
|
349 |
+
) -> List[Tuple[int, int]]:
|
350 |
+
"""
|
351 |
+
Computes all combinations of resolutions, multiple of tile_size,
|
352 |
+
that contain up to max_num_tiles. Useful for when dividing an image into tiles.
|
353 |
+
|
354 |
+
For example, if we want at most 2 tiles per image, then we can support the
|
355 |
+
following resolutions: (1x1, 1x2, 2x1) * tile_size
|
356 |
+
|
357 |
+
Args:
|
358 |
+
max_num_tiles (int): Maximum number of tiles.
|
359 |
+
tile_size (int): Size of the side of the tile.
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
|
363 |
+
|
364 |
+
Examples:
|
365 |
+
|
366 |
+
>>> max_num_tiles = 4
|
367 |
+
>>> tile_size = 224
|
368 |
+
>>> find_supported_resolutions(max_num_tiles, tile_size)
|
369 |
+
[(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
|
370 |
+
"""
|
371 |
+
|
372 |
+
# create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
|
373 |
+
# example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
|
374 |
+
asp_dict = defaultdict(list)
|
375 |
+
for _tile_size in range(max_num_tiles, 0, -1):
|
376 |
+
factors = sorted(_get_factors(_tile_size))
|
377 |
+
asp_ratios = [(factor, _tile_size // factor) for factor in factors]
|
378 |
+
for height, width in asp_ratios:
|
379 |
+
ratio_float = height / width
|
380 |
+
asp_dict[ratio_float].append((height, width))
|
381 |
+
|
382 |
+
# get the resolutions multiplied by the tile_size
|
383 |
+
possible_resolutions = []
|
384 |
+
for ar, resolution in asp_dict.items():
|
385 |
+
for height, width in resolution:
|
386 |
+
possible_resolutions.append((height * tile_size, width * tile_size))
|
387 |
+
|
388 |
+
return possible_resolutions
|
389 |
+
|
390 |
+
|
391 |
+
# NOTE Copied from torchtune.data._utils.py
|
392 |
+
def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
|
393 |
+
"""
|
394 |
+
Convenience method to load an image in torch.Tensor format from a local file path or remote source.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
image_loc (Union[Path, str]): Local file path or remote source pointing to the image
|
398 |
+
which will be loaded in PIL format.
|
399 |
+
|
400 |
+
Note:
|
401 |
+
If loading an image from a remote source, the function expects the URL provided in ``image_loc``
|
402 |
+
to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
|
403 |
+
|
404 |
+
Raises:
|
405 |
+
ValueError: If the image cannot be loaded from remote source, **or**
|
406 |
+
if the image cannot be opened as a :class:`~torch.Tensor`.
|
407 |
+
|
408 |
+
Examples:
|
409 |
+
>>> # Load from remote source
|
410 |
+
>>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
|
411 |
+
|
412 |
+
>>> # Load from local file path
|
413 |
+
>>> image = load_image(Path("/home/user/bird.jpg"))
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
torch.Tensor: The loaded image.
|
417 |
+
"""
|
418 |
+
|
419 |
+
# If pointing to remote source, try to load to local
|
420 |
+
if isinstance(image_loc, str) and image_loc.startswith("http"):
|
421 |
+
try:
|
422 |
+
image_loc = request.urlopen(image_loc).read()
|
423 |
+
image = torchvision.io.decode_image(
|
424 |
+
torch.frombuffer(image_loc, dtype=torch.uint8),
|
425 |
+
mode="RGB",
|
426 |
+
)
|
427 |
+
except Exception as e:
|
428 |
+
raise ValueError("Failed to load remote image as torch.Tensor") from e
|
429 |
+
|
430 |
+
# Open the local image as a Tensor image
|
431 |
+
else:
|
432 |
+
try:
|
433 |
+
image = torchvision.io.decode_image(image_loc, mode="RGB")
|
434 |
+
except Exception as e:
|
435 |
+
raise ValueError("Failed to load local image as torch.Tensor") from e
|
436 |
+
|
437 |
+
return image
|
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.11 kB). View file
|
|
torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc
ADDED
Binary file (1.14 kB). View file
|
|
torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc
ADDED
Binary file (2.61 kB). View file
|
|
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc
ADDED
Binary file (6.83 kB). View file
|
|
torchtitan/experiments/simple_fsdp/tests/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
torchtitan/experiments/simple_fsdp/tests/test_numerics.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.distributed._composable.fsdp import fully_shard
|
10 |
+
|
11 |
+
from torch.testing._internal.common_fsdp import FSDPTest
|
12 |
+
|
13 |
+
from torchtitan.components.loss import cross_entropy_loss
|
14 |
+
from torchtitan.distributed import ParallelDims
|
15 |
+
from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
|
16 |
+
|
17 |
+
|
18 |
+
class TestSimpleFSDP(FSDPTest):
|
19 |
+
def init_test(self):
|
20 |
+
self.optimizer = torch.optim.Adam
|
21 |
+
self.loss_fn = cross_entropy_loss
|
22 |
+
data_parallel_shard_degree = -1
|
23 |
+
if self.mode == "replicate":
|
24 |
+
self.dp_mesh_dim_names = ("dp_replicate",)
|
25 |
+
data_parallel_replicate_degree = self.world_size
|
26 |
+
elif self.mode == "fully_shard":
|
27 |
+
self.dp_mesh_dim_names = ("dp_shard_cp",)
|
28 |
+
data_parallel_replicate_degree = 1
|
29 |
+
elif self.mode == "hybrid_shard":
|
30 |
+
self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
31 |
+
data_parallel_replicate_degree = self.world_size // 2
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported mode {mode}")
|
34 |
+
|
35 |
+
self.parallel_dims = ParallelDims(
|
36 |
+
dp_shard=data_parallel_shard_degree,
|
37 |
+
dp_replicate=data_parallel_replicate_degree,
|
38 |
+
cp=1,
|
39 |
+
tp=1,
|
40 |
+
pp=1,
|
41 |
+
world_size=self.world_size,
|
42 |
+
enable_loss_parallel=True,
|
43 |
+
)
|
44 |
+
self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda")
|
45 |
+
|
46 |
+
def get_input(self):
|
47 |
+
inputs = torch.randn(8, 8).cuda()
|
48 |
+
labels = torch.randn(8, 8).cuda()
|
49 |
+
model = torch.nn.Linear(8, 8)
|
50 |
+
return model, inputs, labels
|
51 |
+
|
52 |
+
def run_fsdp2(self, model, inputs, labels, epoch=20):
|
53 |
+
fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)])
|
54 |
+
optim = self.optimizer(model.parameters(), lr=1e-4)
|
55 |
+
losses = []
|
56 |
+
for _ in range(epoch):
|
57 |
+
optim.zero_grad()
|
58 |
+
out = model(inputs)
|
59 |
+
loss = self.loss_fn(out, labels)
|
60 |
+
loss.backward()
|
61 |
+
optim.step()
|
62 |
+
losses.append(loss)
|
63 |
+
return losses
|
64 |
+
|
65 |
+
def run_simple_fsdp(self, model, inputs, labels, epoch=20):
|
66 |
+
model = data_parallel(
|
67 |
+
model,
|
68 |
+
device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
|
69 |
+
mode=self.mode,
|
70 |
+
)
|
71 |
+
optim = self.optimizer(model.parameters(), lr=1e-4)
|
72 |
+
losses = []
|
73 |
+
for _ in range(epoch):
|
74 |
+
optim.zero_grad()
|
75 |
+
out = model(inputs)
|
76 |
+
loss = self.loss_fn(out, labels)
|
77 |
+
loss.backward()
|
78 |
+
optim.step()
|
79 |
+
losses.append(loss)
|
80 |
+
return losses
|
81 |
+
|
82 |
+
def test_replicate_convergence(self):
|
83 |
+
# unit test for replicate mode
|
84 |
+
self.mode = "replicate"
|
85 |
+
self.init_test()
|
86 |
+
model, inputs, labels = self.get_input()
|
87 |
+
|
88 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
89 |
+
simple_fsdp_replicate_losses = self.run_simple_fsdp(
|
90 |
+
copy.deepcopy(model), inputs, labels
|
91 |
+
)
|
92 |
+
|
93 |
+
for fsdp2_loss, simple_fsdp_replicate_loss in zip(
|
94 |
+
fsdp2_losses, simple_fsdp_replicate_losses
|
95 |
+
):
|
96 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss)
|
97 |
+
|
98 |
+
def test_fullyshard_convergence(self):
|
99 |
+
# unit test for fully_shard mode
|
100 |
+
self.mode = "fully_shard"
|
101 |
+
self.init_test()
|
102 |
+
model, inputs, labels = self.get_input()
|
103 |
+
|
104 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
105 |
+
simple_fsdp_fullyshard_losses = self.run_simple_fsdp(
|
106 |
+
copy.deepcopy(model), inputs, labels
|
107 |
+
)
|
108 |
+
|
109 |
+
for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
|
110 |
+
fsdp2_losses, simple_fsdp_fullyshard_losses
|
111 |
+
):
|
112 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss)
|
113 |
+
|
114 |
+
def test_hybridshard_convergence(self):
|
115 |
+
# unit test for hybrid_shard mode
|
116 |
+
self.mode = "hybrid_shard"
|
117 |
+
self.init_test()
|
118 |
+
model, inputs, labels = self.get_input()
|
119 |
+
|
120 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
121 |
+
simple_fsdp_hybridshard_losses = self.run_simple_fsdp(
|
122 |
+
copy.deepcopy(model), inputs, labels
|
123 |
+
)
|
124 |
+
|
125 |
+
for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
|
126 |
+
fsdp2_losses, simple_fsdp_hybridshard_losses
|
127 |
+
):
|
128 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss)
|
torchtitan/models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (195 Bytes). View file
|
|
torchtitan/models/__pycache__/norms.cpython-312.pyc
ADDED
Binary file (1.39 kB). View file
|
|
torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.57 kB). View file
|
|
torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc
ADDED
Binary file (15.1 kB). View file
|
|
torchtitan/models/llama3/parallelize_llama.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.distributed._composable.replicate import replicate
|
15 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
16 |
+
checkpoint_wrapper as ptd_checkpoint_wrapper,
|
17 |
+
)
|
18 |
+
|
19 |
+
from torch.distributed.device_mesh import DeviceMesh
|
20 |
+
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
|
21 |
+
from torch.distributed.tensor import Replicate, Shard
|
22 |
+
from torch.distributed.tensor.parallel import (
|
23 |
+
ColwiseParallel,
|
24 |
+
parallelize_module,
|
25 |
+
PrepareModuleInput,
|
26 |
+
RowwiseParallel,
|
27 |
+
SequenceParallel,
|
28 |
+
)
|
29 |
+
|
30 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
31 |
+
from torchtitan.distributed import ParallelDims
|
32 |
+
from torchtitan.tools.logging import logger
|
33 |
+
|
34 |
+
|
35 |
+
def parallelize_llama(
|
36 |
+
model: nn.Module,
|
37 |
+
world_mesh: DeviceMesh,
|
38 |
+
parallel_dims: ParallelDims,
|
39 |
+
job_config: JobConfig,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
43 |
+
parallelism to the model.
|
44 |
+
|
45 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
46 |
+
the model must fit on GPU or CPU memory.
|
47 |
+
"""
|
48 |
+
|
49 |
+
if parallel_dims.tp_enabled:
|
50 |
+
if (
|
51 |
+
job_config.parallelism.enable_async_tensor_parallel
|
52 |
+
and not job_config.training.compile
|
53 |
+
):
|
54 |
+
raise RuntimeError("Async TP requires --training.compile")
|
55 |
+
|
56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
57 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
58 |
+
"rowwise",
|
59 |
+
"rowwise_with_gw_hp",
|
60 |
+
)
|
61 |
+
|
62 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
63 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
64 |
+
# all-gather happens in high precision.
|
65 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
66 |
+
|
67 |
+
apply_tp(
|
68 |
+
model,
|
69 |
+
world_mesh["tp"],
|
70 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
71 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
72 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
73 |
+
)
|
74 |
+
|
75 |
+
if job_config.model.use_flex_attn:
|
76 |
+
if job_config.activation_checkpoint.mode == "selective":
|
77 |
+
raise ValueError(
|
78 |
+
"FlexAttention is not compatible with selective AC yet. "
|
79 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
80 |
+
)
|
81 |
+
|
82 |
+
if parallel_dims.cp_enabled:
|
83 |
+
raise ValueError(
|
84 |
+
"FlexAttention is not compatible with CP yet. "
|
85 |
+
"We are still working on this."
|
86 |
+
)
|
87 |
+
|
88 |
+
if job_config.activation_checkpoint.mode != "none":
|
89 |
+
apply_ac(model, job_config.activation_checkpoint)
|
90 |
+
|
91 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
92 |
+
if job_config.training.compile:
|
93 |
+
apply_compile(model)
|
94 |
+
|
95 |
+
if (
|
96 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
97 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
98 |
+
if parallel_dims.dp_replicate_enabled:
|
99 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
100 |
+
else:
|
101 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
102 |
+
|
103 |
+
apply_fsdp(
|
104 |
+
model,
|
105 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
106 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
107 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
108 |
+
pp_enabled=parallel_dims.pp_enabled,
|
109 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
110 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
111 |
+
)
|
112 |
+
|
113 |
+
if parallel_dims.dp_replicate_enabled:
|
114 |
+
logger.info("Applied HSDP to the model")
|
115 |
+
else:
|
116 |
+
logger.info("Applied FSDP to the model")
|
117 |
+
|
118 |
+
if parallel_dims.cp_enabled:
|
119 |
+
logger.info("Applied Context Parallel to the model")
|
120 |
+
|
121 |
+
if job_config.training.enable_cpu_offload:
|
122 |
+
logger.info("Applied CPU Offloading to the model")
|
123 |
+
elif parallel_dims.dp_replicate_enabled:
|
124 |
+
if world_mesh.ndim > 1:
|
125 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
126 |
+
apply_ddp(
|
127 |
+
model,
|
128 |
+
world_mesh,
|
129 |
+
enable_compile=job_config.training.compile,
|
130 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
131 |
+
)
|
132 |
+
|
133 |
+
return model
|
134 |
+
|
135 |
+
|
136 |
+
def apply_tp(
|
137 |
+
model: nn.Module,
|
138 |
+
tp_mesh: DeviceMesh,
|
139 |
+
loss_parallel: bool,
|
140 |
+
enable_float8_tensorwise_tp: bool,
|
141 |
+
enable_async_tp: bool,
|
142 |
+
):
|
143 |
+
"""Apply tensor parallelism."""
|
144 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
145 |
+
# transformer block's inputs)
|
146 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
147 |
+
# 3. Parallelize the final linear output layer
|
148 |
+
parallelize_module(
|
149 |
+
model,
|
150 |
+
tp_mesh,
|
151 |
+
{
|
152 |
+
"tok_embeddings": RowwiseParallel(
|
153 |
+
input_layouts=Replicate(),
|
154 |
+
output_layouts=Shard(1),
|
155 |
+
),
|
156 |
+
"norm": SequenceParallel(),
|
157 |
+
"output": ColwiseParallel(
|
158 |
+
input_layouts=Shard(1),
|
159 |
+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
|
160 |
+
use_local_output=not loss_parallel,
|
161 |
+
),
|
162 |
+
},
|
163 |
+
)
|
164 |
+
|
165 |
+
# Parallel styles used for transformer block linear weights and their
|
166 |
+
# inputs may be different for float8 linears with tensorwise scaling.
|
167 |
+
if enable_float8_tensorwise_tp:
|
168 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
169 |
+
from torchao.float8.float8_tensor_parallel import (
|
170 |
+
Float8ColwiseParallel,
|
171 |
+
Float8RowwiseParallel,
|
172 |
+
PrepareFloat8ModuleInput,
|
173 |
+
)
|
174 |
+
|
175 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
176 |
+
Float8RowwiseParallel,
|
177 |
+
Float8ColwiseParallel,
|
178 |
+
PrepareFloat8ModuleInput,
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
182 |
+
RowwiseParallel,
|
183 |
+
ColwiseParallel,
|
184 |
+
PrepareModuleInput,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Apply tensor + sequence parallelism to every transformer block
|
188 |
+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
|
189 |
+
# by folding (and unfolding) the batch dimension and the sequence dimension.
|
190 |
+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
|
191 |
+
for layer_id, transformer_block in model.layers.items():
|
192 |
+
layer_plan = {
|
193 |
+
"attention_norm": SequenceParallel(),
|
194 |
+
"attention": prepare_module_input(
|
195 |
+
input_layouts=(Shard(1), None),
|
196 |
+
desired_input_layouts=(Replicate(), None),
|
197 |
+
),
|
198 |
+
"attention.wq": colwise_parallel(),
|
199 |
+
"attention.wk": colwise_parallel(),
|
200 |
+
"attention.wv": colwise_parallel(),
|
201 |
+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
|
202 |
+
"ffn_norm": SequenceParallel(),
|
203 |
+
"feed_forward": prepare_module_input(
|
204 |
+
input_layouts=(Shard(1),),
|
205 |
+
desired_input_layouts=(Replicate(),),
|
206 |
+
),
|
207 |
+
"feed_forward.w1": colwise_parallel(),
|
208 |
+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
|
209 |
+
"feed_forward.w3": colwise_parallel(),
|
210 |
+
}
|
211 |
+
|
212 |
+
parallelize_module(
|
213 |
+
module=transformer_block,
|
214 |
+
device_mesh=tp_mesh,
|
215 |
+
parallelize_plan=layer_plan,
|
216 |
+
)
|
217 |
+
|
218 |
+
if enable_async_tp:
|
219 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
220 |
+
|
221 |
+
torch._inductor.config._micro_pipeline_tp = True
|
222 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
223 |
+
|
224 |
+
logger.info(
|
225 |
+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
|
226 |
+
"Tensor Parallelism to the model"
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
# for selective op activation checkpointing
|
231 |
+
_save_list = {
|
232 |
+
torch.ops.aten.mm.default,
|
233 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
234 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
235 |
+
# for low precision training, it's useful to always save
|
236 |
+
# the result of max, since the absolute maximum is
|
237 |
+
# used to compute the scaling factor for quantization.
|
238 |
+
torch.ops.aten.max.default,
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
|
243 |
+
valid_ac_modes = ("full", "selective")
|
244 |
+
if ac_config.mode not in valid_ac_modes:
|
245 |
+
raise ValueError(
|
246 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
247 |
+
)
|
248 |
+
|
249 |
+
if ac_config.mode == "full":
|
250 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
251 |
+
|
252 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
253 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
254 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
255 |
+
if not use_op_sac and not use_layer_sac:
|
256 |
+
raise ValueError(
|
257 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
258 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
259 |
+
)
|
260 |
+
if use_op_sac:
|
261 |
+
from torch.utils.checkpoint import (
|
262 |
+
CheckpointPolicy,
|
263 |
+
create_selective_checkpoint_contexts,
|
264 |
+
)
|
265 |
+
|
266 |
+
def _get_custom_policy(meta):
|
267 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
268 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
269 |
+
mm_count_key = f"{mode}_mm_count"
|
270 |
+
if func == torch.ops.aten.mm.default:
|
271 |
+
meta[mm_count_key] += 1
|
272 |
+
# Saves output of all compute ops, except every second mm
|
273 |
+
to_save = func in _save_list and not (
|
274 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
275 |
+
)
|
276 |
+
return (
|
277 |
+
CheckpointPolicy.MUST_SAVE
|
278 |
+
if to_save
|
279 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
280 |
+
)
|
281 |
+
|
282 |
+
return _custom_policy
|
283 |
+
|
284 |
+
def selective_checkpointing_context_fn():
|
285 |
+
meta = defaultdict(int)
|
286 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
287 |
+
|
288 |
+
return ptd_checkpoint_wrapper(
|
289 |
+
module,
|
290 |
+
context_fn=selective_checkpointing_context_fn,
|
291 |
+
preserve_rng_state=False,
|
292 |
+
)
|
293 |
+
elif use_layer_sac:
|
294 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
295 |
+
ac_freq = int(ac_config.selective_ac_option)
|
296 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
297 |
+
ptd_checkpoint_wrapper._count += 1
|
298 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
299 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
300 |
+
else:
|
301 |
+
return module
|
302 |
+
|
303 |
+
|
304 |
+
def apply_ac(model: nn.Module, ac_config):
|
305 |
+
"""Apply activation checkpointing to the model."""
|
306 |
+
for layer_id, transformer_block in model.layers.named_children():
|
307 |
+
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
|
308 |
+
model.layers.register_module(layer_id, transformer_block)
|
309 |
+
|
310 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
311 |
+
|
312 |
+
|
313 |
+
def apply_compile(model: nn.Module):
|
314 |
+
"""
|
315 |
+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
|
316 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
317 |
+
"""
|
318 |
+
for layer_id, transformer_block in model.layers.named_children():
|
319 |
+
transformer_block = torch.compile(transformer_block, fullgraph=True)
|
320 |
+
model.layers.register_module(layer_id, transformer_block)
|
321 |
+
|
322 |
+
logger.info("Compiling each TransformerBlock with torch.compile")
|
323 |
+
|
324 |
+
|
325 |
+
def apply_fsdp(
|
326 |
+
model: nn.Module,
|
327 |
+
dp_mesh: DeviceMesh,
|
328 |
+
param_dtype: torch.dtype,
|
329 |
+
reduce_dtype: torch.dtype,
|
330 |
+
pp_enabled: bool,
|
331 |
+
cpu_offload: bool = False,
|
332 |
+
reshard_after_forward_policy: str = "default",
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
Apply data parallelism (via FSDP2) to the model.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
model (nn.Module): The model to apply data parallelism to.
|
339 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
340 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
341 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
342 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
343 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
344 |
+
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
|
345 |
+
Other options: "never", "always".
|
346 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
347 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
348 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
349 |
+
|
350 |
+
"""
|
351 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
352 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
353 |
+
if cpu_offload:
|
354 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
355 |
+
|
356 |
+
for layer_id, transformer_block in model.layers.items():
|
357 |
+
if reshard_after_forward_policy == "always":
|
358 |
+
reshard_after_forward = True
|
359 |
+
elif reshard_after_forward_policy == "never":
|
360 |
+
reshard_after_forward = False
|
361 |
+
elif reshard_after_forward_policy == "default":
|
362 |
+
if pp_enabled:
|
363 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
364 |
+
# all-gathers, which can be expensive and non-overlapped
|
365 |
+
reshard_after_forward = False
|
366 |
+
else:
|
367 |
+
# As an optimization, do not reshard after forward for the last
|
368 |
+
# transformer block since FSDP would prefetch it immediately
|
369 |
+
reshard_after_forward = int(layer_id) < len(model.layers) - 1
|
370 |
+
else:
|
371 |
+
raise ValueError(
|
372 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
373 |
+
)
|
374 |
+
fully_shard(
|
375 |
+
transformer_block,
|
376 |
+
**fsdp_config,
|
377 |
+
reshard_after_forward=reshard_after_forward,
|
378 |
+
)
|
379 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
380 |
+
|
381 |
+
|
382 |
+
def apply_ddp(
|
383 |
+
model: nn.Module,
|
384 |
+
dp_mesh: DeviceMesh,
|
385 |
+
enable_compile: bool,
|
386 |
+
enable_compiled_autograd: bool,
|
387 |
+
):
|
388 |
+
if enable_compile:
|
389 |
+
if enable_compiled_autograd:
|
390 |
+
torch._dynamo.config.optimize_ddp = (
|
391 |
+
"python_reducer_without_compiled_forward"
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
395 |
+
|
396 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
397 |
+
|
398 |
+
logger.info("Applied DDP to the model")
|
torchtitan/models/llama3/train_configs/llama3_70b.toml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchtitan Config.toml
|
2 |
+
# NOTE: this toml config is a preset for 64 A100 GPUs.
|
3 |
+
|
4 |
+
[job]
|
5 |
+
dump_folder = "./outputs"
|
6 |
+
description = "Llama 3 70B training"
|
7 |
+
|
8 |
+
[profiling]
|
9 |
+
enable_profiling = true
|
10 |
+
save_traces_folder = "profile_trace"
|
11 |
+
profile_freq = 100
|
12 |
+
|
13 |
+
[metrics]
|
14 |
+
log_freq = 10
|
15 |
+
enable_tensorboard = true
|
16 |
+
save_tb_folder = "tb"
|
17 |
+
|
18 |
+
[model]
|
19 |
+
name = "llama3"
|
20 |
+
flavor = "70B"
|
21 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
22 |
+
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
|
23 |
+
# converters = "float8"
|
24 |
+
|
25 |
+
[optimizer]
|
26 |
+
name = "AdamW"
|
27 |
+
lr = 1.5e-4
|
28 |
+
eps = 1e-8
|
29 |
+
|
30 |
+
[lr_scheduler]
|
31 |
+
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
|
32 |
+
|
33 |
+
[training]
|
34 |
+
batch_size = 8
|
35 |
+
seq_len = 8192
|
36 |
+
max_norm = 1.0 # grad norm clipping
|
37 |
+
steps = 1000
|
38 |
+
compile = false
|
39 |
+
dataset = "c4"
|
40 |
+
|
41 |
+
[parallelism]
|
42 |
+
data_parallel_replicate_degree = 1
|
43 |
+
data_parallel_shard_degree = -1
|
44 |
+
tensor_parallel_degree = 8 # 8-way TP
|
45 |
+
pipeline_parallel_degree = 1
|
46 |
+
context_parallel_degree = 1
|
47 |
+
|
48 |
+
[checkpoint]
|
49 |
+
enable_checkpoint = false
|
50 |
+
folder = "checkpoint"
|
51 |
+
interval = 500
|
52 |
+
model_weights_only = false
|
53 |
+
export_dtype = "float32"
|
54 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
55 |
+
|
56 |
+
[activation_checkpoint]
|
57 |
+
mode = 'full'
|
58 |
+
|
59 |
+
[float8]
|
60 |
+
enable_fsdp_float8_all_gather = false
|
61 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
62 |
+
filter_fqns = "output"
|
torchtitan/protocols/train_spec.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
8 |
+
|
9 |
+
from abc import abstractmethod
|
10 |
+
from collections.abc import Callable, Mapping
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import Protocol, TypeAlias
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.distributed.pipelining.schedules import _PipelineSchedule
|
17 |
+
|
18 |
+
from torchtitan.components.dataloader import BaseDataLoader
|
19 |
+
from torchtitan.components.ft import FTManager
|
20 |
+
from torchtitan.components.loss import LossFunction
|
21 |
+
from torchtitan.components.lr_scheduler import LRSchedulersContainer
|
22 |
+
from torchtitan.components.metrics import MetricsProcessor
|
23 |
+
from torchtitan.components.optimizer import OptimizersContainer
|
24 |
+
from torchtitan.components.tokenizer import Tokenizer
|
25 |
+
from torchtitan.config_manager import JobConfig
|
26 |
+
|
27 |
+
DeviceType = int | str | torch.device
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class BaseModelArgs:
|
32 |
+
"""All ModelArgs should inherit from this class.
|
33 |
+
|
34 |
+
The only usage of this class is type checking but allows us to extend common
|
35 |
+
arguments to all models in the future.
|
36 |
+
"""
|
37 |
+
|
38 |
+
_enforced: str = "This field is used to enforce all fields have defaults."
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def get_nparams_and_flops(
|
46 |
+
self, model: nn.Module, seq_len: int
|
47 |
+
) -> tuple[int, float]:
|
48 |
+
pass
|
49 |
+
|
50 |
+
|
51 |
+
class ModelProtocol(Protocol):
|
52 |
+
"""Defines the interface for a model class.
|
53 |
+
|
54 |
+
This is used to enforce that all model classes have some methods that are
|
55 |
+
required by the TorchTitan trainer.
|
56 |
+
"""
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_model_args(cls, args: BaseModelArgs) -> nn.Module:
|
60 |
+
...
|
61 |
+
|
62 |
+
|
63 |
+
ParallelizeFunction: TypeAlias = Callable[..., nn.Module]
|
64 |
+
PipeliningFunction: TypeAlias = Callable[
|
65 |
+
..., tuple[_PipelineSchedule, list[nn.Module], bool, bool]
|
66 |
+
]
|
67 |
+
DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader]
|
68 |
+
TokenizerBuilder: TypeAlias = Callable[..., Tokenizer]
|
69 |
+
MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
|
70 |
+
OptimizersBuilder: TypeAlias = Callable[
|
71 |
+
[list[nn.Module], JobConfig, FTManager], OptimizersContainer
|
72 |
+
]
|
73 |
+
LRSchedulersBuilder: TypeAlias = Callable[
|
74 |
+
[OptimizersContainer, JobConfig], LRSchedulersContainer
|
75 |
+
]
|
76 |
+
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class TrainSpec:
|
81 |
+
name: str
|
82 |
+
cls: type[nn.Module]
|
83 |
+
config: Mapping[str, BaseModelArgs]
|
84 |
+
parallelize_fn: ParallelizeFunction
|
85 |
+
pipelining_fn: PipeliningFunction | None
|
86 |
+
build_optimizers_fn: OptimizersBuilder
|
87 |
+
build_lr_schedulers_fn: LRSchedulersBuilder
|
88 |
+
build_dataloader_fn: DataLoaderBuilder
|
89 |
+
build_tokenizer_fn: TokenizerBuilder | None
|
90 |
+
build_loss_fn: LossFunctionBuilder
|
91 |
+
build_metrics_processor_fn: MetricsProcessorBuilder | None = None
|
92 |
+
|
93 |
+
|
94 |
+
_train_specs = {}
|
95 |
+
|
96 |
+
|
97 |
+
def register_train_spec(train_spec: TrainSpec) -> None:
|
98 |
+
global _train_specs
|
99 |
+
if train_spec.name in _train_specs:
|
100 |
+
raise ValueError(f"Model {train_spec.name} is already registered.")
|
101 |
+
|
102 |
+
_train_specs[train_spec.name] = train_spec
|
103 |
+
|
104 |
+
|
105 |
+
def get_train_spec(name: str) -> TrainSpec:
|
106 |
+
global _train_specs
|
107 |
+
if name not in _train_specs:
|
108 |
+
raise ValueError(f"Model {name} is not registered.")
|
109 |
+
return _train_specs[name]
|
110 |
+
|
111 |
+
|
112 |
+
def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None:
|
113 |
+
global _train_specs
|
114 |
+
for name, train_spec in _train_specs.items():
|
115 |
+
_train_specs[name] = func(train_spec)
|
train.sh
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/bash
|
2 |
+
|
3 |
+
params=""
|
4 |
+
if [ $# -ne 0 ]; then
|
5 |
+
params="$*"
|
6 |
+
fi
|
7 |
+
|
8 |
+
# use envs as local params for convenience
|
9 |
+
# e.g.
|
10 |
+
# NNODE=1 NGPU=8 LOG_RANK=0 ./train.sh
|
11 |
+
NNODE=${NNODE:-"1"}
|
12 |
+
NGPU=${NGPU:-"8"}
|
13 |
+
LOG_RANK=${LOG_RANK:-0}
|
14 |
+
|
15 |
+
if [[ -z "${MASTER_ADDR}" ]]; then
|
16 |
+
export MASTER_ADDR="localhost"
|
17 |
+
fi
|
18 |
+
if [[ -z "${MASTER_PORT}" ]]; then
|
19 |
+
export MASTER_PORT="0"
|
20 |
+
fi
|
21 |
+
|
22 |
+
: '
|
23 |
+
Usage:
|
24 |
+
|
25 |
+
bash train.sh -h
|
26 |
+
|
27 |
+
Training a 340M model:
|
28 |
+
|
29 |
+
NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
|
30 |
+
--job.config_file flame/models/fla.toml \
|
31 |
+
--job.dump_folder exp/transformer-340M-10B/batch32.seqlen2048.warmup1024.update1.steps20480.lr3e-4 \
|
32 |
+
--model.config configs/transformer_340M.json \
|
33 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
34 |
+
--optimizer.name AdamW \
|
35 |
+
--optimizer.eps 1e-15 \
|
36 |
+
--optimizer.lr 3e-4 \
|
37 |
+
--lr_scheduler.warmup_steps 1024 \
|
38 |
+
--lr_scheduler.lr_min 0.1 \
|
39 |
+
--lr_scheduler.decay_type cosine \
|
40 |
+
--training.batch_size 32 \
|
41 |
+
--training.seq_len 2048 \
|
42 |
+
--training.gradient_accumulation_steps 1 \
|
43 |
+
--training.steps 20480 \
|
44 |
+
--training.max_norm 1.0 \
|
45 |
+
--training.skip_nan_inf \
|
46 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
47 |
+
--training.dataset_name default \
|
48 |
+
--training.dataset_split train \
|
49 |
+
--training.streaming \
|
50 |
+
--training.num_workers 32 \
|
51 |
+
--training.prefetch_factor 2 \
|
52 |
+
--training.seed 42 \
|
53 |
+
--training.compile \
|
54 |
+
--training.tensor_parallel_degree 1 \
|
55 |
+
--training.disable_loss_parallel \
|
56 |
+
--checkpoint.interval 2048 \
|
57 |
+
--checkpoint.load_step -1 \
|
58 |
+
--metrics.log_freq 1
|
59 |
+
'
|
60 |
+
|
61 |
+
echo "Launching training..."
|
62 |
+
|
63 |
+
set -x
|
64 |
+
path=$(grep -oP '(?<=--job.dump_folder )[^ ]+' <<< "$params")
|
65 |
+
steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
|
66 |
+
config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
|
67 |
+
tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
|
68 |
+
model=$(
|
69 |
+
python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
|
70 |
+
)
|
71 |
+
|
72 |
+
mkdir -p $path
|
73 |
+
cp * $path
|
74 |
+
cp -r configs $path
|
75 |
+
cp -r flame $path
|
76 |
+
cp -r 3rdparty/flash-linear-attention/fla $path
|
77 |
+
cp -r 3rdparty/torchtitan/torchtitan $path
|
78 |
+
|
79 |
+
# for offline systems
|
80 |
+
# export TRANSFORMERS_OFFLINE=1
|
81 |
+
# export HF_DATASETS_OFFLINE=1
|
82 |
+
# export HF_HUB_OFFLINE=1
|
83 |
+
if [ "$date" == "" ]; then
|
84 |
+
date=$(date +%Y%m%d%H%M)
|
85 |
+
fi
|
86 |
+
RUN_NAME="$model-$(basename $path)"
|
87 |
+
RUN_ID="$RUN_NAME-$date"
|
88 |
+
|
89 |
+
export WANDB_RESUME=allow
|
90 |
+
if [[ -z "${WANDB_PROJECT}" ]]; then
|
91 |
+
export WANDB_PROJECT="fla"
|
92 |
+
fi
|
93 |
+
if [[ -z "${WANDB_NAME}" ]]; then
|
94 |
+
export WANDB_NAME="$RUN_NAME"
|
95 |
+
fi
|
96 |
+
if [[ -z "${WANDB_RUN_ID}" ]]; then
|
97 |
+
export WANDB_RUN_ID="$RUN_ID"
|
98 |
+
fi
|
99 |
+
|
100 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
101 |
+
torchrun --nnodes=${NNODE} \
|
102 |
+
--nproc_per_node=${NGPU} \
|
103 |
+
--rdzv_backend c10d \
|
104 |
+
--rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \
|
105 |
+
--local-ranks-filter ${LOG_RANK} \
|
106 |
+
--role rank \
|
107 |
+
--tee 3 \
|
108 |
+
--log-dir $path/logs \
|
109 |
+
-m flame.train \
|
110 |
+
$params
|
111 |
+
|
112 |
+
echo "TRAINING DONE!"
|
113 |
+
echo "Converting the DCP checkpoints to HF format..."
|
114 |
+
|
115 |
+
python -m flame.utils.convert_dcp_to_hf \
|
116 |
+
--path $path \
|
117 |
+
--step $steps \
|
118 |
+
--config $config \
|
119 |
+
--tokenizer $tokenizer
|
120 |
+
|
121 |
+
echo "RUNNING DONE!"
|