Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/config_manager.cpython-312.pyc +0 -0
- flame/__pycache__/data.cpython-312.pyc +0 -0
- flame/components/__init__.py +0 -0
- flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/components/checkpoint.py +59 -0
- flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
- flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
- flame/models/activation_offloading.py +447 -0
- flame/models/fla.toml +67 -0
- flame/models/parallelize_fla.py +550 -0
- flame/models/pipeline_fla.py +162 -0
- flame/tools/__init__.py +0 -0
- flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
- flame/utils/__init__.py +0 -0
- flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
- flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
- flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
- flame/utils/checkpoint.py +50 -0
- flame/utils/convert_dcp_to_hf.py +66 -0
- flame/utils/convert_hf_to_dcp.py +34 -0
- flame/utils/hf_utils.py +77 -0
- tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log +90 -0
- torchtitan/components/__pycache__/float8.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/ft.cpython-312.pyc +0 -0
- torchtitan/components/float8.py +150 -0
- torchtitan/components/ft.py +143 -0
- torchtitan/experiments/deepseek_v3/model_config.py +204 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/deepseek_v3/train.py +142 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/flux_argparser.py +42 -0
- torchtitan/experiments/flux/model/autoencoder.py +388 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/layers.py +286 -0
- torchtitan/experiments/flux/model/model.py +177 -0
- torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py +299 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
flame/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (156 Bytes). View file
|
|
flame/__pycache__/config_manager.cpython-312.pyc
ADDED
Binary file (36.9 kB). View file
|
|
flame/__pycache__/data.cpython-312.pyc
ADDED
Binary file (31.3 kB). View file
|
|
flame/components/__init__.py
ADDED
File without changes
|
flame/components/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (141 Bytes). View file
|
|
flame/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
Binary file (3.21 kB). View file
|
|
flame/components/checkpoint.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from datetime import timedelta
|
9 |
+
from io import BytesIO
|
10 |
+
from typing import Any, Dict, List
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class TrainState(Stateful):
|
18 |
+
step: int = 0
|
19 |
+
skipped_step: int = 0
|
20 |
+
token: int = 0
|
21 |
+
elapsed: timedelta = timedelta(0)
|
22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
24 |
+
log_steps: List[int] = field(default_factory=list)
|
25 |
+
|
26 |
+
def state_dict(self) -> Dict[str, Any]:
|
27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
28 |
+
# to avoid sync overhead in every iteration.
|
29 |
+
global_avg_losses_bytes = BytesIO()
|
30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
31 |
+
global_max_losses_bytes = BytesIO()
|
32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
33 |
+
log_steps_bytes = BytesIO()
|
34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
35 |
+
return {
|
36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
39 |
+
"elapsed": self.elapsed,
|
40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
41 |
+
"global_max_losses": global_max_losses_bytes,
|
42 |
+
"log_steps": log_steps_bytes,
|
43 |
+
}
|
44 |
+
|
45 |
+
def load_state_dict(self, state_dict) -> None:
|
46 |
+
self.step = state_dict["step"].item()
|
47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
48 |
+
self.token = state_dict["token"].item()
|
49 |
+
self.elapsed = state_dict["elapsed"]
|
50 |
+
state_dict["global_avg_losses"].seek(0)
|
51 |
+
self.global_avg_losses = torch.load(
|
52 |
+
state_dict["global_avg_losses"], weights_only=False
|
53 |
+
)
|
54 |
+
state_dict["global_max_losses"].seek(0)
|
55 |
+
self.global_max_losses = torch.load(
|
56 |
+
state_dict["global_max_losses"], weights_only=False
|
57 |
+
)
|
58 |
+
state_dict["log_steps"].seek(0)
|
59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|
flame/models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (137 Bytes). View file
|
|
flame/models/__pycache__/parallelize_fla.cpython-312.pyc
ADDED
Binary file (22.1 kB). View file
|
|
flame/models/__pycache__/pipeline_fla.cpython-312.pyc
ADDED
Binary file (5.75 kB). View file
|
|
flame/models/activation_offloading.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
|
2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the BSD-style license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import contextlib
|
9 |
+
from typing import Union
|
10 |
+
from warnings import warn
|
11 |
+
|
12 |
+
import psutil
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.autograd.graph import saved_tensors_hooks
|
16 |
+
|
17 |
+
from torchtitan.tools.logging import logger
|
18 |
+
|
19 |
+
try:
|
20 |
+
import torchao
|
21 |
+
from torchao.dtypes.nf4tensor import NF4Tensor
|
22 |
+
except ImportError:
|
23 |
+
torchao = None
|
24 |
+
NF4Tensor = None
|
25 |
+
logger.warning("torchao not found. ")
|
26 |
+
|
27 |
+
# from torchtune.modules import TiedLinear
|
28 |
+
|
29 |
+
|
30 |
+
class OffloadActivations(saved_tensors_hooks):
|
31 |
+
"""Context manager under which activation tensors created in the forward pass will be offloaded.
|
32 |
+
|
33 |
+
Enable the memory efficiency technique of activation offloading, where activations bigger than
|
34 |
+
min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
|
35 |
+
This is in contrast to maintaining the activation on GPU VRAM throughout the program.
|
36 |
+
|
37 |
+
This manager contains the option of using one additional CUDA stream to handle the communication
|
38 |
+
between CUDA and CPU, which is intended to overlap with the default computation stream to improve
|
39 |
+
runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
|
40 |
+
runtime vs memory usage.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
|
44 |
+
memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
|
45 |
+
but is a limited resource. Default: True.
|
46 |
+
|
47 |
+
use_streams (bool): Whether or not to use streams for performance optimization where
|
48 |
+
the communications get overlapped with the computation. Requires a torch build
|
49 |
+
after torch-2.5.0.]. Default: True.
|
50 |
+
|
51 |
+
max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
|
52 |
+
consecutive activations to keep alive during the forward pass. This number must be at
|
53 |
+
least 1. Keeping alive more activations will potentially allow more overlap between the
|
54 |
+
communication and compute streams at the cost of increasing memory usage. Keeping alive
|
55 |
+
fewer activations will conserve memory, but may cause poor overlap between the streams,
|
56 |
+
increasing runtime. Default: 5.
|
57 |
+
|
58 |
+
min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
|
59 |
+
for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
|
60 |
+
moving it to CPU and back. Default: 1024 bytes.
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
ValueError: if max_fwd_stash_size is not at least 1.
|
64 |
+
|
65 |
+
Example:
|
66 |
+
>>> with OffloadActivations():
|
67 |
+
>>> logits = model(inputs)
|
68 |
+
>>> loss = ...
|
69 |
+
>>> loss.backward()
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
use_pin_memory: bool = True,
|
75 |
+
use_streams: bool = True,
|
76 |
+
max_fwd_stash_size: int = 5,
|
77 |
+
min_offload_size: int = 1024,
|
78 |
+
) -> None:
|
79 |
+
|
80 |
+
self.use_streams: bool = use_streams
|
81 |
+
|
82 |
+
self.min_tensor_size_bytes = (
|
83 |
+
min_offload_size # we don't want to bother with small tensors
|
84 |
+
)
|
85 |
+
self.tracker = (
|
86 |
+
{}
|
87 |
+
) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
|
88 |
+
self.tensor_id: int = 0
|
89 |
+
self.is_first_forward_call = True
|
90 |
+
self.is_first_backward_call = True
|
91 |
+
self.is_first_forward_pass = True
|
92 |
+
|
93 |
+
# managing cpu memory
|
94 |
+
self.use_pin_memory: bool = use_pin_memory
|
95 |
+
self.virtual_memory_safe_pct = (
|
96 |
+
60 # we should not exceed this percentage of memory
|
97 |
+
)
|
98 |
+
|
99 |
+
self.s0 = torch.cuda.default_stream() # comp stream
|
100 |
+
|
101 |
+
# for streaming
|
102 |
+
if self.use_streams:
|
103 |
+
self.s1 = torch.cuda.Stream() # comms stream
|
104 |
+
self.fwd_stash = {} # tensor_id => (activation, ev1)
|
105 |
+
if max_fwd_stash_size < 1:
|
106 |
+
raise ValueError(
|
107 |
+
f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
|
108 |
+
)
|
109 |
+
self.max_fwd_stash_size = max_fwd_stash_size
|
110 |
+
self.bwd_tensor_stash = {} # tensor_id => activation
|
111 |
+
self.bwd_ev_stash = {} # tensor_id => ev0
|
112 |
+
self.curr_graph_id = None
|
113 |
+
self.curr_autograd_node = None
|
114 |
+
|
115 |
+
# -------- platform util functions -------- #
|
116 |
+
def verify_sufficient_virtual_memory():
|
117 |
+
curr_pct = get_cpu_ram_pct()
|
118 |
+
if curr_pct > self.virtual_memory_safe_pct:
|
119 |
+
warn(
|
120 |
+
f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
|
121 |
+
)
|
122 |
+
|
123 |
+
def get_cpu_ram_pct() -> float:
|
124 |
+
# get the percentage of memory used by the system
|
125 |
+
return psutil.virtual_memory().percent
|
126 |
+
|
127 |
+
def get_tensor_id() -> int:
|
128 |
+
# create a unique id for each tensor we are managing
|
129 |
+
self.tensor_id += 1
|
130 |
+
return self.tensor_id
|
131 |
+
|
132 |
+
def get_num_bytes_tensor(x: torch.Tensor) -> int:
|
133 |
+
# get the number of bytes in a tensor, for memory management purposes
|
134 |
+
return (
|
135 |
+
x.element_size() * x.nelement()
|
136 |
+
) # x.element_size() * x._base_storage().nbytes()
|
137 |
+
|
138 |
+
# -------- core pack / unpack work -------- #
|
139 |
+
def pack_tensor(activation: torch.Tensor) -> int:
|
140 |
+
# activations are passed in during forward pass - from here we take over and return a unique id
|
141 |
+
if self.is_first_forward_call:
|
142 |
+
assert (
|
143 |
+
len(self.tracker) == 0
|
144 |
+
), "backward pass should have cleared tracker of all tensors"
|
145 |
+
|
146 |
+
# set training phase trackers
|
147 |
+
self.is_first_forward_call = False
|
148 |
+
self.is_first_backward_call = True
|
149 |
+
|
150 |
+
# query for basic tensor info
|
151 |
+
num_bytes = get_num_bytes_tensor(activation)
|
152 |
+
tensor_id = get_tensor_id()
|
153 |
+
|
154 |
+
# only offload hefty bois if they're activations on CUDA (our heuristic
|
155 |
+
# for that is to check if they're not params or buffers)!
|
156 |
+
if (
|
157 |
+
activation.is_cuda
|
158 |
+
and num_bytes >= self.min_tensor_size_bytes
|
159 |
+
and (
|
160 |
+
not isinstance(activation, torch.nn.Parameter)
|
161 |
+
and not isinstance(activation, torch.nn.Buffer)
|
162 |
+
)
|
163 |
+
):
|
164 |
+
if self.use_streams:
|
165 |
+
# First, sync back and dereference previously offloaded tensors
|
166 |
+
# as the offloading should be done sufficiently long ago.
|
167 |
+
for id in [k for k in self.fwd_stash.keys()]:
|
168 |
+
if id <= tensor_id - self.max_fwd_stash_size:
|
169 |
+
_, ev = self.fwd_stash[id]
|
170 |
+
self.s0.wait_event(ev)
|
171 |
+
del self.fwd_stash[id]
|
172 |
+
else:
|
173 |
+
break
|
174 |
+
|
175 |
+
# Sync in, offload, and add an event to sync back later
|
176 |
+
self.s1.wait_stream(self.s0)
|
177 |
+
|
178 |
+
stream = self.s1 if self.use_streams else self.s0
|
179 |
+
with torch.cuda.stream(stream):
|
180 |
+
try:
|
181 |
+
cpu_tensor = torch.empty_like(
|
182 |
+
activation, pin_memory=self.use_pin_memory, device="cpu"
|
183 |
+
)
|
184 |
+
except NotImplementedError as e:
|
185 |
+
if (
|
186 |
+
isinstance(activation, NF4Tensor)
|
187 |
+
and torchao.__version__ < "0.6.0.dev20240917"
|
188 |
+
):
|
189 |
+
raise RuntimeError(
|
190 |
+
"Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
|
191 |
+
) from e
|
192 |
+
raise e
|
193 |
+
cpu_tensor.copy_(activation, non_blocking=True)
|
194 |
+
self.tracker[tensor_id] = (
|
195 |
+
cpu_tensor,
|
196 |
+
True,
|
197 |
+
) # True = (in future) modified
|
198 |
+
|
199 |
+
if self.use_streams:
|
200 |
+
event = self.s1.record_event()
|
201 |
+
|
202 |
+
# Stash to keep activation alive til s1 is done
|
203 |
+
self.fwd_stash[tensor_id] = (activation, event)
|
204 |
+
else:
|
205 |
+
self.tracker[tensor_id] = (
|
206 |
+
activation,
|
207 |
+
False,
|
208 |
+
) # False = not modified, tensor is as is
|
209 |
+
|
210 |
+
return tensor_id
|
211 |
+
|
212 |
+
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
|
213 |
+
# backward pass - we are called with the tensor_id, which
|
214 |
+
# we will use to retrieve the saved/offloaded tensor
|
215 |
+
if self.is_first_backward_call:
|
216 |
+
if self.is_first_forward_pass:
|
217 |
+
self.is_first_forward_pass = False
|
218 |
+
if self.use_pin_memory:
|
219 |
+
verify_sufficient_virtual_memory()
|
220 |
+
|
221 |
+
self.is_first_backward_call = False
|
222 |
+
self.is_first_forward_call = True
|
223 |
+
|
224 |
+
assert (
|
225 |
+
unpack_tensor_id in self.tracker
|
226 |
+
), f"untracked tensor with id {unpack_tensor_id}"
|
227 |
+
|
228 |
+
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
|
229 |
+
if modified:
|
230 |
+
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
|
231 |
+
maybe_gpu_tensor = gpu_tensor
|
232 |
+
|
233 |
+
# clear tensor from tracking
|
234 |
+
del self.tracker[unpack_tensor_id]
|
235 |
+
return maybe_gpu_tensor
|
236 |
+
|
237 |
+
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
|
238 |
+
# backward pass - we are called with the tensor_id, which
|
239 |
+
# we will use to retrieve the saved/offloaded tensor
|
240 |
+
if self.is_first_backward_call:
|
241 |
+
self.curr_graph_id = torch._C._current_graph_task_id()
|
242 |
+
|
243 |
+
def wait_and_del_remaining_references() -> None:
|
244 |
+
for id in [k for k in self.bwd_tensor_stash.keys()]:
|
245 |
+
event = self.bwd_ev_stash[id]
|
246 |
+
self.s1.wait_event(event)
|
247 |
+
del self.bwd_tensor_stash[id]
|
248 |
+
|
249 |
+
# Register a callback to the end of autograd to clean everything up
|
250 |
+
torch.autograd.variable.Variable._execution_engine.queue_callback(
|
251 |
+
wait_and_del_remaining_references
|
252 |
+
)
|
253 |
+
|
254 |
+
if self.is_first_forward_pass:
|
255 |
+
self.is_first_forward_pass = False
|
256 |
+
if self.use_pin_memory:
|
257 |
+
verify_sufficient_virtual_memory()
|
258 |
+
|
259 |
+
self.is_first_backward_call = False
|
260 |
+
self.is_first_forward_call = True
|
261 |
+
|
262 |
+
assert (
|
263 |
+
unpack_tensor_id in self.tracker
|
264 |
+
), f"untracked tensor with id {unpack_tensor_id}"
|
265 |
+
|
266 |
+
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
|
267 |
+
if modified:
|
268 |
+
# Get data on the current autograd node
|
269 |
+
graph_id = torch._C._current_graph_task_id()
|
270 |
+
node = torch._C._current_autograd_node()
|
271 |
+
prev_node_ids = []
|
272 |
+
|
273 |
+
# If we're on a new node, mark prev node's tensors to be freed later
|
274 |
+
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
|
275 |
+
self.curr_autograd_node = node
|
276 |
+
prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
|
277 |
+
|
278 |
+
brought_back_from_cpu = True
|
279 |
+
if unpack_tensor_id in self.fwd_stash:
|
280 |
+
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
|
281 |
+
brought_back_from_cpu = False
|
282 |
+
else:
|
283 |
+
# Kick off the process to bring tensors back
|
284 |
+
with torch.cuda.stream(self.s1):
|
285 |
+
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
|
286 |
+
maybe_gpu_tensor = gpu_tensor
|
287 |
+
|
288 |
+
# Tell comp stream to wait for the info to be loaded before executing
|
289 |
+
self.s0.wait_stream(self.s1)
|
290 |
+
|
291 |
+
# Stash the tensor to keep memory alive until compute stream is complete
|
292 |
+
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
|
293 |
+
|
294 |
+
# Note: [Track views of the unpacked]
|
295 |
+
# Why do we get the use count of the unpacked tensor here? We want an
|
296 |
+
# initial count to compare to later, during the post-hook of the
|
297 |
+
# backward node, when we need to decide whether we're allowed to free
|
298 |
+
# the tensor yet. In what obscure cases must we delay freeing the
|
299 |
+
# tensor (and thus call record_stream)?
|
300 |
+
# 1. Any of the outputs of the backward node is a view of the unpacked
|
301 |
+
# tensor.
|
302 |
+
# 2. In the case that this unpacked tensor will be used in a
|
303 |
+
# checkpointed region, if one of the recomputed saved tensors ends
|
304 |
+
# up as a view of the unpacked tensor.
|
305 |
+
# 3. The user abuses the system somehow and manually relies on the
|
306 |
+
# unpacked tensor to exist after the backward node has executed.
|
307 |
+
storage_refcount = torch._C._storage_Use_Count(
|
308 |
+
maybe_gpu_tensor.untyped_storage()._cdata
|
309 |
+
)
|
310 |
+
|
311 |
+
def hook(outputs, inputs):
|
312 |
+
# create events for the current node inputs/outputs if they were streamed in
|
313 |
+
if brought_back_from_cpu:
|
314 |
+
# See Note: [Track views of the unpacked]
|
315 |
+
# IF any of the outputs is a view of the tensor, OR if a view of
|
316 |
+
# the tensor has been saved as a part of checkpoint's recompute
|
317 |
+
# process, OR the user has abusedly incurred a reference on the
|
318 |
+
# unpacked tensor, THEN the tensor might be used later and we
|
319 |
+
# cannot presume to delete it after only the current node is
|
320 |
+
# done! So we use our frenemy, record_stream, to ensure the
|
321 |
+
# Tensor stays unmessed with until it's done getting used in the
|
322 |
+
# compute stream (s0 here). Note that the con here is we introduce
|
323 |
+
# non-deterministic (thus higher) memory usage, but this case
|
324 |
+
# should not happen often.
|
325 |
+
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
|
326 |
+
if (
|
327 |
+
torch._C._storage_Use_Count(
|
328 |
+
unpacked_tensor.untyped_storage()._cdata
|
329 |
+
)
|
330 |
+
> storage_refcount
|
331 |
+
):
|
332 |
+
unpacked_tensor.record_stream(self.s0)
|
333 |
+
del self.bwd_tensor_stash[unpack_tensor_id]
|
334 |
+
else:
|
335 |
+
event = self.s0.record_event()
|
336 |
+
self.bwd_ev_stash[unpack_tensor_id] = event
|
337 |
+
|
338 |
+
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
|
339 |
+
for id in [k for k in self.fwd_stash.keys()]:
|
340 |
+
_, ev = self.fwd_stash[id]
|
341 |
+
self.s0.wait_event(ev)
|
342 |
+
del self.fwd_stash[id]
|
343 |
+
|
344 |
+
# wait on prev node's events and del those
|
345 |
+
for id in prev_node_ids:
|
346 |
+
event = self.bwd_ev_stash[id]
|
347 |
+
self.s1.wait_event(event)
|
348 |
+
del self.bwd_tensor_stash[id]
|
349 |
+
|
350 |
+
return outputs
|
351 |
+
|
352 |
+
node.register_hook(hook)
|
353 |
+
|
354 |
+
# clear tensor from tracking
|
355 |
+
del self.tracker[unpack_tensor_id]
|
356 |
+
return maybe_gpu_tensor
|
357 |
+
|
358 |
+
unpack_tensor = (
|
359 |
+
unpack_tensor_with_streams
|
360 |
+
if self.use_streams
|
361 |
+
else unpack_tensor_single_stream
|
362 |
+
)
|
363 |
+
super().__init__(pack_tensor, unpack_tensor)
|
364 |
+
|
365 |
+
|
366 |
+
class NoOpManager(saved_tensors_hooks):
|
367 |
+
"""
|
368 |
+
A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
|
369 |
+
applied before. This relies on the behavior that only the most recently registered
|
370 |
+
saved_tensors_hook will run.
|
371 |
+
|
372 |
+
One example usage is to opt a local region of code out of activations offloading,
|
373 |
+
which is usually applied globally to best track state.
|
374 |
+
"""
|
375 |
+
|
376 |
+
def __init__(self) -> None:
|
377 |
+
def noop(tensor):
|
378 |
+
return tensor
|
379 |
+
|
380 |
+
super().__init__(noop, noop)
|
381 |
+
|
382 |
+
|
383 |
+
def get_act_offloading_ctx_manager(
|
384 |
+
model: nn.Module, enable_activation_offloading: bool
|
385 |
+
) -> Union[OffloadActivations, contextlib.nullcontext]:
|
386 |
+
"""Returns the activation offloading context manager for the model, which will be
|
387 |
+
a null context if enable_activation_offloading is False.
|
388 |
+
|
389 |
+
If activation offloading is enabled, we return the OffloadActivations context manager.
|
390 |
+
If activation offloading is disabled, we return a NoOpManager context manager.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
model (nn.Module): the model to wrap with the activation offloading context manager.
|
394 |
+
enable_activation_offloading (bool): whether or not to enable activation offloading
|
395 |
+
for the model.
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
contextlib.ContextDecorator: the activation offloading context manager for the model.
|
399 |
+
|
400 |
+
Raises:
|
401 |
+
NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
|
402 |
+
"""
|
403 |
+
if enable_activation_offloading:
|
404 |
+
activations_handling_ctx = OffloadActivations()
|
405 |
+
|
406 |
+
# Below is our hack to disable offloading the last output Linear in every
|
407 |
+
# step, as the cost for offloading the activation and then soon after bringing
|
408 |
+
# it back is expensive. Moreover, due to heuristics in our streaming API,
|
409 |
+
# we actually use more memory if we offload it as it interferes with chunkedCE.
|
410 |
+
output_head_detected = False
|
411 |
+
noop_ctx = NoOpManager()
|
412 |
+
|
413 |
+
if hasattr(model, "output"):
|
414 |
+
if isinstance(model.output, nn.Module):
|
415 |
+
model.output.register_forward_pre_hook(
|
416 |
+
lambda *args: noop_ctx.__enter__()
|
417 |
+
)
|
418 |
+
model.output.register_forward_hook(
|
419 |
+
lambda *args: noop_ctx.__exit__(), always_call=True
|
420 |
+
)
|
421 |
+
print("registering hooks for model.output ============ ")
|
422 |
+
output_head_detected = True
|
423 |
+
# ================================
|
424 |
+
# ! TODO[flame] check if we need to detal with TiedLinear
|
425 |
+
# The following code appears in `torchtune`
|
426 |
+
# elif isinstance(model.output, TiedLinear):
|
427 |
+
# model.output.linear.register_forward_pre_hook(
|
428 |
+
# lambda *args: noop_ctx.__enter__()
|
429 |
+
# )
|
430 |
+
# model.output.linear.register_forward_hook(
|
431 |
+
# lambda *args: noop_ctx.__exit__(), always_call=True
|
432 |
+
# )
|
433 |
+
# output_head_detected = True
|
434 |
+
|
435 |
+
if not output_head_detected:
|
436 |
+
logger.warning(
|
437 |
+
"During activation offloading, no output head was detected. "
|
438 |
+
"If your model has an output head, it will be offloaded. "
|
439 |
+
"This usually greatly slows training, given the large vocabulary size. "
|
440 |
+
"To change this behavior, set your output head as model.output and make it "
|
441 |
+
"an nn.Module."
|
442 |
+
)
|
443 |
+
|
444 |
+
else:
|
445 |
+
activations_handling_ctx = contextlib.nullcontext()
|
446 |
+
|
447 |
+
return activations_handling_ctx
|
flame/models/fla.toml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[model]
|
2 |
+
config = "fla-hub/transformer-1.3B-100B"
|
3 |
+
tokenizer_path = "fla-hub/transformer-1.3B-100B"
|
4 |
+
|
5 |
+
[job]
|
6 |
+
dump_folder = "exp"
|
7 |
+
print_args = true
|
8 |
+
|
9 |
+
[training]
|
10 |
+
batch_size = 32
|
11 |
+
seq_len = 2048
|
12 |
+
context_len = 2048
|
13 |
+
gradient_accumulation_steps = 1
|
14 |
+
steps = 20480
|
15 |
+
max_norm = 1.0
|
16 |
+
skip_nan_inf = true
|
17 |
+
data_parallel_replicate_degree = 1
|
18 |
+
data_parallel_shard_degree = -1
|
19 |
+
tensor_parallel_degree = 1
|
20 |
+
compile = false
|
21 |
+
dataset = "HuggingFaceFW/fineweb-edu"
|
22 |
+
dataset_name = "default"
|
23 |
+
num_workers = 32
|
24 |
+
pin_memory = false
|
25 |
+
persistent_workers = false
|
26 |
+
prefetch_factor = 2
|
27 |
+
seed = 42
|
28 |
+
varlen = false
|
29 |
+
|
30 |
+
[optimizer]
|
31 |
+
name = "AdamW"
|
32 |
+
eps = 1e-15
|
33 |
+
lr = 3e-4
|
34 |
+
|
35 |
+
[lr_scheduler]
|
36 |
+
warmup_steps = 1024
|
37 |
+
decay_type = "cosine"
|
38 |
+
lr_min = 0.1
|
39 |
+
|
40 |
+
[checkpoint]
|
41 |
+
enable_checkpoint = true
|
42 |
+
folder = "checkpoint"
|
43 |
+
interval_type = "steps"
|
44 |
+
interval = 2048
|
45 |
+
model_weights_only = false
|
46 |
+
export_dtype = "float32"
|
47 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
48 |
+
|
49 |
+
[profiling]
|
50 |
+
enable_profiling = true
|
51 |
+
save_traces_folder = "profile_trace"
|
52 |
+
profile_freq = 512
|
53 |
+
|
54 |
+
[metrics]
|
55 |
+
log_freq = 32
|
56 |
+
enable_wandb = true
|
57 |
+
|
58 |
+
[experimental]
|
59 |
+
context_parallel_degree = 1
|
60 |
+
pipeline_parallel_degree = 1
|
61 |
+
|
62 |
+
[float8]
|
63 |
+
enable_fsdp_float8_all_gather = false
|
64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
65 |
+
|
66 |
+
[activation_checkpoint]
|
67 |
+
mode = "none"
|
flame/models/parallelize_fla.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.distributed import DeviceMesh
|
15 |
+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
|
16 |
+
from torch.distributed._composable.replicate import replicate
|
17 |
+
from torch.distributed._tensor import Replicate, Shard
|
18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
|
19 |
+
from torch.distributed.tensor.parallel import (
|
20 |
+
ColwiseParallel,
|
21 |
+
PrepareModuleInput,
|
22 |
+
PrepareModuleOutput,
|
23 |
+
RowwiseParallel,
|
24 |
+
SequenceParallel,
|
25 |
+
parallelize_module
|
26 |
+
)
|
27 |
+
|
28 |
+
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
|
29 |
+
from fla.modules.mlp import SwiGLULinearParallel
|
30 |
+
from fla.modules.parallel import PrepareModuleWeight
|
31 |
+
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
|
32 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
33 |
+
from torchtitan.tools.logging import logger
|
34 |
+
|
35 |
+
|
36 |
+
def parallelize_fla(
|
37 |
+
model: nn.Module,
|
38 |
+
world_mesh: DeviceMesh,
|
39 |
+
parallel_dims: ParallelDims,
|
40 |
+
job_config: JobConfig,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
44 |
+
parallelism to the model.
|
45 |
+
|
46 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
47 |
+
the model must fit on GPU or CPU memory.
|
48 |
+
"""
|
49 |
+
|
50 |
+
if parallel_dims.tp_enabled:
|
51 |
+
if (
|
52 |
+
job_config.experimental.enable_async_tensor_parallel
|
53 |
+
and not job_config.training.compile
|
54 |
+
):
|
55 |
+
raise RuntimeError("Async TP requires --training.compile")
|
56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
57 |
+
apply_tp(
|
58 |
+
model,
|
59 |
+
world_mesh["tp"],
|
60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
61 |
+
enable_float8=enable_float8_linear,
|
62 |
+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
|
63 |
+
)
|
64 |
+
|
65 |
+
if job_config.activation_checkpoint.mode != "none":
|
66 |
+
apply_ac(model, job_config.activation_checkpoint)
|
67 |
+
|
68 |
+
# turn on per-block compile after AC wrapping and before FSDP
|
69 |
+
if job_config.training.compile:
|
70 |
+
apply_compile(model)
|
71 |
+
|
72 |
+
if (
|
73 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
74 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
75 |
+
if parallel_dims.dp_replicate_enabled:
|
76 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
77 |
+
else:
|
78 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
79 |
+
|
80 |
+
apply_fsdp(
|
81 |
+
model,
|
82 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
83 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
84 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
85 |
+
pp_enabled=parallel_dims.pp_enabled,
|
86 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
87 |
+
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
|
88 |
+
)
|
89 |
+
|
90 |
+
if parallel_dims.dp_replicate_enabled:
|
91 |
+
logger.info("Applied HSDP to the model")
|
92 |
+
else:
|
93 |
+
logger.info("Applied FSDP to the model")
|
94 |
+
|
95 |
+
if parallel_dims.cp_enabled:
|
96 |
+
logger.info("Applied Context Parallel to the model")
|
97 |
+
|
98 |
+
if job_config.training.enable_cpu_offload:
|
99 |
+
logger.info("Applied CPU Offloading to the model")
|
100 |
+
elif parallel_dims.dp_replicate_enabled:
|
101 |
+
if world_mesh.ndim > 1:
|
102 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
103 |
+
apply_ddp(
|
104 |
+
model,
|
105 |
+
world_mesh,
|
106 |
+
enable_compile=job_config.training.compile,
|
107 |
+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
class TPPlan:
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
model=None,
|
115 |
+
loss_parallel=False,
|
116 |
+
enable_float8=False,
|
117 |
+
):
|
118 |
+
self.model = model
|
119 |
+
self.loss_parallel = loss_parallel
|
120 |
+
self.enable_float8 = enable_float8
|
121 |
+
self.base_model_prefix = getattr(model, "base_model_prefix", "model")
|
122 |
+
|
123 |
+
# TODO(vkuzo): once float8 configuration supports delayed scaling,
|
124 |
+
# add a check here to enforce supported float8 all-gather configurations
|
125 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
126 |
+
try:
|
127 |
+
from torchao.float8.float8_tensor_parallel import (
|
128 |
+
Float8ColwiseParallel,
|
129 |
+
Float8RowwiseParallel,
|
130 |
+
PrepareFloat8ModuleInput
|
131 |
+
)
|
132 |
+
except ImportError:
|
133 |
+
Float8ColwiseParallel = None
|
134 |
+
Float8RowwiseParallel = None
|
135 |
+
PrepareFloat8ModuleInput = None
|
136 |
+
if self.enable_float8 and Float8ColwiseParallel is not None:
|
137 |
+
self.rowwise_parallel = Float8RowwiseParallel
|
138 |
+
self.colwise_parallel = Float8ColwiseParallel
|
139 |
+
self.prepare_module_input = PrepareFloat8ModuleInput
|
140 |
+
self.prepare_module_output = PrepareModuleOutput
|
141 |
+
else:
|
142 |
+
self.rowwise_parallel = RowwiseParallel
|
143 |
+
self.colwise_parallel = ColwiseParallel
|
144 |
+
self.prepare_module_input = PrepareModuleInput
|
145 |
+
self.prepare_module_output = PrepareModuleOutput
|
146 |
+
|
147 |
+
@property
|
148 |
+
def model_plan(self):
|
149 |
+
plans = {
|
150 |
+
f"{self.base_model_prefix}.embeddings": RowwiseParallel(
|
151 |
+
input_layouts=Replicate(),
|
152 |
+
output_layouts=Shard(1),
|
153 |
+
),
|
154 |
+
f"{self.base_model_prefix}.norm": SequenceParallel(),
|
155 |
+
}
|
156 |
+
if self.loss_parallel:
|
157 |
+
plans.update(
|
158 |
+
{
|
159 |
+
"lm_head": ColwiseParallel(
|
160 |
+
input_layouts=Shard(1),
|
161 |
+
output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
|
162 |
+
use_local_output=not self.loss_parallel,
|
163 |
+
),
|
164 |
+
}
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
plans.update(
|
168 |
+
{
|
169 |
+
"lm_head": PrepareModuleWeight(layouts=Replicate()),
|
170 |
+
"criterion": LinearLossParallel(),
|
171 |
+
}
|
172 |
+
)
|
173 |
+
return plans
|
174 |
+
|
175 |
+
@property
|
176 |
+
def layer_plan(self):
|
177 |
+
return {
|
178 |
+
"attn_norm": SequenceParallel(),
|
179 |
+
**self.attn_plan,
|
180 |
+
"mlp_norm": SequenceParallel(),
|
181 |
+
**self.mlp_plan,
|
182 |
+
}
|
183 |
+
|
184 |
+
@property
|
185 |
+
def attn_plan(self):
|
186 |
+
raise NotImplementedError(
|
187 |
+
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
|
188 |
+
)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def mlp_plan(self):
|
192 |
+
return {
|
193 |
+
"mlp": self.prepare_module_input(
|
194 |
+
input_layouts=(Shard(1),),
|
195 |
+
desired_input_layouts=(Replicate(),),
|
196 |
+
),
|
197 |
+
"mlp.gate_proj": self.colwise_parallel(),
|
198 |
+
"mlp.up_proj": self.colwise_parallel(),
|
199 |
+
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
200 |
+
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
class TransformerTPPlan(TPPlan):
|
205 |
+
|
206 |
+
@property
|
207 |
+
def attn_plan(self):
|
208 |
+
return {
|
209 |
+
"attn": self.prepare_module_input(
|
210 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
211 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
212 |
+
),
|
213 |
+
"attn.q_proj": self.colwise_parallel(),
|
214 |
+
"attn.k_proj": self.colwise_parallel(),
|
215 |
+
"attn.v_proj": self.colwise_parallel(),
|
216 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
class GLATPPlan(TPPlan):
|
221 |
+
|
222 |
+
@property
|
223 |
+
def attn_plan(self):
|
224 |
+
return {
|
225 |
+
"attn": self.prepare_module_input(
|
226 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
227 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
228 |
+
),
|
229 |
+
"attn.q_proj": self.colwise_parallel(),
|
230 |
+
"attn.k_proj": self.colwise_parallel(),
|
231 |
+
"attn.v_proj": self.colwise_parallel(),
|
232 |
+
"attn.g_proj": self.colwise_parallel(),
|
233 |
+
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
|
234 |
+
"attn.gk_proj.1": self.colwise_parallel(),
|
235 |
+
"attn.g_norm": SequenceParallel(sequence_dim=-1),
|
236 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
|
241 |
+
|
242 |
+
|
243 |
+
def apply_tp(
|
244 |
+
model: nn.Module,
|
245 |
+
tp_mesh: DeviceMesh,
|
246 |
+
loss_parallel: bool,
|
247 |
+
enable_float8: bool,
|
248 |
+
enable_async_tp: bool,
|
249 |
+
):
|
250 |
+
"""Apply tensor parallelism."""
|
251 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
252 |
+
# transformer block's inputs)
|
253 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
254 |
+
# 3. Parallelize the final linear output layer
|
255 |
+
tp_plan = TP_PLAN_MAP[model.config.model_type](
|
256 |
+
model, loss_parallel=loss_parallel, enable_float8=enable_float8
|
257 |
+
)
|
258 |
+
parallelize_module(model, tp_mesh, tp_plan.model_plan)
|
259 |
+
|
260 |
+
blocks = get_blocks(model)
|
261 |
+
if blocks is None:
|
262 |
+
logger.warning("No block found for tensor parallelism")
|
263 |
+
else:
|
264 |
+
for _, block in enumerate(blocks):
|
265 |
+
parallelize_module(
|
266 |
+
module=block,
|
267 |
+
device_mesh=tp_mesh,
|
268 |
+
parallelize_plan=tp_plan.layer_plan,
|
269 |
+
)
|
270 |
+
|
271 |
+
if enable_async_tp:
|
272 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
273 |
+
|
274 |
+
torch._inductor.config._micro_pipeline_tp = True
|
275 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
276 |
+
|
277 |
+
logger.info(
|
278 |
+
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
|
279 |
+
"Tensor Parallelism to the model"
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
# for selective op activation checkpointing
|
284 |
+
_save_list = {
|
285 |
+
torch.ops.aten.mm.default,
|
286 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
287 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
288 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
289 |
+
# for low precision training, it's useful to always save
|
290 |
+
# the result of max, since the absolute maximum is
|
291 |
+
# used to compute the scaling factor for quantization.
|
292 |
+
torch.ops.aten.max.default,
|
293 |
+
}
|
294 |
+
|
295 |
+
|
296 |
+
def _apply_ac_to_block(module: nn.Module, ac_config):
|
297 |
+
valid_ac_modes = ("full", "selective")
|
298 |
+
if ac_config.mode not in valid_ac_modes:
|
299 |
+
raise ValueError(
|
300 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
301 |
+
)
|
302 |
+
|
303 |
+
if ac_config.mode == "full":
|
304 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
305 |
+
|
306 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
307 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
308 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
309 |
+
if not use_op_sac and not use_layer_sac:
|
310 |
+
raise ValueError(
|
311 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
312 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
313 |
+
)
|
314 |
+
if use_op_sac:
|
315 |
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
316 |
+
|
317 |
+
def _get_custom_policy(meta):
|
318 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
319 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
320 |
+
mm_count_key = f"{mode}_mm_count"
|
321 |
+
if func == torch.ops.aten.mm.default:
|
322 |
+
meta[mm_count_key] += 1
|
323 |
+
# Saves output of all compute ops, except every second mm
|
324 |
+
to_save = func in _save_list and not (
|
325 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
326 |
+
)
|
327 |
+
return (
|
328 |
+
CheckpointPolicy.MUST_SAVE
|
329 |
+
if to_save
|
330 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
331 |
+
)
|
332 |
+
|
333 |
+
return _custom_policy
|
334 |
+
|
335 |
+
def selective_checkpointing_context_fn():
|
336 |
+
meta = defaultdict(int)
|
337 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
338 |
+
|
339 |
+
return ptd_checkpoint_wrapper(
|
340 |
+
module,
|
341 |
+
context_fn=selective_checkpointing_context_fn,
|
342 |
+
preserve_rng_state=False,
|
343 |
+
)
|
344 |
+
elif use_layer_sac:
|
345 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
346 |
+
ac_freq = int(ac_config.selective_ac_option)
|
347 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
348 |
+
ptd_checkpoint_wrapper._count += 1
|
349 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
350 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
351 |
+
else:
|
352 |
+
return module
|
353 |
+
|
354 |
+
|
355 |
+
def apply_ac(model: nn.Module, ac_config):
|
356 |
+
"""Apply activation checkpointing to the model."""
|
357 |
+
blocks = get_blocks(model)
|
358 |
+
if blocks is None:
|
359 |
+
logger.warning("No block found for activation checkpointing")
|
360 |
+
return
|
361 |
+
|
362 |
+
for layer_id, block in blocks.named_children():
|
363 |
+
block = _apply_ac_to_block(block, ac_config)
|
364 |
+
blocks.register_module(layer_id, block)
|
365 |
+
|
366 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
367 |
+
|
368 |
+
|
369 |
+
def apply_compile(model: nn.Module):
|
370 |
+
"""
|
371 |
+
Apply torch.compile to each block, which makes compilation efficient due to
|
372 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
373 |
+
"""
|
374 |
+
|
375 |
+
blocks = get_blocks(model)
|
376 |
+
if blocks is None:
|
377 |
+
logger.warning("No block found for torch.compile")
|
378 |
+
else:
|
379 |
+
for layer_id, block in blocks.named_children():
|
380 |
+
block = torch.compile(block)
|
381 |
+
blocks.register_module(layer_id, block)
|
382 |
+
logger.info("Compiling each block with torch.compile")
|
383 |
+
|
384 |
+
real_model = get_model(model)
|
385 |
+
|
386 |
+
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
|
387 |
+
embeddings_key = get_components_name(real_model, "tok_embeddings")
|
388 |
+
if embeddings_key is not None:
|
389 |
+
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
|
390 |
+
real_model.register_module(embeddings_key, embeddings)
|
391 |
+
|
392 |
+
norm_key = get_components_name(real_model, "norm")
|
393 |
+
if norm_key is not None:
|
394 |
+
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
|
395 |
+
real_model.register_module(norm_key, norm)
|
396 |
+
|
397 |
+
lm_head_key = get_components_name(model, "lm_head")
|
398 |
+
if lm_head_key is not None:
|
399 |
+
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
|
400 |
+
model.register_module(lm_head_key, lm_head)
|
401 |
+
|
402 |
+
logger.info("Compiling the entire model with torch.compile")
|
403 |
+
model = torch.compile(model)
|
404 |
+
|
405 |
+
|
406 |
+
def apply_fsdp(
|
407 |
+
model: nn.Module,
|
408 |
+
dp_mesh: DeviceMesh,
|
409 |
+
param_dtype: torch.dtype,
|
410 |
+
reduce_dtype: torch.dtype,
|
411 |
+
pp_enabled: bool,
|
412 |
+
cpu_offload: bool = False,
|
413 |
+
reshard_after_forward_policy: str = "default",
|
414 |
+
):
|
415 |
+
"""
|
416 |
+
Apply data parallelism (via FSDP2) to the model.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
model (nn.Module): The model to apply data parallelism to.
|
420 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
421 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
422 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
423 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
424 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
425 |
+
reshard_after_forward_policy (str, optional):
|
426 |
+
The policy to use for resharding after forward pass. Defaults to "default".
|
427 |
+
Other options: "never", "always".
|
428 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
429 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
430 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
431 |
+
|
432 |
+
"""
|
433 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
434 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
435 |
+
if cpu_offload:
|
436 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
437 |
+
|
438 |
+
blocks = get_blocks(model)
|
439 |
+
if blocks is None:
|
440 |
+
logger.warning("No block found for FSDP")
|
441 |
+
else:
|
442 |
+
total_blocks = len(blocks)
|
443 |
+
for layer_id, block in enumerate(blocks):
|
444 |
+
if reshard_after_forward_policy == "always":
|
445 |
+
reshard_after_forward = True
|
446 |
+
elif reshard_after_forward_policy == "never":
|
447 |
+
reshard_after_forward = False
|
448 |
+
elif reshard_after_forward_policy == "default":
|
449 |
+
if pp_enabled:
|
450 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
451 |
+
# all-gathers, which can be expensive and non-overlapped
|
452 |
+
reshard_after_forward = False
|
453 |
+
else:
|
454 |
+
# As an optimization, do not reshard after forward for the last
|
455 |
+
# transformer block since FSDP would prefetch it immediately
|
456 |
+
reshard_after_forward = int(layer_id) < total_blocks - 1
|
457 |
+
else:
|
458 |
+
raise ValueError(
|
459 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
460 |
+
)
|
461 |
+
fully_shard(
|
462 |
+
block,
|
463 |
+
**fsdp_config,
|
464 |
+
reshard_after_forward=reshard_after_forward,
|
465 |
+
)
|
466 |
+
|
467 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
468 |
+
|
469 |
+
|
470 |
+
def apply_ddp(
|
471 |
+
model: nn.Module,
|
472 |
+
dp_mesh: DeviceMesh,
|
473 |
+
enable_compile: bool,
|
474 |
+
enable_compiled_autograd: bool,
|
475 |
+
):
|
476 |
+
if enable_compile:
|
477 |
+
if enable_compiled_autograd:
|
478 |
+
torch._dynamo.config.optimize_ddp = (
|
479 |
+
"python_reducer_without_compiled_forward"
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
483 |
+
|
484 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
485 |
+
|
486 |
+
logger.info("Applied DDP to the model")
|
487 |
+
|
488 |
+
|
489 |
+
def get_model(model):
|
490 |
+
base_model_prefix = getattr(model, "base_model_prefix", "model")
|
491 |
+
if not hasattr(model, base_model_prefix):
|
492 |
+
return None
|
493 |
+
model = getattr(model, base_model_prefix)
|
494 |
+
return model
|
495 |
+
|
496 |
+
|
497 |
+
def get_blocks(model):
|
498 |
+
# TODO[flame]: adapt for network not using 'layers' attribute
|
499 |
+
model = get_model(model)
|
500 |
+
if not hasattr(model, "layers"):
|
501 |
+
logger.warning('no "layers" in model can be found')
|
502 |
+
return None
|
503 |
+
return model.layers
|
504 |
+
|
505 |
+
|
506 |
+
def get_components_name(model, component_name):
|
507 |
+
"""
|
508 |
+
We try to catch tok_embeddings, norm layers and lm_head layers
|
509 |
+
We do not catch the layer names in the blocks, for blocks see `get_blocks`
|
510 |
+
We assume the model has the following structure:
|
511 |
+
LlamaForCausalLM:
|
512 |
+
Model:
|
513 |
+
embed_tokens,
|
514 |
+
layers,
|
515 |
+
norm,
|
516 |
+
lm_head
|
517 |
+
***
|
518 |
+
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
|
519 |
+
and for 'lm_head' we need to pass `model`
|
520 |
+
***
|
521 |
+
"""
|
522 |
+
|
523 |
+
if component_name == "tok_embeddings":
|
524 |
+
if hasattr(model, "tok_embeddings"):
|
525 |
+
return "tok_embeddings"
|
526 |
+
elif hasattr(model, "embed_tokens"):
|
527 |
+
return "embed_tokens"
|
528 |
+
elif hasattr(model, "embeddings"):
|
529 |
+
return "embeddings"
|
530 |
+
else:
|
531 |
+
logger.warning("No tok_embeddings found in model")
|
532 |
+
return None
|
533 |
+
|
534 |
+
elif component_name == "norm":
|
535 |
+
if hasattr(model, "norm"):
|
536 |
+
return "norm"
|
537 |
+
elif hasattr(model, "norms"):
|
538 |
+
return "norms"
|
539 |
+
elif hasattr(model, "layernorm"):
|
540 |
+
return "layernorm"
|
541 |
+
else:
|
542 |
+
logger.warning("No norm found in model")
|
543 |
+
return None
|
544 |
+
|
545 |
+
elif component_name == "lm_head":
|
546 |
+
if hasattr(model, "lm_head"):
|
547 |
+
return "lm_head"
|
548 |
+
else:
|
549 |
+
logger.warning("No lm_head found in model")
|
550 |
+
return None
|
flame/models/pipeline_fla.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This file applies the PT-D pipeline parallelism to the Llama model.
|
8 |
+
|
9 |
+
import copy
|
10 |
+
from typing import Callable, Optional, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.distributed import DeviceMesh
|
15 |
+
from torch.distributed.pipelining import PipelineStage
|
16 |
+
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
|
17 |
+
from transformers import PretrainedConfig
|
18 |
+
|
19 |
+
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
|
20 |
+
from torchtitan.config_manager import JobConfig
|
21 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
22 |
+
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
|
23 |
+
from torchtitan.tools.logging import logger
|
24 |
+
|
25 |
+
DeviceType = Union[int, str, torch.device]
|
26 |
+
|
27 |
+
|
28 |
+
def pipeline_fla(
|
29 |
+
model: nn.Module,
|
30 |
+
pp_mesh: DeviceMesh,
|
31 |
+
parallel_dims: ParallelDims,
|
32 |
+
job_config: JobConfig,
|
33 |
+
device: DeviceType,
|
34 |
+
model_config: PretrainedConfig,
|
35 |
+
loss_fn: Callable[..., torch.Tensor],
|
36 |
+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
|
37 |
+
stages, models = pipeline_fla_manual_split(
|
38 |
+
model, pp_mesh, parallel_dims, job_config, device, model_config
|
39 |
+
)
|
40 |
+
|
41 |
+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
|
42 |
+
|
43 |
+
# This is used in the train loop to determine whether to pass in the input_ids and labels
|
44 |
+
has_first_stage = False
|
45 |
+
has_last_stage = False
|
46 |
+
for stage in stages:
|
47 |
+
if stage.is_first:
|
48 |
+
has_first_stage = True
|
49 |
+
if stage.is_last:
|
50 |
+
has_last_stage = True
|
51 |
+
|
52 |
+
return pp_schedule, models, has_first_stage, has_last_stage
|
53 |
+
|
54 |
+
|
55 |
+
def pipeline_fla_manual_split(
|
56 |
+
whole_model: nn.Module,
|
57 |
+
pp_mesh: DeviceMesh,
|
58 |
+
parallel_dims: ParallelDims,
|
59 |
+
job_config: JobConfig,
|
60 |
+
device: DeviceType,
|
61 |
+
model_config: PretrainedConfig,
|
62 |
+
) -> tuple[list[PipelineStage], list[nn.Module]]:
|
63 |
+
"""
|
64 |
+
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
|
65 |
+
|
66 |
+
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
|
67 |
+
|
68 |
+
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
|
69 |
+
parallelism.
|
70 |
+
"""
|
71 |
+
pp_rank = pp_mesh.get_local_rank()
|
72 |
+
pp_size = pp_mesh.size()
|
73 |
+
|
74 |
+
splits = (
|
75 |
+
job_config.experimental.pipeline_parallel_split_points
|
76 |
+
or generate_split_points(
|
77 |
+
job_config, parallel_dims.pp, model_config.num_hidden_layers
|
78 |
+
)
|
79 |
+
)
|
80 |
+
|
81 |
+
def _build_stage(
|
82 |
+
stage_idx: int,
|
83 |
+
start_layer: Optional[str],
|
84 |
+
stop_layer: Optional[str],
|
85 |
+
is_first: bool = False,
|
86 |
+
is_last: bool = False,
|
87 |
+
) -> tuple[PipelineStage, nn.Module]:
|
88 |
+
model = copy.deepcopy(whole_model)
|
89 |
+
if not is_first:
|
90 |
+
# we do `model.tok_embeddings = None` here
|
91 |
+
real_model = get_model(model)
|
92 |
+
tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
|
93 |
+
setattr(real_model, tok_embeddings_name, None)
|
94 |
+
|
95 |
+
drop_layers = start_layer is not None
|
96 |
+
# Get module dictionary from get_blocks(model)
|
97 |
+
# and Create a list of keys before modifying dictionary
|
98 |
+
module_dict = get_blocks(model)._modules # Store reference
|
99 |
+
layer_names = list(module_dict.keys())
|
100 |
+
|
101 |
+
# Iterate over the list of keys instead of `_modules.items()`
|
102 |
+
for name in layer_names:
|
103 |
+
# Dynamically determine prefix (blocks.* or layers.*)
|
104 |
+
prefix = start_layer.split(".")[0] if start_layer else "layers"
|
105 |
+
layer_name = f"{prefix}.{name}" # Construct the correct name format
|
106 |
+
|
107 |
+
# Ensure `drop_layers` activation is based on actual naming
|
108 |
+
if layer_name == start_layer:
|
109 |
+
drop_layers = False
|
110 |
+
if layer_name == stop_layer:
|
111 |
+
drop_layers = True
|
112 |
+
|
113 |
+
# Delete layer if drop_layers is active
|
114 |
+
if drop_layers:
|
115 |
+
del module_dict[name] # Safe deletion from stored dictionary
|
116 |
+
|
117 |
+
if not is_last:
|
118 |
+
# we do `model.norm = None` and `model.output = None`
|
119 |
+
real_model = get_model(model)
|
120 |
+
norm_name = get_components_name(real_model, "norm")
|
121 |
+
setattr(real_model, norm_name, None)
|
122 |
+
|
123 |
+
head_name = get_components_name(model, "lm_head")
|
124 |
+
setattr(model, head_name, None)
|
125 |
+
|
126 |
+
stage = PipelineStage(
|
127 |
+
model,
|
128 |
+
stage_idx,
|
129 |
+
num_stages,
|
130 |
+
device,
|
131 |
+
group=pp_mesh.get_group("pp"),
|
132 |
+
)
|
133 |
+
return stage, model
|
134 |
+
|
135 |
+
num_stages = len(splits) + 1
|
136 |
+
stage_idx = pp_rank
|
137 |
+
|
138 |
+
stages = []
|
139 |
+
models = []
|
140 |
+
|
141 |
+
schedule_class = get_schedule_class(
|
142 |
+
job_config.experimental.pipeline_parallel_schedule
|
143 |
+
)
|
144 |
+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
|
145 |
+
|
146 |
+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
|
147 |
+
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
|
148 |
+
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
|
149 |
+
stage, model_chunk = _build_stage(
|
150 |
+
stage_idx,
|
151 |
+
start_layer,
|
152 |
+
stop_layer,
|
153 |
+
is_first=stage_idx == 0,
|
154 |
+
is_last=stage_idx == num_stages - 1,
|
155 |
+
)
|
156 |
+
logger.info(
|
157 |
+
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
|
158 |
+
f" with start_layer {start_layer}, stop_layer {stop_layer}"
|
159 |
+
)
|
160 |
+
stages.append(stage)
|
161 |
+
models.append(model_chunk)
|
162 |
+
return stages, models
|
flame/tools/__init__.py
ADDED
File without changes
|
flame/tools/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (136 Bytes). View file
|
|
flame/tools/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (2.14 kB). View file
|
|
flame/utils/__init__.py
ADDED
File without changes
|
flame/utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (136 Bytes). View file
|
|
flame/utils/__pycache__/checkpoint.cpython-312.pyc
ADDED
Binary file (4.07 kB). View file
|
|
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc
ADDED
Binary file (3.73 kB). View file
|
|
flame/utils/__pycache__/hf_utils.cpython-312.pyc
ADDED
Binary file (4.46 kB). View file
|
|
flame/utils/checkpoint.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import re
|
4 |
+
import shutil
|
5 |
+
from torchtitan.tools.logging import logger
|
6 |
+
|
7 |
+
|
8 |
+
def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
|
9 |
+
"""Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
|
10 |
+
if keep_latest_k <= 0:
|
11 |
+
return # Keep all checkpoints
|
12 |
+
|
13 |
+
logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
|
14 |
+
|
15 |
+
# Cleanup DCP checkpoints (step-*)
|
16 |
+
dcp_checkpoints = sorted(
|
17 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*")),
|
18 |
+
key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
|
19 |
+
reverse=True
|
20 |
+
)
|
21 |
+
# Filter out HF format directories
|
22 |
+
dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
|
23 |
+
|
24 |
+
if len(dcp_checkpoints) > keep_latest_k:
|
25 |
+
checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
|
26 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
27 |
+
for ckpt_path in checkpoints_to_delete:
|
28 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
29 |
+
try:
|
30 |
+
shutil.rmtree(ckpt_path)
|
31 |
+
except OSError as e:
|
32 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
33 |
+
|
34 |
+
|
35 |
+
# Cleanup HF checkpoints (step-*-hf)
|
36 |
+
hf_checkpoints = sorted(
|
37 |
+
glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
|
38 |
+
key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
|
39 |
+
reverse=True
|
40 |
+
)
|
41 |
+
|
42 |
+
if len(hf_checkpoints) > keep_latest_k:
|
43 |
+
checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
|
44 |
+
logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
|
45 |
+
for ckpt_path in checkpoints_to_delete:
|
46 |
+
if os.path.isdir(ckpt_path): # Ensure it's a directory
|
47 |
+
try:
|
48 |
+
shutil.rmtree(ckpt_path)
|
49 |
+
except OSError as e:
|
50 |
+
logger.error(f"Error removing directory {ckpt_path}: {e}")
|
flame/utils/convert_dcp_to_hf.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.serialization
|
12 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
14 |
+
|
15 |
+
import fla # noqa
|
16 |
+
from torchtitan.tools.logging import init_logger, logger
|
17 |
+
|
18 |
+
|
19 |
+
@torch.inference_mode()
|
20 |
+
def save_pretrained(
|
21 |
+
path: str,
|
22 |
+
step: int,
|
23 |
+
config: str,
|
24 |
+
tokenizer: str
|
25 |
+
):
|
26 |
+
logger.info(f"Loading the config from {config}")
|
27 |
+
config = AutoConfig.from_pretrained(config, trust_remote_code=True)
|
28 |
+
|
29 |
+
logger.info(f"Saving the config to {path}")
|
30 |
+
config.save_pretrained(path)
|
31 |
+
logger.info(f"Loading the tokenizer from {tokenizer}")
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
|
33 |
+
logger.info(f"Saving the tokenizer to {path}")
|
34 |
+
tokenizer.save_pretrained(path)
|
35 |
+
|
36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
37 |
+
# base_checkpoint_dir = os.path.dirname(path)
|
38 |
+
base_checkpoint_dir = path
|
39 |
+
checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
|
40 |
+
checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
|
41 |
+
logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
|
42 |
+
dcp_to_torch_save(checkpoint, checkpoint_path)
|
43 |
+
|
44 |
+
logger.info(f"Initializing the model from config\n{config}")
|
45 |
+
model = AutoModelForCausalLM.from_config(config)
|
46 |
+
logger.info(model)
|
47 |
+
logger.info("Loading state dict from the checkpoint")
|
48 |
+
|
49 |
+
# Add datetime.timedelta and io.BytesIO to safe globals
|
50 |
+
torch.serialization.add_safe_globals([timedelta, io.BytesIO])
|
51 |
+
# torch.load now with default weights_only=True will work
|
52 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
|
53 |
+
|
54 |
+
logger.info(f"Saving the model to {path}")
|
55 |
+
model.save_pretrained(path)
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
init_logger()
|
60 |
+
parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
|
61 |
+
parser.add_argument("--path", type=str, required=True)
|
62 |
+
parser.add_argument("--step", type=int, required=True)
|
63 |
+
parser.add_argument("--config", type=str, required=True)
|
64 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
65 |
+
args = parser.parse_args()
|
66 |
+
save_pretrained(args.path, args.step, args.config, args.tokenizer)
|
flame/utils/convert_hf_to_dcp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed.checkpoint as DCP
|
9 |
+
from transformers import AutoModelForCausalLM
|
10 |
+
|
11 |
+
import fla # noqa
|
12 |
+
from torchtitan.tools.logging import init_logger, logger
|
13 |
+
|
14 |
+
|
15 |
+
@torch.inference_mode()
|
16 |
+
def convert_hf_weights(model: str, checkpoint: str):
|
17 |
+
logger.info(f"Loading model from {model}")
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model)
|
19 |
+
state_dict = model.state_dict()
|
20 |
+
|
21 |
+
logger.info(f"Writing to DCP at '{checkpoint}'")
|
22 |
+
checkpoint.mkdir(parents=True, exist_ok=True)
|
23 |
+
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
|
24 |
+
DCP.save({"model": state_dict}, storage_writer=storage_writer)
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
init_logger()
|
29 |
+
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
|
30 |
+
parser.add_argument("--model", type=str, required=True)
|
31 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
convert_hf_weights(args.model, args.checkpoint)
|
flame/utils/hf_utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
|
4 |
+
from torchtitan.tools.logging import logger
|
5 |
+
|
6 |
+
def upload_checkpoint_to_hf(
|
7 |
+
local_path: str,
|
8 |
+
step: int,
|
9 |
+
hf_repo_id_for_run: str,
|
10 |
+
hf_keep_latest_k: int,
|
11 |
+
upload_format: str
|
12 |
+
):
|
13 |
+
"""Uploads a checkpoint directory to HF Hub and manages retention."""
|
14 |
+
if not os.path.isdir(local_path):
|
15 |
+
logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
|
16 |
+
return
|
17 |
+
|
18 |
+
api = HfApi()
|
19 |
+
token = HfFolder.get_token()
|
20 |
+
if not token:
|
21 |
+
logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
|
22 |
+
return
|
23 |
+
|
24 |
+
# --- Ensure the specific repository for this run exists ---
|
25 |
+
try:
|
26 |
+
logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
|
27 |
+
# Use create_repo which handles creation only if it doesn't exist
|
28 |
+
create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
|
29 |
+
logger.info(f"Repository {hf_repo_id_for_run} ensured.")
|
30 |
+
except Exception as e:
|
31 |
+
logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
|
32 |
+
return # Stop if repo interaction fails
|
33 |
+
|
34 |
+
commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
|
35 |
+
path_in_repo = f"step-{step}"
|
36 |
+
|
37 |
+
logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
|
38 |
+
try:
|
39 |
+
api.upload_folder(
|
40 |
+
folder_path=local_path,
|
41 |
+
path_in_repo=path_in_repo,
|
42 |
+
repo_id=hf_repo_id_for_run,
|
43 |
+
repo_type="model",
|
44 |
+
commit_message=commit_message,
|
45 |
+
token=token,
|
46 |
+
)
|
47 |
+
logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
|
50 |
+
if hf_keep_latest_k > 0:
|
51 |
+
logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
|
52 |
+
try:
|
53 |
+
repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
|
54 |
+
step_folders = [
|
55 |
+
item.path for item in repo_files
|
56 |
+
if item.path.startswith("step-") and item.path[5:].isdigit()
|
57 |
+
]
|
58 |
+
|
59 |
+
step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
|
60 |
+
|
61 |
+
if len(step_folders) > hf_keep_latest_k:
|
62 |
+
folders_to_delete = step_folders[hf_keep_latest_k:]
|
63 |
+
logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
|
64 |
+
for folder in folders_to_delete:
|
65 |
+
# Deleting requires repo_id, path_in_repo, and token
|
66 |
+
api.delete_folder(
|
67 |
+
repo_id=hf_repo_id_for_run,
|
68 |
+
path_in_repo=folder,
|
69 |
+
repo_type="model",
|
70 |
+
commit_message=f"Delete old checkpoint {folder}",
|
71 |
+
token=token
|
72 |
+
)
|
73 |
+
logger.info("Hub cleanup complete.")
|
74 |
+
else:
|
75 |
+
logger.info("No old checkpoints found on Hub to delete.")
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
|
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2025-07-16T22:10:00.785425491Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
|
2 |
+
{"time":"2025-07-16T22:10:01.508654924Z","level":"INFO","msg":"stream: created new stream","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
3 |
+
{"time":"2025-07-16T22:10:01.508690211Z","level":"INFO","msg":"stream: started","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
4 |
+
{"time":"2025-07-16T22:10:01.508739999Z","level":"INFO","msg":"writer: Do: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
5 |
+
{"time":"2025-07-16T22:10:01.508759314Z","level":"INFO","msg":"handler: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
6 |
+
{"time":"2025-07-16T22:10:01.508803829Z","level":"INFO","msg":"sender: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
7 |
+
{"time":"2025-07-16T23:09:45.740737848Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
8 |
+
{"time":"2025-07-16T23:18:29.56428269Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
9 |
+
{"time":"2025-07-16T23:19:01.917480335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
10 |
+
{"time":"2025-07-16T23:19:36.868918826Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
11 |
+
{"time":"2025-07-16T23:20:16.297827588Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
12 |
+
{"time":"2025-07-16T23:20:18.619477493Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:46470->127.0.0.53:53: i/o timeout"}
|
13 |
+
{"time":"2025-07-16T23:20:30.740650327Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:47482->127.0.0.53:53: i/o timeout"}
|
14 |
+
{"time":"2025-07-16T23:21:04.536690541Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
15 |
+
{"time":"2025-07-16T23:21:49.291673175Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
16 |
+
{"time":"2025-07-16T23:22:07.542159208Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
17 |
+
{"time":"2025-07-16T23:23:23.103733736Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
18 |
+
{"time":"2025-07-16T23:23:37.543151076Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
19 |
+
{"time":"2025-07-16T23:25:07.544031298Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
20 |
+
{"time":"2025-07-16T23:26:37.545971769Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
21 |
+
{"time":"2025-07-16T23:27:42.194377246Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
|
22 |
+
{"time":"2025-07-16T23:27:59.564813743Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000912631,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
|
23 |
+
{"time":"2025-07-16T23:28:07.547697617Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
24 |
+
{"time":"2025-07-16T23:29:37.549836886Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"}
|
25 |
+
{"time":"2025-07-16T23:31:01.930916994Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000672411,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
26 |
+
{"time":"2025-07-16T23:31:02.101966833Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000995925,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
27 |
+
{"time":"2025-07-16T23:31:07.103368571Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000796336,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
|
28 |
+
{"time":"2025-07-16T23:31:07.551682713Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
29 |
+
{"time":"2025-07-16T23:32:37.553473869Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
30 |
+
{"time":"2025-07-16T23:33:58.248779065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
|
31 |
+
{"time":"2025-07-16T23:34:58.351555112Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1018.787711083,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"}
|
32 |
+
{"time":"2025-07-16T23:34:58.351650283Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.421498346,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
33 |
+
{"time":"2025-07-16T23:34:58.351778293Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":831.249242004,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
|
34 |
+
{"time":"2025-07-16T23:34:58.351785775Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.250829923,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
35 |
+
{"time":"2025-07-17T01:31:13.353253854Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
|
36 |
+
{"time":"2025-07-17T08:06:16.748740406Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
37 |
+
{"time":"2025-07-17T09:50:19.526737851Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": read tcp 10.0.2.15:54882->35.186.228.49:443: read: connection reset by peer"}
|
38 |
+
{"time":"2025-07-17T09:52:30.348552703Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
39 |
+
{"time":"2025-07-17T09:53:02.422139335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
40 |
+
{"time":"2025-07-17T09:53:36.600890938Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
41 |
+
{"time":"2025-07-17T09:54:16.203516351Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
42 |
+
{"time":"2025-07-17T09:55:05.357439477Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
43 |
+
{"time":"2025-07-17T09:56:15.05960959Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
44 |
+
{"time":"2025-07-17T09:57:45.061688428Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
45 |
+
{"time":"2025-07-17T09:59:15.063226591Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
46 |
+
{"time":"2025-07-17T10:00:45.065259852Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
47 |
+
{"time":"2025-07-17T10:01:04.518171545Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
|
48 |
+
{"time":"2025-07-17T10:02:00.347889757Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000372919,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
|
49 |
+
{"time":"2025-07-17T10:02:15.066174619Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
50 |
+
{"time":"2025-07-17T10:03:45.067145051Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
51 |
+
{"time":"2025-07-17T10:05:02.098970791Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000073665,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
52 |
+
{"time":"2025-07-17T10:05:07.474477054Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000841939,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
|
53 |
+
{"time":"2025-07-17T10:05:15.068468165Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"}
|
54 |
+
{"time":"2025-07-17T10:05:16.930808745Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000229861,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
55 |
+
{"time":"2025-07-17T10:06:07.008582668Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"}
|
56 |
+
{"time":"2025-07-17T10:06:45.070340311Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
57 |
+
{"time":"2025-07-17T10:07:57.799911415Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"}
|
58 |
+
{"time":"2025-07-17T10:08:57.969386735Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1017.621908973,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"}
|
59 |
+
{"time":"2025-07-17T10:08:57.969579361Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":835.870728331,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
60 |
+
{"time":"2025-07-17T10:08:57.969680501Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":821.039158554,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
61 |
+
{"time":"2025-07-17T10:08:57.969682134Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":830.496074059,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"}
|
62 |
+
{"time":"2025-07-17T12:53:12.780364188Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
63 |
+
{"time":"2025-07-17T16:43:31.998287109Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
64 |
+
{"time":"2025-07-18T00:01:06.015630566Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
65 |
+
{"time":"2025-07-18T06:56:24.118529653Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
66 |
+
{"time":"2025-07-18T14:32:12.830145916Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
67 |
+
{"time":"2025-07-18T19:51:31.703829065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
68 |
+
{"time":"2025-07-19T03:35:03.743864446Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
69 |
+
{"time":"2025-07-19T21:22:32.639517404Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:51870->35.186.228.49:443: read: connection reset by peer"}
|
70 |
+
{"time":"2025-07-19T21:31:32.643369264Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:38482->35.186.228.49:443: read: connection reset by peer"}
|
71 |
+
{"time":"2025-07-20T00:27:42.221361901Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
72 |
+
{"time":"2025-07-20T09:40:16.319872482Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
73 |
+
{"time":"2025-07-20T09:45:18.218885403Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
74 |
+
{"time":"2025-07-20T19:19:37.674808147Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
75 |
+
{"time":"2025-07-20T20:26:46.102126738Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
76 |
+
{"time":"2025-07-20T21:40:42.245223721Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
77 |
+
{"time":"2025-07-20T21:42:31.526229193Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
78 |
+
{"time":"2025-07-20T22:42:07.859288654Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
79 |
+
{"time":"2025-07-21T03:41:28.397742169Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
80 |
+
{"time":"2025-07-21T04:49:16.742257697Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
81 |
+
{"time":"2025-07-21T05:48:28.62347913Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
82 |
+
{"time":"2025-07-21T06:22:31.529351974Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"}
|
83 |
+
{"time":"2025-07-21T14:47:44.545628902Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
84 |
+
{"time":"2025-07-21T21:19:44.840025606Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
|
85 |
+
{"time":"2025-07-21T21:19:44.94975041Z","level":"INFO","msg":"handler: operation stats","stats":{}}
|
86 |
+
{"time":"2025-07-21T21:19:44.958211652Z","level":"INFO","msg":"stream: closing","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
87 |
+
{"time":"2025-07-21T21:19:44.958407771Z","level":"INFO","msg":"writer: Close: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
88 |
+
{"time":"2025-07-21T21:19:44.958426934Z","level":"INFO","msg":"handler: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
89 |
+
{"time":"2025-07-21T21:19:44.958428316Z","level":"INFO","msg":"sender: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
90 |
+
{"time":"2025-07-21T21:19:44.958480192Z","level":"INFO","msg":"stream: closed","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"}
|
torchtitan/components/__pycache__/float8.cpython-312.pyc
ADDED
Binary file (6.2 kB). View file
|
|
torchtitan/components/__pycache__/ft.cpython-312.pyc
ADDED
Binary file (6.75 kB). View file
|
|
torchtitan/components/float8.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# [Note] Getting the 'torchao' package:
|
8 |
+
# This script requires the 'torchao' package to function correctly.
|
9 |
+
# Please ensure you have this package installed from the appropriate repository.
|
10 |
+
# You can obtain it from https://github.com/pytorch/ao by following the
|
11 |
+
# installation instructions.
|
12 |
+
|
13 |
+
# Note: Performance
|
14 |
+
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from torchtitan.config_manager import JobConfig
|
20 |
+
from torchtitan.distributed import ParallelDims
|
21 |
+
from torchtitan.protocols.model_converter import (
|
22 |
+
ModelConverter,
|
23 |
+
register_model_converter,
|
24 |
+
)
|
25 |
+
from torchtitan.tools.logging import logger
|
26 |
+
|
27 |
+
|
28 |
+
def _is_sm89_or_later():
|
29 |
+
# Float8 is only supported on SM89 or later (H100+ GPUs)
|
30 |
+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
|
31 |
+
|
32 |
+
|
33 |
+
class Float8Converter(ModelConverter):
|
34 |
+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
|
35 |
+
self.enabled = False
|
36 |
+
|
37 |
+
float8_config = job_config.float8
|
38 |
+
if not _is_sm89_or_later():
|
39 |
+
logger.warning(
|
40 |
+
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
|
41 |
+
)
|
42 |
+
return
|
43 |
+
try:
|
44 |
+
from torchao.float8 import Float8LinearConfig
|
45 |
+
except ImportError as e:
|
46 |
+
raise ImportError(
|
47 |
+
"torchao is not installed. Please install it to use float8 linear layers."
|
48 |
+
) from e
|
49 |
+
|
50 |
+
if float8_config.recipe_name is not None and not hasattr(
|
51 |
+
Float8LinearConfig, "from_recipe_name"
|
52 |
+
):
|
53 |
+
logger.warning(
|
54 |
+
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
|
55 |
+
"is too old, please install torchao v0.9.0 or later and try again",
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
self.enabled = True
|
60 |
+
self.filter_fqns = float8_config.filter_fqns
|
61 |
+
|
62 |
+
if float8_config.recipe_name is not None:
|
63 |
+
assert (
|
64 |
+
not float8_config.enable_fsdp_float8_all_gather
|
65 |
+
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
|
66 |
+
assert (
|
67 |
+
not float8_config.force_recompute_fp8_weight_in_bwd
|
68 |
+
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
|
69 |
+
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
|
70 |
+
self.precompute_scale = False
|
71 |
+
logger.info(
|
72 |
+
f"Float8 training active with recipe {float8_config.recipe_name}"
|
73 |
+
)
|
74 |
+
|
75 |
+
else:
|
76 |
+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
|
77 |
+
enable_fsdp_float8_all_gather = (
|
78 |
+
parallel_dims.dp_shard_enabled
|
79 |
+
and float8_config.enable_fsdp_float8_all_gather
|
80 |
+
)
|
81 |
+
self.config = Float8LinearConfig(
|
82 |
+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
83 |
+
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
|
84 |
+
)
|
85 |
+
# for precompute_float8_dynamic_scale_for_fsdp
|
86 |
+
self.precompute_scale = (
|
87 |
+
enable_fsdp_float8_all_gather
|
88 |
+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
|
89 |
+
)
|
90 |
+
logger.info("Float8 tensorwise scaled training active")
|
91 |
+
|
92 |
+
def convert(self, model: nn.Module):
|
93 |
+
return self.convert_to_float8_training(model)
|
94 |
+
|
95 |
+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
|
96 |
+
return self.precompute_float8_dynamic_scale_for_fsdp(model)
|
97 |
+
|
98 |
+
def convert_to_float8_training(self, model: nn.Module):
|
99 |
+
"""
|
100 |
+
This function converts the linear layers of `model` to `Float8Linear`.
|
101 |
+
Note that today, only dynamic tensor scaling (the default) is supported.
|
102 |
+
This will mutate the model inplace.
|
103 |
+
"""
|
104 |
+
if not self.enabled:
|
105 |
+
return
|
106 |
+
|
107 |
+
from torchao.float8 import convert_to_float8_training
|
108 |
+
|
109 |
+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
|
110 |
+
convert_to_float8_training(
|
111 |
+
model,
|
112 |
+
config=self.config,
|
113 |
+
module_filter_fn=self._module_filter_fn,
|
114 |
+
)
|
115 |
+
logger.info(
|
116 |
+
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
|
117 |
+
f"{self.config.enable_fsdp_float8_all_gather}"
|
118 |
+
)
|
119 |
+
|
120 |
+
def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
|
121 |
+
if not isinstance(mod, nn.Linear):
|
122 |
+
return False
|
123 |
+
|
124 |
+
# All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
|
125 |
+
dims_multiples_of_16 = (
|
126 |
+
mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
|
127 |
+
)
|
128 |
+
|
129 |
+
# If the fqn matches any filtered fqn, then we should not convert this module.
|
130 |
+
is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
|
131 |
+
|
132 |
+
return dims_multiples_of_16 and not is_filtered_fqn
|
133 |
+
|
134 |
+
def precompute_float8_dynamic_scale_for_fsdp(
|
135 |
+
self, model: nn.Module | list[nn.Module]
|
136 |
+
):
|
137 |
+
if not self.enabled:
|
138 |
+
return
|
139 |
+
|
140 |
+
if not self.precompute_scale:
|
141 |
+
return
|
142 |
+
|
143 |
+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
|
144 |
+
|
145 |
+
models = [model] if isinstance(model, nn.Module) else model
|
146 |
+
for m in models:
|
147 |
+
precompute_float8_dynamic_scale_for_fsdp(m)
|
148 |
+
|
149 |
+
|
150 |
+
register_model_converter(Float8Converter, "float8")
|
torchtitan/components/ft.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import importlib
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed._functional_collectives as funcol
|
14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
15 |
+
from torch.distributed.tensor import DTensor
|
16 |
+
from torchtitan.config_manager import JobConfig
|
17 |
+
from torchtitan.distributed import ParallelDims
|
18 |
+
|
19 |
+
if importlib.util.find_spec("torchft") is not None:
|
20 |
+
import torchft as ft
|
21 |
+
|
22 |
+
has_torchft = True
|
23 |
+
else:
|
24 |
+
has_torchft = False
|
25 |
+
|
26 |
+
|
27 |
+
class FTManager:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
manager: Optional["ft.Manager"],
|
31 |
+
group_size: int = 1,
|
32 |
+
replica_id: int = 0,
|
33 |
+
) -> None:
|
34 |
+
self._manager = manager
|
35 |
+
self.group_size = group_size
|
36 |
+
self.replica_id = replica_id
|
37 |
+
|
38 |
+
@property
|
39 |
+
def enabled(self) -> bool:
|
40 |
+
return self._manager is not None
|
41 |
+
|
42 |
+
@property
|
43 |
+
def manager(self) -> "ft.Manager":
|
44 |
+
assert self._manager is not None
|
45 |
+
return self._manager
|
46 |
+
|
47 |
+
def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
|
48 |
+
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
|
49 |
+
|
50 |
+
|
51 |
+
def init_ft_manager(job: JobConfig) -> FTManager:
|
52 |
+
"""Initialize the FT manager if TorchFT is enabled.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
job (JobConfig): The job configuration.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
|
59 |
+
"""
|
60 |
+
if not job.fault_tolerance.enable:
|
61 |
+
return FTManager(None)
|
62 |
+
|
63 |
+
if not has_torchft:
|
64 |
+
raise ImportError("torchft is not installed. Please install it.")
|
65 |
+
|
66 |
+
if job.fault_tolerance.min_replica_size < 1:
|
67 |
+
raise ValueError("At least one FT replica is required.")
|
68 |
+
|
69 |
+
pg = ft.ProcessGroupBabyNCCL()
|
70 |
+
|
71 |
+
return FTManager(
|
72 |
+
ft.Manager(
|
73 |
+
pg=pg,
|
74 |
+
min_replica_size=job.fault_tolerance.min_replica_size,
|
75 |
+
load_state_dict=None,
|
76 |
+
state_dict=None,
|
77 |
+
use_async_quorum=True,
|
78 |
+
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
|
79 |
+
),
|
80 |
+
group_size=job.fault_tolerance.group_size,
|
81 |
+
replica_id=job.fault_tolerance.replica_id,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class FTParallelDims(ParallelDims):
|
87 |
+
ft_manager: FTManager
|
88 |
+
|
89 |
+
def build_mesh(self, device_type: str) -> DeviceMesh:
|
90 |
+
def func(
|
91 |
+
device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
|
92 |
+
) -> DeviceMesh:
|
93 |
+
from torchft.process_group import ft_init_device_mesh
|
94 |
+
|
95 |
+
return ft_init_device_mesh(
|
96 |
+
device_type=device_type,
|
97 |
+
mesh_shape=mesh_shape,
|
98 |
+
mesh_dim_names=mesh_dim_names,
|
99 |
+
replicate_dim=mesh_dim_names.index("dp_replicate"),
|
100 |
+
manager=self.ft_manager.manager,
|
101 |
+
)
|
102 |
+
|
103 |
+
dims = []
|
104 |
+
names = []
|
105 |
+
for d, name in zip(
|
106 |
+
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
|
107 |
+
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
|
108 |
+
):
|
109 |
+
if d > 1 or name == "dp_replicate":
|
110 |
+
dims.append(d)
|
111 |
+
names.append(name)
|
112 |
+
|
113 |
+
return self._build_mesh(device_type, dims, names, func)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def dp_replicate_enabled(self):
|
117 |
+
return True
|
118 |
+
|
119 |
+
|
120 |
+
def ft_dist_reduce(
|
121 |
+
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
|
122 |
+
) -> tuple[torch.Tensor, str, DeviceMesh]:
|
123 |
+
if has_torchft and isinstance(mesh, ft.process_group._FlattenDeviceMesh):
|
124 |
+
x = funcol.all_reduce(
|
125 |
+
x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg
|
126 |
+
)
|
127 |
+
return x, reduceOp, mesh.managed_mesh.mesh
|
128 |
+
return x, reduceOp, mesh
|
129 |
+
|
130 |
+
|
131 |
+
def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
|
132 |
+
if has_torchft:
|
133 |
+
mesh = total_norm._spec.mesh
|
134 |
+
if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
|
135 |
+
# The gradients along the replicated dim has already been reduced.
|
136 |
+
# So we don't need another reducution beforing removing the
|
137 |
+
# replicate dimension
|
138 |
+
local_tensor = total_norm.to_local()
|
139 |
+
placements = list(copy.copy(total_norm._spec.placements))
|
140 |
+
placements.pop(mesh.replicate_dim)
|
141 |
+
return DTensor.from_local(local_tensor, mesh.mesh, placements)
|
142 |
+
|
143 |
+
return total_norm
|
torchtitan/experiments/deepseek_v3/model_config.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class ModelArgs:
|
12 |
+
r"""
|
13 |
+
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
14 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
15 |
+
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
+
documentation from [`PretrainedConfig`] for more information.
|
18 |
+
Args:
|
19 |
+
vocab_size (`int`, *optional*, defaults to 129280):
|
20 |
+
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
21 |
+
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
23 |
+
Dimension of the hidden representations.
|
24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
25 |
+
Dimension of the MLP representations.
|
26 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
27 |
+
Dimension of the MoE representations.
|
28 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
29 |
+
Number of hidden layers in the Transformer decoder.
|
30 |
+
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
31 |
+
Number of nextn predict layers in the DeepSeekV3 Model.
|
32 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
33 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
34 |
+
n_shared_experts (`int`, *optional*, defaults to None):
|
35 |
+
Number of shared experts, None means dense model.
|
36 |
+
n_routed_experts (`int`, *optional*, defaults to None):
|
37 |
+
Number of routed experts, None means dense model.
|
38 |
+
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
39 |
+
Scaling factor or routed experts.
|
40 |
+
topk_method (`str`, *optional*, defaults to `gready`):
|
41 |
+
Topk method used in routed gate.
|
42 |
+
n_group (`int`, *optional*, defaults to None):
|
43 |
+
Number of groups for routed experts.
|
44 |
+
topk_group (`int`, *optional*, defaults to None):
|
45 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within
|
46 |
+
`topk_group` groups).
|
47 |
+
num_experts_per_tok (`int`, *optional*, defaults to None):
|
48 |
+
Number of selected experts, None means dense model.
|
49 |
+
moe_layer_freq (`int`, *optional*, defaults to 1):
|
50 |
+
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
51 |
+
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
52 |
+
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
53 |
+
\--k dense layers--/
|
54 |
+
norm_topk_prob (`bool`, *optional*, defaults to False):
|
55 |
+
Whether to normalize the weights of the routed experts.
|
56 |
+
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
57 |
+
Method of computing expert weights.
|
58 |
+
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
59 |
+
Auxiliary loss weight coefficient.
|
60 |
+
seq_aux = (`bool`, *optional*, defaults to True):
|
61 |
+
Whether to compute the auxiliary loss for each individual sample.
|
62 |
+
num_key_value_heads (`int`, *optional*):
|
63 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
64 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
65 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
66 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
67 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
68 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
69 |
+
`num_attention_heads`.
|
70 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
71 |
+
The non-linear activation function (function or string) in the decoder.
|
72 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
73 |
+
The maximum sequence length that this model might ever be used with.
|
74 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
75 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
76 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
77 |
+
The epsilon used by the rms normalization layers.
|
78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
80 |
+
relevant if `config.is_decoder=True`.
|
81 |
+
pad_token_id (`int`, *optional*):
|
82 |
+
Padding token id.
|
83 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
84 |
+
Beginning of stream token id.
|
85 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
86 |
+
End of stream token id.
|
87 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
88 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
89 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
90 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
91 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
92 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to tie weight embeddings
|
94 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
95 |
+
The base period of the RoPE embeddings.
|
96 |
+
rope_scaling (`Dict`, *optional*):
|
97 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
98 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
99 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
100 |
+
`max_position_embeddings` to the expected new maximum.
|
101 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
102 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
103 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
104 |
+
The dropout ratio for the attention probabilities.
|
105 |
+
"""
|
106 |
+
|
107 |
+
vocab_size: int = 129280
|
108 |
+
hidden_size: int = 7168
|
109 |
+
intermediate_size: int = 18432
|
110 |
+
moe_intermediate_size: int = 2048
|
111 |
+
num_hidden_layers: int = 61
|
112 |
+
num_nextn_predict_layers: int = 1
|
113 |
+
num_attention_heads: int = 128
|
114 |
+
num_key_value_heads: int = 128
|
115 |
+
n_shared_experts: int = 1
|
116 |
+
n_routed_experts: int = 256
|
117 |
+
ep_size: int = 1
|
118 |
+
routed_scaling_factor: float = 2.5
|
119 |
+
kv_lora_rank: int = 512
|
120 |
+
q_lora_rank: int = 1536
|
121 |
+
qk_rope_head_dim: int = 64
|
122 |
+
v_head_dim: int = 128
|
123 |
+
qk_nope_head_dim: int = 128
|
124 |
+
topk_method: str = "noaux_tc"
|
125 |
+
n_group: int = 8
|
126 |
+
topk_group: int = 4
|
127 |
+
num_experts_per_tok: int = 8
|
128 |
+
moe_layer_freq: int = 1
|
129 |
+
first_k_dense_replace: int = 3
|
130 |
+
norm_topk_prob: bool = True
|
131 |
+
scoring_func: str = "sigmoid"
|
132 |
+
aux_loss_alpha: float = 0.001
|
133 |
+
seq_aux: bool = True
|
134 |
+
hidden_act: str = "silu"
|
135 |
+
max_position_embeddings: int = 163840
|
136 |
+
initializer_range: float = 0.02
|
137 |
+
rms_norm_eps: float = 1e-6
|
138 |
+
rope_theta: float = 10000.0
|
139 |
+
rope_scaling: dict = field(
|
140 |
+
default_factory=lambda: {
|
141 |
+
"beta_fast": 32,
|
142 |
+
"beta_slow": 1,
|
143 |
+
"factor": 40,
|
144 |
+
"mscale": 1.0,
|
145 |
+
"mscale_all_dim": 1.0,
|
146 |
+
"original_max_position_embeddings": 4096,
|
147 |
+
"type": "yarn",
|
148 |
+
}
|
149 |
+
)
|
150 |
+
attention_bias: bool = False
|
151 |
+
attention_dropout: float = 0.0
|
152 |
+
pad_token_id = None
|
153 |
+
# Added for symmetric memory
|
154 |
+
max_seq_len: int = 4096
|
155 |
+
dtype: str = "bfloat16"
|
156 |
+
# Added for pipeline parallel
|
157 |
+
num_stages: int = 1
|
158 |
+
stage_idx: int = 0
|
159 |
+
|
160 |
+
|
161 |
+
# This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
|
162 |
+
deepseek_v2_lite_config = ModelArgs(
|
163 |
+
vocab_size=102400,
|
164 |
+
hidden_size=2048,
|
165 |
+
intermediate_size=10944,
|
166 |
+
moe_intermediate_size=1408,
|
167 |
+
num_hidden_layers=27,
|
168 |
+
num_attention_heads=16,
|
169 |
+
num_key_value_heads=16,
|
170 |
+
n_shared_experts=2,
|
171 |
+
n_routed_experts=64,
|
172 |
+
routed_scaling_factor=1.0,
|
173 |
+
kv_lora_rank=512,
|
174 |
+
q_lora_rank=None,
|
175 |
+
qk_rope_head_dim=64,
|
176 |
+
v_head_dim=128,
|
177 |
+
qk_nope_head_dim=128,
|
178 |
+
topk_method="greedy",
|
179 |
+
n_group=1,
|
180 |
+
topk_group=1,
|
181 |
+
num_experts_per_tok=6,
|
182 |
+
first_k_dense_replace=1,
|
183 |
+
norm_topk_prob=False,
|
184 |
+
scoring_func="softmax",
|
185 |
+
max_position_embeddings=4096,
|
186 |
+
rope_scaling={
|
187 |
+
"beta_fast": 32,
|
188 |
+
"beta_slow": 1,
|
189 |
+
"factor": 40,
|
190 |
+
"mscale": 0.707,
|
191 |
+
"mscale_all_dim": 0.707,
|
192 |
+
"original_max_position_embeddings": 4096,
|
193 |
+
"type": "yarn",
|
194 |
+
},
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
# Model configuration registry
|
199 |
+
# Key is the model distribution ID on HuggingFace Hub
|
200 |
+
deepseek_config_registry = {
|
201 |
+
"deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
|
202 |
+
"deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
|
203 |
+
"deepseek-ai/deepseek-v3": ModelArgs(),
|
204 |
+
}
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"OnDeviceAllToAllV",
|
11 |
+
]
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
11 |
+
|
12 |
+
|
13 |
+
@triton.jit
|
14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
15 |
+
if sem == "relaxed":
|
16 |
+
tl.inline_asm_elementwise(
|
17 |
+
"""
|
18 |
+
{
|
19 |
+
.reg .u32 %tmp32_<1>;
|
20 |
+
.reg .pred %p<1>;
|
21 |
+
|
22 |
+
send_signal:
|
23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
25 |
+
@!%p0 bra send_signal;
|
26 |
+
}
|
27 |
+
""",
|
28 |
+
"=r, l",
|
29 |
+
[addrs],
|
30 |
+
dtype=tl.int32,
|
31 |
+
is_pure=False,
|
32 |
+
pack=1,
|
33 |
+
)
|
34 |
+
elif sem == "acq_rel":
|
35 |
+
tl.inline_asm_elementwise(
|
36 |
+
"""
|
37 |
+
{
|
38 |
+
.reg .u32 %tmp32_<1>;
|
39 |
+
.reg .pred %p<1>;
|
40 |
+
|
41 |
+
send_signal:
|
42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
44 |
+
@!%p0 bra send_signal;
|
45 |
+
}
|
46 |
+
""",
|
47 |
+
"=r, l",
|
48 |
+
[addrs],
|
49 |
+
dtype=tl.int32,
|
50 |
+
is_pure=False,
|
51 |
+
pack=1,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
55 |
+
|
56 |
+
|
57 |
+
@triton.jit
|
58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
59 |
+
if sem == "relaxed":
|
60 |
+
tl.inline_asm_elementwise(
|
61 |
+
"""
|
62 |
+
{
|
63 |
+
.reg .u32 %tmp32_<1>;
|
64 |
+
.reg .pred %p<1>;
|
65 |
+
|
66 |
+
wait_signal:
|
67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
69 |
+
@!%p0 bra wait_signal;
|
70 |
+
}
|
71 |
+
""",
|
72 |
+
"=r, l",
|
73 |
+
[addrs],
|
74 |
+
dtype=tl.int32,
|
75 |
+
is_pure=False,
|
76 |
+
pack=1,
|
77 |
+
)
|
78 |
+
elif sem == "acq_rel":
|
79 |
+
tl.inline_asm_elementwise(
|
80 |
+
"""
|
81 |
+
{
|
82 |
+
.reg .u32 %tmp32_<1>;
|
83 |
+
.reg .pred %p<1>;
|
84 |
+
|
85 |
+
wait_signal:
|
86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
88 |
+
@!%p0 bra wait_signal;
|
89 |
+
}
|
90 |
+
""",
|
91 |
+
"=r, l",
|
92 |
+
[addrs],
|
93 |
+
dtype=tl.int32,
|
94 |
+
is_pure=False,
|
95 |
+
pack=1,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
99 |
+
|
100 |
+
|
101 |
+
@triton.jit
|
102 |
+
def blockwise_barrier(
|
103 |
+
signal_pad_ptrs,
|
104 |
+
block_id,
|
105 |
+
rank: tl.constexpr,
|
106 |
+
world_size: tl.constexpr,
|
107 |
+
sem: tl.constexpr,
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
111 |
+
|
112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
113 |
+
building block for expressing different synchronization patterns.
|
114 |
+
|
115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
116 |
+
kernels across all devices are visible to the current kernel:
|
117 |
+
|
118 |
+
blockwise_barrier(..., sem="relaxed")
|
119 |
+
sync_threads()
|
120 |
+
|
121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
122 |
+
block are visible to all remote blocks with matching blockIdx:
|
123 |
+
|
124 |
+
sync_threads()
|
125 |
+
blockwise_barrier(..., sem="acq_rel")
|
126 |
+
sync_threads()
|
127 |
+
|
128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
129 |
+
for writing by subsequent kernels across all devices.
|
130 |
+
|
131 |
+
sync_threads()
|
132 |
+
blockwise_barrier(..., sem="relaxed")
|
133 |
+
|
134 |
+
CUDA graph friendliness:
|
135 |
+
|
136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
137 |
+
pad, which resets to a zero-filled state after each successful
|
138 |
+
synchronization. This design eliminates the need for incrementing a
|
139 |
+
flag from host.
|
140 |
+
"""
|
141 |
+
if block_id is None:
|
142 |
+
block_id = get_flat_bid()
|
143 |
+
flat_tid = get_flat_tid()
|
144 |
+
|
145 |
+
remote_ranks = tl.arange(0, world_size)
|
146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
148 |
+
tl.pointer_type(tl.uint32)
|
149 |
+
)
|
150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
151 |
+
|
152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
153 |
+
tl.pointer_type(tl.uint32)
|
154 |
+
)
|
155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
156 |
+
|
157 |
+
if flat_tid < world_size:
|
158 |
+
send_signal(send_addrs, sem)
|
159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from .triton_barrier import blockwise_barrier
|
14 |
+
from .triton_utils import sync_threads
|
15 |
+
|
16 |
+
|
17 |
+
@triton.jit
|
18 |
+
def _exchange_row_offsets(
|
19 |
+
split_sizes_ptrs,
|
20 |
+
rank: tl.constexpr,
|
21 |
+
world_size: tl.constexpr,
|
22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
23 |
+
):
|
24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
25 |
+
|
26 |
+
# split_sizes_ptr for all ranks
|
27 |
+
# All these vector stacks into split_sizes_matrix
|
28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
29 |
+
|
30 |
+
# split_sizes_matrix[remote_rank, :]
|
31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
32 |
+
tl.pointer_type(tl.int64)
|
33 |
+
)
|
34 |
+
|
35 |
+
offsets_ = tl.arange(0, world_size)
|
36 |
+
input_split_sizes = tl.load(
|
37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
38 |
+
)
|
39 |
+
|
40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
42 |
+
|
43 |
+
# split_sizes_matrix[:, rank]
|
44 |
+
output_split_sizes_ptrs = (
|
45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
46 |
+
)
|
47 |
+
output_split_sizes = tl.load(
|
48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
49 |
+
)
|
50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
51 |
+
|
52 |
+
return input_row_offset, output_row_offset, num_rows
|
53 |
+
|
54 |
+
|
55 |
+
@triton.jit
|
56 |
+
def on_device_all_to_all_v_kernel(
|
57 |
+
output_ptr,
|
58 |
+
output_splits_ptr,
|
59 |
+
input_ptrs,
|
60 |
+
input_splits_ptr,
|
61 |
+
signal_pad_ptrs,
|
62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
63 |
+
rank: tl.constexpr,
|
64 |
+
world_size: tl.constexpr,
|
65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
66 |
+
UNROLL_FACTOR: tl.constexpr,
|
67 |
+
BLOCK_SIZE: tl.constexpr,
|
68 |
+
):
|
69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
70 |
+
sync_threads()
|
71 |
+
|
72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
74 |
+
|
75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
77 |
+
)
|
78 |
+
|
79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
80 |
+
if block_offset == 0:
|
81 |
+
# Update output_splits
|
82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
83 |
+
|
84 |
+
input_ptr = (
|
85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
86 |
+
tl.pointer_type(tl.bfloat16)
|
87 |
+
)
|
88 |
+
+ input_row_offset * dim
|
89 |
+
)
|
90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
91 |
+
|
92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
95 |
+
)
|
96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
97 |
+
offset = numel_per_rank * block_offset
|
98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
99 |
+
|
100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
102 |
+
datas = []
|
103 |
+
for j in tl.range(
|
104 |
+
i,
|
105 |
+
i + outer_loop_step,
|
106 |
+
BLOCK_SIZE,
|
107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
108 |
+
):
|
109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
110 |
+
data = tl.load(input_ptr + offsets)
|
111 |
+
tl.store(output_ptr + offsets, data)
|
112 |
+
|
113 |
+
offset += unroll_region_size
|
114 |
+
while offset < end:
|
115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
116 |
+
mask = offsets < num_rows * dim
|
117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
119 |
+
offset += BLOCK_SIZE
|
120 |
+
|
121 |
+
sync_threads()
|
122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
def _on_device_all_to_all_v(
|
127 |
+
output: torch.Tensor,
|
128 |
+
output_splits: torch.Tensor,
|
129 |
+
input: torch.Tensor,
|
130 |
+
input_splits: torch.Tensor,
|
131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
133 |
+
UNROLL_FACTOR: int = 8,
|
134 |
+
BLOCK_SIZE: int = 16384,
|
135 |
+
):
|
136 |
+
assert output.dim() == 2, f"{output.shape}"
|
137 |
+
assert input.dim() == 2, f"{input.shape}"
|
138 |
+
assert output.shape[1] == input.shape[1]
|
139 |
+
|
140 |
+
dim = output.shape[1]
|
141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
143 |
+
|
144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
146 |
+
output,
|
147 |
+
output_splits,
|
148 |
+
input_hdl.buffer_ptrs_dev,
|
149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
150 |
+
input_hdl.signal_pad_ptrs_dev,
|
151 |
+
dim=dim,
|
152 |
+
rank=input_hdl.rank,
|
153 |
+
world_size=input_hdl.world_size,
|
154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
157 |
+
num_warps=16,
|
158 |
+
)
|
159 |
+
# log_triton_kernel(kernel)
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
164 |
+
# A symmetric memory holding the grad_output during backward
|
165 |
+
grad_output_buf = None
|
166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
167 |
+
splits_buf = None
|
168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
169 |
+
max_output_len = None
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def forward(
|
173 |
+
ctx,
|
174 |
+
input: torch.Tensor,
|
175 |
+
input_splits: torch.Tensor,
|
176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
input: input tensor with data for all ranks concatenated.
|
181 |
+
input_splits: input splits of shape (group.world_size,)
|
182 |
+
group: process group to scope the collective.
|
183 |
+
"""
|
184 |
+
# Initialize input splits buffer (one time only)
|
185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
187 |
+
*input_splits.shape,
|
188 |
+
dtype=input_splits.dtype,
|
189 |
+
device=input_splits.device,
|
190 |
+
)
|
191 |
+
|
192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
193 |
+
raise RuntimeError(
|
194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
195 |
+
)
|
196 |
+
|
197 |
+
# Allocate output buffer
|
198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
199 |
+
# Allocate output splits tensor
|
200 |
+
output_splits = torch.empty_like(input_splits)
|
201 |
+
# Copy input splits to the buffer
|
202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
203 |
+
|
204 |
+
# Shuffle input to output
|
205 |
+
_on_device_all_to_all_v(
|
206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
207 |
+
)
|
208 |
+
|
209 |
+
# Output splits in forward is the input splits in backward
|
210 |
+
ctx.save_for_backward(output_splits)
|
211 |
+
ctx.group = group
|
212 |
+
ctx.input_shape = input.shape
|
213 |
+
return output, output_splits
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def backward(ctx, grad_output, grad_splits):
|
217 |
+
"""
|
218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
219 |
+
Args:
|
220 |
+
`grad_output`: output's gradients passed from the downstream.
|
221 |
+
`grad_splits`: unused.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# Initialize grad_output buffer (one time only)
|
225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
226 |
+
assert (
|
227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
228 |
+
), "`max_output_len` not set"
|
229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
230 |
+
OnDeviceAllToAllV.max_output_len,
|
231 |
+
*grad_output.shape[1:],
|
232 |
+
dtype=grad_output.dtype,
|
233 |
+
device=grad_output.device,
|
234 |
+
)
|
235 |
+
|
236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
237 |
+
# our symm_mem buffer?
|
238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
239 |
+
grad_output
|
240 |
+
)
|
241 |
+
|
242 |
+
# Size info
|
243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
247 |
+
|
248 |
+
# Shuffle gradients back to the input
|
249 |
+
_on_device_all_to_all_v(
|
250 |
+
grad_input,
|
251 |
+
grad_input_splits,
|
252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
253 |
+
OnDeviceAllToAllV.splits_buf,
|
254 |
+
group=ctx.group,
|
255 |
+
)
|
256 |
+
return grad_input, None, None
|
257 |
+
|
258 |
+
|
259 |
+
# Alias
|
260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/deepseek_v3/train.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# torchrun --standalone --nproc-per-node 8 run.py
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from checkpoint import load_weights_from_hf
|
11 |
+
from model import DeepseekForCausalLM
|
12 |
+
from model_config import deepseek_config_registry
|
13 |
+
|
14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
15 |
+
from torch.distributed.fsdp import fully_shard
|
16 |
+
from torch.distributed.pipelining import PipelineStage, Schedule1F1B
|
17 |
+
|
18 |
+
|
19 |
+
# Use DeepSeek-V2-Lite as a proxy
|
20 |
+
model_id = "deepseek-ai/DeepSeek-V2-Lite"
|
21 |
+
|
22 |
+
|
23 |
+
# Run full model
|
24 |
+
def run_full_model(
|
25 |
+
mesh: DeviceMesh,
|
26 |
+
):
|
27 |
+
rank = dist.get_rank()
|
28 |
+
device_count = torch.cuda.device_count()
|
29 |
+
device = torch.device("cuda", rank % device_count)
|
30 |
+
|
31 |
+
pp_mesh = mesh["pp"]
|
32 |
+
ep_mesh = mesh["ep"]
|
33 |
+
pp_rank = pp_mesh.get_local_rank()
|
34 |
+
ep_rank = ep_mesh.get_local_rank()
|
35 |
+
pp_size = pp_mesh.size()
|
36 |
+
ep_size = ep_mesh.size()
|
37 |
+
|
38 |
+
# Get model configs
|
39 |
+
model_args = deepseek_config_registry[model_id]
|
40 |
+
# [Note]: I am making the model smaller for testing / avoiding OOM. If you
|
41 |
+
# have sufficient GPUs for model parallelism, you can remove this line.
|
42 |
+
model_args.num_hidden_layers = 16
|
43 |
+
|
44 |
+
# Apply model parallelism
|
45 |
+
model_args.ep_size = ep_size
|
46 |
+
model_args.num_stages = pp_size
|
47 |
+
model_args.stage_idx = pp_rank
|
48 |
+
print(model_args)
|
49 |
+
|
50 |
+
# Instantiate model
|
51 |
+
with device, mesh:
|
52 |
+
model = DeepseekForCausalLM(model_args)
|
53 |
+
|
54 |
+
# Load weights
|
55 |
+
load_weights_from_hf(model, model_id, device)
|
56 |
+
model.train()
|
57 |
+
|
58 |
+
# Apply data parallelism
|
59 |
+
fsdp_mesh = mesh["fsdp"]
|
60 |
+
hsdp_mesh = mesh["ep", "fsdp"]
|
61 |
+
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
|
62 |
+
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
|
63 |
+
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
|
64 |
+
# it will be ineconomical re-gather the weights.
|
65 |
+
for layer in model.model.layers.values():
|
66 |
+
# Apply FSDP to experts
|
67 |
+
if hasattr(layer.mlp, "experts"):
|
68 |
+
for expert in layer.mlp.experts.values():
|
69 |
+
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
|
70 |
+
# Apply HSDP to other parts such as attention, layernorm, because they
|
71 |
+
# are doing DDP on EP dimension
|
72 |
+
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
|
73 |
+
|
74 |
+
# Apply HSDP on root model (lm_head, embeddings, etc)
|
75 |
+
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
|
76 |
+
|
77 |
+
# Synthetic setting
|
78 |
+
microbatches = pp_size * 2
|
79 |
+
|
80 |
+
# Use Symmetric Memory for MoE token shuffle.
|
81 |
+
# TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
|
82 |
+
# currently supported for forward only. See `generate.py`.
|
83 |
+
# model.setup_symm_mem(torch.bfloat16, device)
|
84 |
+
|
85 |
+
# Example inputs
|
86 |
+
torch.manual_seed(ep_rank)
|
87 |
+
bs = 4
|
88 |
+
seqlen = 128
|
89 |
+
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
|
90 |
+
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
|
91 |
+
|
92 |
+
# Create loss function
|
93 |
+
loss_fn = torch.nn.functional.cross_entropy
|
94 |
+
|
95 |
+
# Run forward and backward
|
96 |
+
steps = 2
|
97 |
+
for _ in range(steps):
|
98 |
+
if pp_size > 1:
|
99 |
+
# Create pipeline stage
|
100 |
+
stage = PipelineStage(
|
101 |
+
model,
|
102 |
+
pp_rank,
|
103 |
+
pp_size,
|
104 |
+
device,
|
105 |
+
group=pp_mesh.get_group(),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Create pipeline schedule
|
109 |
+
losses = []
|
110 |
+
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
|
111 |
+
|
112 |
+
if pp_rank == 0:
|
113 |
+
y = pp_schedule.step(x)
|
114 |
+
elif pp_rank == pp_size - 1:
|
115 |
+
y = pp_schedule.step(target=label, losses=losses)
|
116 |
+
loss = torch.mean(torch.stack(losses))
|
117 |
+
else:
|
118 |
+
pp_schedule.step()
|
119 |
+
else:
|
120 |
+
y = model(x)
|
121 |
+
loss = loss_fn(y, label)
|
122 |
+
loss.backward()
|
123 |
+
|
124 |
+
if pp_rank == pp_size - 1:
|
125 |
+
print(f"logits: {y.shape}")
|
126 |
+
print(f"{loss=}")
|
127 |
+
|
128 |
+
if pp_rank == 0:
|
129 |
+
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
|
130 |
+
print(f"{torch.linalg.norm(param.grad)=}")
|
131 |
+
|
132 |
+
model.zero_grad()
|
133 |
+
|
134 |
+
print("Backward done")
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
|
139 |
+
|
140 |
+
run_full_model(mesh)
|
141 |
+
|
142 |
+
dist.destroy_process_group()
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
9 |
+
|
10 |
+
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
15 |
+
|
16 |
+
|
17 |
+
class FluxTokenizer(Tokenizer):
|
18 |
+
"""
|
19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
27 |
+
super().__init__()
|
28 |
+
self._n_words = 8 # TODO(jianiw): check
|
29 |
+
self._max_length = max_length
|
30 |
+
|
31 |
+
self.is_clip = model_path.startswith("openai")
|
32 |
+
|
33 |
+
if self.is_clip:
|
34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
35 |
+
model_path, max_length=max_length
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
39 |
+
model_path, max_length=max_length
|
40 |
+
)
|
41 |
+
|
42 |
+
def encode(
|
43 |
+
self,
|
44 |
+
s: str,
|
45 |
+
) -> List[int]:
|
46 |
+
"""
|
47 |
+
Encode the prompt text into tokens.
|
48 |
+
"""
|
49 |
+
tokens = self._tokenizer(
|
50 |
+
s,
|
51 |
+
truncation=True,
|
52 |
+
max_length=self._max_length,
|
53 |
+
return_length=False,
|
54 |
+
return_overflowing_tokens=False,
|
55 |
+
padding="max_length",
|
56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
57 |
+
)["input_ids"]
|
58 |
+
return tokens
|
59 |
+
|
60 |
+
def decode(self, t: List[int]) -> str:
|
61 |
+
"""
|
62 |
+
Decode function. This function will not be called.
|
63 |
+
"""
|
64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/flux_argparser.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def extend_parser(parser: argparse.ArgumentParser) -> None:
|
13 |
+
parser.add_argument(
|
14 |
+
"--training.guidance",
|
15 |
+
type=float,
|
16 |
+
default=3.5,
|
17 |
+
help="guidance value used for guidance distillation",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--encoder.t5_encoder",
|
21 |
+
type=str,
|
22 |
+
default="google/t5-v1_1-small",
|
23 |
+
help="T5 encoder to use, HuggingFace model name.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--encoder.clip_encoder",
|
27 |
+
type=str,
|
28 |
+
default="openai/clip-vit-large-patch14",
|
29 |
+
help="Clip encoder to use, HuggingFace model name.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--encoder.encoder_dtype",
|
33 |
+
type=torch.dtype,
|
34 |
+
default=torch.bfloat16,
|
35 |
+
help="Which dtype to load for autoencoder. ",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--encoder.max_t5_encoding_len",
|
39 |
+
type=int,
|
40 |
+
default=512,
|
41 |
+
help="Maximum length of the T5 encoding.",
|
42 |
+
)
|
torchtitan/experiments/flux/model/autoencoder.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from einops import rearrange
|
12 |
+
from safetensors.torch import load_file as load_sft
|
13 |
+
from torch import nn, Tensor
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class AutoEncoderParams:
|
18 |
+
resolution: int = 256
|
19 |
+
in_channels: int = 3
|
20 |
+
ch: int = 128
|
21 |
+
out_ch: int = 3
|
22 |
+
ch_mult: tuple[int] = (1, 2, 4, 4)
|
23 |
+
num_res_blocks: int = 2
|
24 |
+
z_channels: int = 16
|
25 |
+
scale_factor: float = 0.3611
|
26 |
+
shift_factor: float = 0.1159
|
27 |
+
|
28 |
+
|
29 |
+
def swish(x: Tensor) -> Tensor:
|
30 |
+
return x * torch.sigmoid(x)
|
31 |
+
|
32 |
+
|
33 |
+
class AttnBlock(nn.Module):
|
34 |
+
def __init__(self, in_channels: int):
|
35 |
+
super().__init__()
|
36 |
+
self.in_channels = in_channels
|
37 |
+
|
38 |
+
self.norm = nn.GroupNorm(
|
39 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
40 |
+
)
|
41 |
+
|
42 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
43 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
44 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
45 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
46 |
+
|
47 |
+
def attention(self, h_: Tensor) -> Tensor:
|
48 |
+
h_ = self.norm(h_)
|
49 |
+
q = self.q(h_)
|
50 |
+
k = self.k(h_)
|
51 |
+
v = self.v(h_)
|
52 |
+
|
53 |
+
b, c, h, w = q.shape
|
54 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
55 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
56 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
57 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
58 |
+
|
59 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
60 |
+
|
61 |
+
def forward(self, x: Tensor) -> Tensor:
|
62 |
+
return x + self.proj_out(self.attention(x))
|
63 |
+
|
64 |
+
|
65 |
+
class ResnetBlock(nn.Module):
|
66 |
+
def __init__(self, in_channels: int, out_channels: int):
|
67 |
+
super().__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
70 |
+
self.out_channels = out_channels
|
71 |
+
|
72 |
+
self.norm1 = nn.GroupNorm(
|
73 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
74 |
+
)
|
75 |
+
self.conv1 = nn.Conv2d(
|
76 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
77 |
+
)
|
78 |
+
self.norm2 = nn.GroupNorm(
|
79 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
80 |
+
)
|
81 |
+
self.conv2 = nn.Conv2d(
|
82 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
83 |
+
)
|
84 |
+
if self.in_channels != self.out_channels:
|
85 |
+
self.nin_shortcut = nn.Conv2d(
|
86 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
h = x
|
91 |
+
h = self.norm1(h)
|
92 |
+
h = swish(h)
|
93 |
+
h = self.conv1(h)
|
94 |
+
|
95 |
+
h = self.norm2(h)
|
96 |
+
h = swish(h)
|
97 |
+
h = self.conv2(h)
|
98 |
+
|
99 |
+
if self.in_channels != self.out_channels:
|
100 |
+
x = self.nin_shortcut(x)
|
101 |
+
|
102 |
+
return x + h
|
103 |
+
|
104 |
+
|
105 |
+
class Downsample(nn.Module):
|
106 |
+
def __init__(self, in_channels: int):
|
107 |
+
super().__init__()
|
108 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
109 |
+
self.conv = nn.Conv2d(
|
110 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x: Tensor):
|
114 |
+
pad = (0, 1, 0, 1)
|
115 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
116 |
+
x = self.conv(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class Upsample(nn.Module):
|
121 |
+
def __init__(self, in_channels: int):
|
122 |
+
super().__init__()
|
123 |
+
self.conv = nn.Conv2d(
|
124 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x: Tensor):
|
128 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
129 |
+
x = self.conv(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class Encoder(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
resolution: int,
|
137 |
+
in_channels: int,
|
138 |
+
ch: int,
|
139 |
+
ch_mult: list[int],
|
140 |
+
num_res_blocks: int,
|
141 |
+
z_channels: int,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.ch = ch
|
145 |
+
self.num_resolutions = len(ch_mult)
|
146 |
+
self.num_res_blocks = num_res_blocks
|
147 |
+
self.resolution = resolution
|
148 |
+
self.in_channels = in_channels
|
149 |
+
# downsampling
|
150 |
+
self.conv_in = nn.Conv2d(
|
151 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
152 |
+
)
|
153 |
+
|
154 |
+
curr_res = resolution
|
155 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
156 |
+
self.in_ch_mult = in_ch_mult
|
157 |
+
self.down = nn.ModuleList()
|
158 |
+
block_in = self.ch
|
159 |
+
for i_level in range(self.num_resolutions):
|
160 |
+
block = nn.ModuleList()
|
161 |
+
attn = nn.ModuleList()
|
162 |
+
block_in = ch * in_ch_mult[i_level]
|
163 |
+
block_out = ch * ch_mult[i_level]
|
164 |
+
for _ in range(self.num_res_blocks):
|
165 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
166 |
+
block_in = block_out
|
167 |
+
down = nn.Module()
|
168 |
+
down.block = block
|
169 |
+
down.attn = attn
|
170 |
+
if i_level != self.num_resolutions - 1:
|
171 |
+
down.downsample = Downsample(block_in)
|
172 |
+
curr_res = curr_res // 2
|
173 |
+
self.down.append(down)
|
174 |
+
|
175 |
+
# middle
|
176 |
+
self.mid = nn.Module()
|
177 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
178 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
179 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
180 |
+
|
181 |
+
# end
|
182 |
+
self.norm_out = nn.GroupNorm(
|
183 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
184 |
+
)
|
185 |
+
self.conv_out = nn.Conv2d(
|
186 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(self, x: Tensor) -> Tensor:
|
190 |
+
# downsampling
|
191 |
+
hs = [self.conv_in(x)]
|
192 |
+
for i_level in range(self.num_resolutions):
|
193 |
+
for i_block in range(self.num_res_blocks):
|
194 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
195 |
+
if len(self.down[i_level].attn) > 0:
|
196 |
+
h = self.down[i_level].attn[i_block](h)
|
197 |
+
hs.append(h)
|
198 |
+
if i_level != self.num_resolutions - 1:
|
199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
200 |
+
|
201 |
+
# middle
|
202 |
+
h = hs[-1]
|
203 |
+
h = self.mid.block_1(h)
|
204 |
+
h = self.mid.attn_1(h)
|
205 |
+
h = self.mid.block_2(h)
|
206 |
+
# end
|
207 |
+
h = self.norm_out(h)
|
208 |
+
h = swish(h)
|
209 |
+
h = self.conv_out(h)
|
210 |
+
return h
|
211 |
+
|
212 |
+
|
213 |
+
class Decoder(nn.Module):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
ch: int,
|
217 |
+
out_ch: int,
|
218 |
+
ch_mult: list[int],
|
219 |
+
num_res_blocks: int,
|
220 |
+
in_channels: int,
|
221 |
+
resolution: int,
|
222 |
+
z_channels: int,
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
self.ch = ch
|
226 |
+
self.num_resolutions = len(ch_mult)
|
227 |
+
self.num_res_blocks = num_res_blocks
|
228 |
+
self.resolution = resolution
|
229 |
+
self.in_channels = in_channels
|
230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
231 |
+
|
232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
236 |
+
|
237 |
+
# z to block_in
|
238 |
+
self.conv_in = nn.Conv2d(
|
239 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
240 |
+
)
|
241 |
+
|
242 |
+
# middle
|
243 |
+
self.mid = nn.Module()
|
244 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
245 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
246 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
247 |
+
|
248 |
+
# upsampling
|
249 |
+
self.up = nn.ModuleList()
|
250 |
+
for i_level in reversed(range(self.num_resolutions)):
|
251 |
+
block = nn.ModuleList()
|
252 |
+
attn = nn.ModuleList()
|
253 |
+
block_out = ch * ch_mult[i_level]
|
254 |
+
for _ in range(self.num_res_blocks + 1):
|
255 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
256 |
+
block_in = block_out
|
257 |
+
up = nn.Module()
|
258 |
+
up.block = block
|
259 |
+
up.attn = attn
|
260 |
+
if i_level != 0:
|
261 |
+
up.upsample = Upsample(block_in)
|
262 |
+
curr_res = curr_res * 2
|
263 |
+
self.up.insert(0, up) # prepend to get consistent order
|
264 |
+
|
265 |
+
# end
|
266 |
+
self.norm_out = nn.GroupNorm(
|
267 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
268 |
+
)
|
269 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
270 |
+
|
271 |
+
def forward(self, z: Tensor) -> Tensor:
|
272 |
+
# get dtype for proper tracing
|
273 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
274 |
+
|
275 |
+
# z to block_in
|
276 |
+
h = self.conv_in(z)
|
277 |
+
|
278 |
+
# middle
|
279 |
+
h = self.mid.block_1(h)
|
280 |
+
h = self.mid.attn_1(h)
|
281 |
+
h = self.mid.block_2(h)
|
282 |
+
|
283 |
+
# cast to proper dtype
|
284 |
+
h = h.to(upscale_dtype)
|
285 |
+
# upsampling
|
286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
287 |
+
for i_block in range(self.num_res_blocks + 1):
|
288 |
+
h = self.up[i_level].block[i_block](h)
|
289 |
+
if len(self.up[i_level].attn) > 0:
|
290 |
+
h = self.up[i_level].attn[i_block](h)
|
291 |
+
if i_level != 0:
|
292 |
+
h = self.up[i_level].upsample(h)
|
293 |
+
|
294 |
+
# end
|
295 |
+
h = self.norm_out(h)
|
296 |
+
h = swish(h)
|
297 |
+
h = self.conv_out(h)
|
298 |
+
return h
|
299 |
+
|
300 |
+
|
301 |
+
class DiagonalGaussian(nn.Module):
|
302 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
303 |
+
super().__init__()
|
304 |
+
self.sample = sample
|
305 |
+
self.chunk_dim = chunk_dim
|
306 |
+
|
307 |
+
def forward(self, z: Tensor) -> Tensor:
|
308 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
309 |
+
if self.sample:
|
310 |
+
std = torch.exp(0.5 * logvar)
|
311 |
+
return mean + std * torch.randn_like(mean)
|
312 |
+
else:
|
313 |
+
return mean
|
314 |
+
|
315 |
+
|
316 |
+
class AutoEncoder(nn.Module):
|
317 |
+
def __init__(self, params: AutoEncoderParams):
|
318 |
+
super().__init__()
|
319 |
+
self.params = params
|
320 |
+
self.encoder = Encoder(
|
321 |
+
resolution=params.resolution,
|
322 |
+
in_channels=params.in_channels,
|
323 |
+
ch=params.ch,
|
324 |
+
ch_mult=params.ch_mult,
|
325 |
+
num_res_blocks=params.num_res_blocks,
|
326 |
+
z_channels=params.z_channels,
|
327 |
+
)
|
328 |
+
self.decoder = Decoder(
|
329 |
+
resolution=params.resolution,
|
330 |
+
in_channels=params.in_channels,
|
331 |
+
ch=params.ch,
|
332 |
+
out_ch=params.out_ch,
|
333 |
+
ch_mult=params.ch_mult,
|
334 |
+
num_res_blocks=params.num_res_blocks,
|
335 |
+
z_channels=params.z_channels,
|
336 |
+
)
|
337 |
+
self.reg = DiagonalGaussian()
|
338 |
+
|
339 |
+
self.scale_factor = params.scale_factor
|
340 |
+
self.shift_factor = params.shift_factor
|
341 |
+
|
342 |
+
def encode(self, x: Tensor) -> Tensor:
|
343 |
+
z = self.reg(self.encoder(x))
|
344 |
+
z = self.scale_factor * (z - self.shift_factor)
|
345 |
+
return z
|
346 |
+
|
347 |
+
def decode(self, z: Tensor) -> Tensor:
|
348 |
+
z = z / self.scale_factor + self.shift_factor
|
349 |
+
return self.decoder(z)
|
350 |
+
|
351 |
+
def forward(self, x: Tensor) -> Tensor:
|
352 |
+
return self.decode(self.encode(x))
|
353 |
+
|
354 |
+
|
355 |
+
def load_ae(
|
356 |
+
ckpt_path: str,
|
357 |
+
autoencoder_params: AutoEncoderParams,
|
358 |
+
device: str | torch.device = "cuda",
|
359 |
+
dtype=torch.bfloat16,
|
360 |
+
) -> AutoEncoder:
|
361 |
+
"""
|
362 |
+
Load the autoencoder from the given model name.
|
363 |
+
Args:
|
364 |
+
name (str): The name of the autoencoder.
|
365 |
+
device (str or torch.device): The device to load the autoencoder to.
|
366 |
+
Returns:
|
367 |
+
AutoEncoder: The loaded autoencoder.
|
368 |
+
"""
|
369 |
+
# Loading the autoencoder
|
370 |
+
print("Init AE")
|
371 |
+
with torch.device(device):
|
372 |
+
ae = AutoEncoder(autoencoder_params)
|
373 |
+
|
374 |
+
if not os.path.exists(ckpt_path):
|
375 |
+
raise ValueError(
|
376 |
+
f"Autoencoder path {ckpt_path} does not exist. Please download it first."
|
377 |
+
)
|
378 |
+
|
379 |
+
if ckpt_path is not None:
|
380 |
+
sd = load_sft(ckpt_path, device=str(device))
|
381 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
382 |
+
if len(missing) > 0:
|
383 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
384 |
+
if len(unexpected) > 0:
|
385 |
+
print(
|
386 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
387 |
+
)
|
388 |
+
return ae.to(dtype=dtype)
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from torch import nn, Tensor
|
8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
9 |
+
|
10 |
+
|
11 |
+
class FluxEmbedder(nn.Module):
|
12 |
+
def __init__(self, version: str, **hf_kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.is_clip = version.startswith("openai")
|
15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
16 |
+
|
17 |
+
if self.is_clip:
|
18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
19 |
+
version, **hf_kwargs
|
20 |
+
)
|
21 |
+
else:
|
22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
23 |
+
version, **hf_kwargs
|
24 |
+
)
|
25 |
+
|
26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
27 |
+
|
28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
29 |
+
"""
|
30 |
+
batch_tokens: [bsz, embedding_length]
|
31 |
+
|
32 |
+
For T5 Encoder, embeding_length is 768
|
33 |
+
For CLIP, embedding_length is 256
|
34 |
+
"""
|
35 |
+
outputs = self.hf_module(
|
36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
37 |
+
attention_mask=None,
|
38 |
+
output_hidden_states=False,
|
39 |
+
)
|
40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/layers.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# imported from black-forest-labs/FLUX
|
8 |
+
import math
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from einops import rearrange
|
13 |
+
from torch import nn, Tensor
|
14 |
+
|
15 |
+
from torchtitan.experiments.flux.model.math import attention, rope
|
16 |
+
|
17 |
+
|
18 |
+
class EmbedND(nn.Module):
|
19 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
20 |
+
super().__init__()
|
21 |
+
self.dim = dim
|
22 |
+
self.theta = theta
|
23 |
+
self.axes_dim = axes_dim
|
24 |
+
|
25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
26 |
+
n_axes = ids.shape[-1]
|
27 |
+
emb = torch.cat(
|
28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
29 |
+
dim=-3,
|
30 |
+
)
|
31 |
+
|
32 |
+
return emb.unsqueeze(1)
|
33 |
+
|
34 |
+
|
35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
36 |
+
"""
|
37 |
+
Create sinusoidal timestep embeddings.
|
38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
39 |
+
These may be fractional.
|
40 |
+
:param dim: the dimension of the output.
|
41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
43 |
+
"""
|
44 |
+
t = time_factor * t
|
45 |
+
half = dim // 2
|
46 |
+
freqs = torch.exp(
|
47 |
+
-math.log(max_period)
|
48 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
49 |
+
/ half
|
50 |
+
).to(t.device)
|
51 |
+
|
52 |
+
args = t[:, None].float() * freqs[None]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if dim % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
if torch.is_floating_point(t):
|
57 |
+
embedding = embedding.to(t)
|
58 |
+
return embedding
|
59 |
+
|
60 |
+
|
61 |
+
class MLPEmbedder(nn.Module):
|
62 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
63 |
+
super().__init__()
|
64 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
65 |
+
self.silu = nn.SiLU()
|
66 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
70 |
+
|
71 |
+
|
72 |
+
class RMSNorm(torch.nn.Module):
|
73 |
+
def __init__(self, dim: int):
|
74 |
+
super().__init__()
|
75 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
76 |
+
|
77 |
+
def forward(self, x: Tensor):
|
78 |
+
x_dtype = x.dtype
|
79 |
+
x = x.float()
|
80 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
81 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
82 |
+
|
83 |
+
|
84 |
+
class QKNorm(torch.nn.Module):
|
85 |
+
def __init__(self, dim: int):
|
86 |
+
super().__init__()
|
87 |
+
self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
|
88 |
+
self.key_norm = RMSNorm(dim)
|
89 |
+
|
90 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
91 |
+
q = self.query_norm(q)
|
92 |
+
k = self.key_norm(k)
|
93 |
+
return q.to(v), k.to(v)
|
94 |
+
|
95 |
+
|
96 |
+
class SelfAttention(nn.Module):
|
97 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
98 |
+
super().__init__()
|
99 |
+
self.num_heads = num_heads
|
100 |
+
head_dim = dim // num_heads
|
101 |
+
|
102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
103 |
+
self.norm = QKNorm(head_dim)
|
104 |
+
self.proj = nn.Linear(dim, dim)
|
105 |
+
|
106 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
107 |
+
qkv = self.qkv(x)
|
108 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
109 |
+
q, k = self.norm(q, k, v)
|
110 |
+
x = attention(q, k, v, pe=pe)
|
111 |
+
x = self.proj(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
@dataclass
|
116 |
+
class ModulationOut:
|
117 |
+
shift: Tensor
|
118 |
+
scale: Tensor
|
119 |
+
gate: Tensor
|
120 |
+
|
121 |
+
|
122 |
+
class Modulation(nn.Module):
|
123 |
+
def __init__(self, dim: int, double: bool):
|
124 |
+
super().__init__()
|
125 |
+
self.is_double = double
|
126 |
+
self.multiplier = 6 if double else 3
|
127 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
128 |
+
|
129 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
130 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
131 |
+
self.multiplier, dim=-1
|
132 |
+
)
|
133 |
+
|
134 |
+
return (
|
135 |
+
ModulationOut(*out[:3]),
|
136 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
class DoubleStreamBlock(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
147 |
+
self.num_heads = num_heads
|
148 |
+
self.hidden_size = hidden_size
|
149 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
150 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
151 |
+
self.img_attn = SelfAttention(
|
152 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
153 |
+
)
|
154 |
+
|
155 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
156 |
+
self.img_mlp = nn.Sequential(
|
157 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
158 |
+
nn.GELU(approximate="tanh"),
|
159 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
160 |
+
)
|
161 |
+
|
162 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
163 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
164 |
+
self.txt_attn = SelfAttention(
|
165 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
166 |
+
)
|
167 |
+
|
168 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
169 |
+
self.txt_mlp = nn.Sequential(
|
170 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
171 |
+
nn.GELU(approximate="tanh"),
|
172 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(
|
176 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
177 |
+
) -> tuple[Tensor, Tensor]:
|
178 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
179 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
180 |
+
|
181 |
+
# prepare image for attention
|
182 |
+
img_modulated = self.img_norm1(img)
|
183 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
184 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
185 |
+
img_q, img_k, img_v = rearrange(
|
186 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
187 |
+
)
|
188 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
189 |
+
|
190 |
+
# prepare txt for attention
|
191 |
+
txt_modulated = self.txt_norm1(txt)
|
192 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
193 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
194 |
+
txt_q, txt_k, txt_v = rearrange(
|
195 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
196 |
+
)
|
197 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
198 |
+
|
199 |
+
# run actual attention
|
200 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
201 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
202 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
203 |
+
|
204 |
+
attn = attention(q, k, v, pe=pe)
|
205 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
206 |
+
|
207 |
+
# calculate the img bloks
|
208 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
209 |
+
img = img + img_mod2.gate * self.img_mlp(
|
210 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
211 |
+
)
|
212 |
+
|
213 |
+
# calculate the txt bloks
|
214 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
215 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
216 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
217 |
+
)
|
218 |
+
return img, txt
|
219 |
+
|
220 |
+
|
221 |
+
class SingleStreamBlock(nn.Module):
|
222 |
+
"""
|
223 |
+
A DiT block with parallel linear layers as described in
|
224 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
hidden_size: int,
|
230 |
+
num_heads: int,
|
231 |
+
mlp_ratio: float = 4.0,
|
232 |
+
qk_scale: float | None = None,
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
self.hidden_dim = hidden_size
|
236 |
+
self.num_heads = num_heads
|
237 |
+
head_dim = hidden_size // num_heads
|
238 |
+
self.scale = qk_scale or head_dim**-0.5
|
239 |
+
|
240 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
241 |
+
# qkv and mlp_in
|
242 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
243 |
+
# proj and mlp_out
|
244 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
245 |
+
|
246 |
+
self.norm = QKNorm(head_dim)
|
247 |
+
|
248 |
+
self.hidden_size = hidden_size
|
249 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
250 |
+
|
251 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
252 |
+
self.modulation = Modulation(hidden_size, double=False)
|
253 |
+
|
254 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
255 |
+
mod, _ = self.modulation(vec)
|
256 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
257 |
+
qkv, mlp = torch.split(
|
258 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
259 |
+
)
|
260 |
+
|
261 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
262 |
+
q, k = self.norm(q, k, v)
|
263 |
+
|
264 |
+
# compute attention
|
265 |
+
attn = attention(q, k, v, pe=pe)
|
266 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
267 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
268 |
+
return x + mod.gate * output
|
269 |
+
|
270 |
+
|
271 |
+
class LastLayer(nn.Module):
|
272 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
273 |
+
super().__init__()
|
274 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
275 |
+
self.linear = nn.Linear(
|
276 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
277 |
+
)
|
278 |
+
self.adaLN_modulation = nn.Sequential(
|
279 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
280 |
+
)
|
281 |
+
|
282 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
283 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
284 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
285 |
+
x = self.linear(x)
|
286 |
+
return x
|
torchtitan/experiments/flux/model/model.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
13 |
+
from torchtitan.config_manager import JobConfig
|
14 |
+
|
15 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
16 |
+
from torchtitan.experiments.flux.model.layers import (
|
17 |
+
DoubleStreamBlock,
|
18 |
+
EmbedND,
|
19 |
+
LastLayer,
|
20 |
+
MLPEmbedder,
|
21 |
+
SingleStreamBlock,
|
22 |
+
timestep_embedding,
|
23 |
+
)
|
24 |
+
|
25 |
+
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
|
26 |
+
from torchtitan.tools.logging import logger
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class FluxModelArgs(BaseModelArgs):
|
31 |
+
in_channels: int = 64
|
32 |
+
out_channels: int = 64
|
33 |
+
vec_in_dim: int = 768
|
34 |
+
context_in_dim: int = 512
|
35 |
+
hidden_size: int = 3072
|
36 |
+
mlp_ratio: float = 4.0
|
37 |
+
num_heads: int = 24
|
38 |
+
depth: int = 19
|
39 |
+
depth_single_blocks: int = 38
|
40 |
+
axes_dim: tuple = (16, 56, 56)
|
41 |
+
theta: int = 10_000
|
42 |
+
qkv_bias: bool = True
|
43 |
+
guidance_embed: bool = True
|
44 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
45 |
+
|
46 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
47 |
+
# context_in_dim is the same as the T5 embedding dimension
|
48 |
+
self.context_in_dim = job_config.encoder.max_t5_encoding_len
|
49 |
+
|
50 |
+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
|
51 |
+
# TODO(jianiw): Add the number of flops for the autoencoder
|
52 |
+
nparams = sum(p.numel() for p in model.parameters())
|
53 |
+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
|
54 |
+
return nparams, 1
|
55 |
+
|
56 |
+
|
57 |
+
class FluxModel(nn.Module, ModelProtocol):
|
58 |
+
"""
|
59 |
+
Transformer model for flow matching on sequences.
|
60 |
+
|
61 |
+
Agrs:
|
62 |
+
model_args: FluxModelArgs.
|
63 |
+
|
64 |
+
Attributes:
|
65 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, model_args: FluxModelArgs):
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.model_args = model_args
|
72 |
+
self.in_channels = model_args.in_channels
|
73 |
+
self.out_channels = model_args.out_channels
|
74 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
75 |
+
raise ValueError(
|
76 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
77 |
+
)
|
78 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
79 |
+
if sum(model_args.axes_dim) != pe_dim:
|
80 |
+
raise ValueError(
|
81 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
82 |
+
)
|
83 |
+
self.hidden_size = model_args.hidden_size
|
84 |
+
self.num_heads = model_args.num_heads
|
85 |
+
self.pe_embedder = EmbedND(
|
86 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
87 |
+
)
|
88 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
89 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
90 |
+
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
|
91 |
+
self.guidance_in = (
|
92 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
93 |
+
if model_args.guidance_embed
|
94 |
+
else nn.Identity()
|
95 |
+
)
|
96 |
+
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
|
97 |
+
|
98 |
+
self.double_blocks = nn.ModuleList(
|
99 |
+
[
|
100 |
+
DoubleStreamBlock(
|
101 |
+
self.hidden_size,
|
102 |
+
self.num_heads,
|
103 |
+
mlp_ratio=model_args.mlp_ratio,
|
104 |
+
qkv_bias=model_args.qkv_bias,
|
105 |
+
)
|
106 |
+
for _ in range(model_args.depth)
|
107 |
+
]
|
108 |
+
)
|
109 |
+
|
110 |
+
self.single_blocks = nn.ModuleList(
|
111 |
+
[
|
112 |
+
SingleStreamBlock(
|
113 |
+
self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
|
114 |
+
)
|
115 |
+
for _ in range(model_args.depth_single_blocks)
|
116 |
+
]
|
117 |
+
)
|
118 |
+
|
119 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
120 |
+
|
121 |
+
def init_weights(self, buffer_device=None):
|
122 |
+
# TODO(jianiw): replace placeholder with real weight init
|
123 |
+
for param in self.parameters():
|
124 |
+
param.data.uniform_(0, 0.1)
|
125 |
+
|
126 |
+
def forward(
|
127 |
+
self,
|
128 |
+
img: Tensor,
|
129 |
+
img_ids: Tensor,
|
130 |
+
txt: Tensor,
|
131 |
+
txt_ids: Tensor,
|
132 |
+
timesteps: Tensor,
|
133 |
+
y: Tensor,
|
134 |
+
guidance: Tensor | None = None,
|
135 |
+
) -> Tensor:
|
136 |
+
if img.ndim != 3 or txt.ndim != 3:
|
137 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
138 |
+
|
139 |
+
# running on sequences img
|
140 |
+
img = self.img_in(img)
|
141 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
142 |
+
if self.model_args.guidance_embed:
|
143 |
+
if guidance is None:
|
144 |
+
raise ValueError(
|
145 |
+
"Didn't get guidance strength for guidance distilled model."
|
146 |
+
)
|
147 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
148 |
+
vec = vec + self.vector_in(y)
|
149 |
+
txt = self.txt_in(txt)
|
150 |
+
|
151 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
152 |
+
pe = self.pe_embedder(ids)
|
153 |
+
|
154 |
+
for block in self.double_blocks:
|
155 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
156 |
+
|
157 |
+
img = torch.cat((txt, img), 1)
|
158 |
+
for block in self.single_blocks:
|
159 |
+
img = block(img, vec=vec, pe=pe)
|
160 |
+
img = img[:, txt.shape[1] :, ...]
|
161 |
+
|
162 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
163 |
+
return img
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
|
167 |
+
"""
|
168 |
+
Initialize a Flux model from a FluxModelArgs object.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
model_args (FluxModelArgs): Model configuration arguments.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
FluxModel: FluxModel model.
|
175 |
+
|
176 |
+
"""
|
177 |
+
return cls(model_args)
|
torchtitan/experiments/flux/scripts/download_autoencoder.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from requests.exceptions import HTTPError
|
10 |
+
|
11 |
+
|
12 |
+
def hf_download(
|
13 |
+
repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
|
14 |
+
) -> None:
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
try:
|
18 |
+
hf_hub_download(
|
19 |
+
repo_id=repo_id,
|
20 |
+
filename=file_path,
|
21 |
+
local_dir=local_dir,
|
22 |
+
local_dir_use_symlinks=False,
|
23 |
+
token=hf_token,
|
24 |
+
)
|
25 |
+
except HTTPError as e:
|
26 |
+
if e.response.status_code == 401:
|
27 |
+
print(
|
28 |
+
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
raise e
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
import argparse
|
36 |
+
|
37 |
+
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
|
38 |
+
parser.add_argument(
|
39 |
+
"--repo_id",
|
40 |
+
type=str,
|
41 |
+
default="black-forest-labs/FLUX.1-dev",
|
42 |
+
help="Repository ID to download from. default to Flux-dev model",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--ae_path",
|
46 |
+
type=str,
|
47 |
+
default="ae.safetensors",
|
48 |
+
help="the autoencoder path relative to repo_id",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--hf_token", type=str, default=None, help="HuggingFace API token"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--local_dir",
|
55 |
+
type=str,
|
56 |
+
default="torchtitan/experiments/flux/assets/autoencoder/",
|
57 |
+
help="local directory to save the autoencoder",
|
58 |
+
)
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from torchtitan.config_manager import JobConfig
|
10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
11 |
+
from torchtitan.tools.profiling import (
|
12 |
+
maybe_enable_memory_snapshot,
|
13 |
+
maybe_enable_profiling,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TestFluxDataLoader:
|
18 |
+
def test_flux_dataloader(self):
|
19 |
+
dataset_name = "cc12m"
|
20 |
+
batch_size = 32
|
21 |
+
world_size = 4
|
22 |
+
rank = 0
|
23 |
+
|
24 |
+
num_steps = 10
|
25 |
+
|
26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
28 |
+
config = JobConfig()
|
29 |
+
config.maybe_add_custom_args()
|
30 |
+
config.parse_args(
|
31 |
+
[
|
32 |
+
# Profiling options
|
33 |
+
# "--profiling.enable_profiling",
|
34 |
+
# "--profiling.profile_freq",
|
35 |
+
# "5",
|
36 |
+
# "--profiling.enable_memory_snapshot",
|
37 |
+
# "--profiling.save_memory_snapshot_folder",
|
38 |
+
# "memory_snapshot_flux",
|
39 |
+
"--training.dataset",
|
40 |
+
dataset_name,
|
41 |
+
"--training.batch_size",
|
42 |
+
str(batch_size),
|
43 |
+
"--encoder.t5_encoder",
|
44 |
+
"google/t5-v1_1-small",
|
45 |
+
"--encoder.clip_encoder",
|
46 |
+
"openai/clip-vit-large-patch14",
|
47 |
+
"--encoder.max_t5_encoding_len",
|
48 |
+
"512",
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
with maybe_enable_profiling(
|
53 |
+
config, global_step=0
|
54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
55 |
+
config, global_step=0
|
56 |
+
) as memory_profiler:
|
57 |
+
dl = self._build_dataloader(
|
58 |
+
config,
|
59 |
+
world_size,
|
60 |
+
rank,
|
61 |
+
)
|
62 |
+
dl = iter(dl)
|
63 |
+
|
64 |
+
for i in range(0, num_steps):
|
65 |
+
input_data, labels = next(dl)
|
66 |
+
print(f"Step {i} image size: {labels.shape}")
|
67 |
+
if torch_profiler:
|
68 |
+
torch_profiler.step()
|
69 |
+
if memory_profiler:
|
70 |
+
memory_profiler.step()
|
71 |
+
|
72 |
+
print(len(input_data["clip_tokens"]))
|
73 |
+
for k, v in input_data.items():
|
74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
75 |
+
|
76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
80 |
+
|
81 |
+
if torch_profiler:
|
82 |
+
torch_profiler.step()
|
83 |
+
if memory_profiler:
|
84 |
+
memory_profiler.step(exit_ctx=True)
|
85 |
+
|
86 |
+
def test_preprocess(self):
|
87 |
+
# TODO
|
88 |
+
pass
|
89 |
+
|
90 |
+
def _build_dataloader(
|
91 |
+
self,
|
92 |
+
job_config,
|
93 |
+
world_size,
|
94 |
+
rank,
|
95 |
+
):
|
96 |
+
|
97 |
+
return build_flux_dataloader(
|
98 |
+
dp_world_size=world_size,
|
99 |
+
dp_rank=rank,
|
100 |
+
job_config=job_config,
|
101 |
+
tokenizer=None,
|
102 |
+
infinite=False,
|
103 |
+
)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from PIL import ExifTags, Image
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
20 |
+
|
21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
22 |
+
AutoEncoder,
|
23 |
+
AutoEncoderParams,
|
24 |
+
load_ae,
|
25 |
+
)
|
26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
27 |
+
|
28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
29 |
+
from torchtitan.experiments.flux.utils import (
|
30 |
+
create_position_encoding_for_latents,
|
31 |
+
generate_noise_latent,
|
32 |
+
pack_latents,
|
33 |
+
preprocess_flux_data,
|
34 |
+
unpack_latents,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lin_function(
|
43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
44 |
+
) -> Callable[[float], float]:
|
45 |
+
m = (y2 - y1) / (x2 - x1)
|
46 |
+
b = y1 - m * x1
|
47 |
+
return lambda x: m * x + b
|
48 |
+
|
49 |
+
|
50 |
+
def get_schedule(
|
51 |
+
num_steps: int,
|
52 |
+
image_seq_len: int,
|
53 |
+
base_shift: float = 0.5,
|
54 |
+
max_shift: float = 1.15,
|
55 |
+
shift: bool = True,
|
56 |
+
) -> list[float]:
|
57 |
+
# extra step for zero
|
58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
59 |
+
|
60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
61 |
+
if shift:
|
62 |
+
# estimate mu based on linear estimation between two points
|
63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
65 |
+
|
66 |
+
return timesteps.tolist()
|
67 |
+
|
68 |
+
|
69 |
+
class TestGenerateImage:
|
70 |
+
def test_generate_image(self):
|
71 |
+
"""
|
72 |
+
Run a forward pass of flux model to generate an image.
|
73 |
+
"""
|
74 |
+
name = "flux-dev"
|
75 |
+
img_width = 512
|
76 |
+
img_height = 512
|
77 |
+
seed = None
|
78 |
+
prompt = (
|
79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
81 |
+
)
|
82 |
+
device = "cuda"
|
83 |
+
num_steps = None
|
84 |
+
loop = False
|
85 |
+
guidance = 3.5
|
86 |
+
output_dir = "output"
|
87 |
+
add_sampling_metadata = True
|
88 |
+
|
89 |
+
prompt = prompt.split("|")
|
90 |
+
if len(prompt) == 1:
|
91 |
+
prompt = prompt[0]
|
92 |
+
additional_prompts = None
|
93 |
+
else:
|
94 |
+
additional_prompts = prompt[1:]
|
95 |
+
prompt = prompt[0]
|
96 |
+
|
97 |
+
assert not (
|
98 |
+
(additional_prompts is not None) and loop
|
99 |
+
), "Do not provide additional prompts and set loop to True"
|
100 |
+
|
101 |
+
torch_device = torch.device(device)
|
102 |
+
if num_steps is None:
|
103 |
+
num_steps = 30
|
104 |
+
|
105 |
+
# allow for packing and conversion to latent space
|
106 |
+
img_height = 16 * (img_height // 16)
|
107 |
+
img_width = 16 * (img_width // 16)
|
108 |
+
|
109 |
+
# init all components
|
110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
111 |
+
|
112 |
+
ae = load_ae(
|
113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
114 |
+
autoencoder_params=AutoEncoderParams(),
|
115 |
+
device=torch_device,
|
116 |
+
dtype=torch.bfloat16,
|
117 |
+
)
|
118 |
+
clip_tokenizer = FluxTokenizer(
|
119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
120 |
+
)
|
121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
123 |
+
torch_device, dtype=torch.bfloat16
|
124 |
+
)
|
125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
126 |
+
torch_device, dtype=torch.bfloat16
|
127 |
+
)
|
128 |
+
|
129 |
+
rng = torch.Generator(device="cpu")
|
130 |
+
|
131 |
+
if seed is None:
|
132 |
+
seed = rng.seed()
|
133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
134 |
+
t0 = time.perf_counter()
|
135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
136 |
+
|
137 |
+
# Tokenize the prompt, on CPU
|
138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
140 |
+
|
141 |
+
batch = preprocess_flux_data(
|
142 |
+
device=torch_device,
|
143 |
+
dtype=torch.bfloat16,
|
144 |
+
autoencoder=None,
|
145 |
+
clip_encoder=clip_encoder,
|
146 |
+
t5_encoder=t5_encoder,
|
147 |
+
batch={
|
148 |
+
"clip_tokens": clip_tokens,
|
149 |
+
"t5_tokens": t5_tokens,
|
150 |
+
},
|
151 |
+
)
|
152 |
+
|
153 |
+
img = self._generate_images(
|
154 |
+
device=torch_device,
|
155 |
+
dtype=torch.bfloat16,
|
156 |
+
model=model,
|
157 |
+
decoder=ae,
|
158 |
+
img_width=img_width,
|
159 |
+
img_height=img_height,
|
160 |
+
denoising_steps=num_steps,
|
161 |
+
seed=seed,
|
162 |
+
clip_encodings=batch["clip_encodings"],
|
163 |
+
t5_encodings=batch["t5_encodings"],
|
164 |
+
guidance=guidance,
|
165 |
+
)
|
166 |
+
|
167 |
+
if torch.cuda.is_available():
|
168 |
+
torch.cuda.synchronize()
|
169 |
+
t1 = time.perf_counter()
|
170 |
+
|
171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
172 |
+
|
173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
174 |
+
|
175 |
+
def _generate_images(
|
176 |
+
self,
|
177 |
+
device: torch.device,
|
178 |
+
dtype: torch.dtype,
|
179 |
+
model: FluxModel,
|
180 |
+
decoder: AutoEncoder,
|
181 |
+
# image params:
|
182 |
+
img_width: int,
|
183 |
+
img_height: int,
|
184 |
+
# sampling params:
|
185 |
+
denoising_steps: int,
|
186 |
+
seed: int,
|
187 |
+
clip_encodings: torch.Tensor,
|
188 |
+
t5_encodings: torch.Tensor,
|
189 |
+
guidance: float = 4.0,
|
190 |
+
):
|
191 |
+
|
192 |
+
bsz = clip_encodings.shape[0]
|
193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
195 |
+
|
196 |
+
# create denoising schedule
|
197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
198 |
+
|
199 |
+
# create positional encodings
|
200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
203 |
+
).to(latents)
|
204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
205 |
+
|
206 |
+
# convert img-like latents into sequences of patches
|
207 |
+
latents = pack_latents(latents)
|
208 |
+
|
209 |
+
# this is ignored for schnell
|
210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
213 |
+
pred = model(
|
214 |
+
img=latents,
|
215 |
+
img_ids=latent_pos_enc,
|
216 |
+
txt=t5_encodings,
|
217 |
+
txt_ids=text_pos_enc,
|
218 |
+
y=clip_encodings,
|
219 |
+
timesteps=t_vec,
|
220 |
+
guidance=guidance_vec,
|
221 |
+
)
|
222 |
+
|
223 |
+
latents = latents + (t_prev - t_curr) * pred
|
224 |
+
|
225 |
+
# convert sequences of patches into img-like latents
|
226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
227 |
+
|
228 |
+
img = decoder.decode(latents)
|
229 |
+
return img
|
230 |
+
|
231 |
+
def _save_image(
|
232 |
+
self,
|
233 |
+
name: str,
|
234 |
+
output_name: str,
|
235 |
+
x: torch.Tensor,
|
236 |
+
add_sampling_metadata: bool,
|
237 |
+
prompt: str,
|
238 |
+
):
|
239 |
+
print(f"Saving {output_name}")
|
240 |
+
# bring into PIL format and save
|
241 |
+
x = x.clamp(-1, 1)
|
242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
243 |
+
|
244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
245 |
+
|
246 |
+
exif_data = Image.Exif()
|
247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
249 |
+
exif_data[ExifTags.Base.Model] = name
|
250 |
+
if add_sampling_metadata:
|
251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[job]
|
3 |
+
dump_folder = "./outputs"
|
4 |
+
description = "Flux debug model"
|
5 |
+
print_args = false
|
6 |
+
use_for_integration_test = true
|
7 |
+
|
8 |
+
[profiling]
|
9 |
+
enable_profiling = false
|
10 |
+
save_traces_folder = "profile_trace"
|
11 |
+
profile_freq = 10
|
12 |
+
enable_memory_snapshot = false
|
13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
14 |
+
|
15 |
+
[metrics]
|
16 |
+
log_freq = 1
|
17 |
+
disable_color_printing = false
|
18 |
+
enable_tensorboard = false
|
19 |
+
save_tb_folder = "tb"
|
20 |
+
enable_wandb = false
|
21 |
+
|
22 |
+
[model]
|
23 |
+
name = "flux"
|
24 |
+
flavor = "flux-debug"
|
25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
26 |
+
# test tokenizer.model, for debug purpose only
|
27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
28 |
+
# converters = "float8"
|
29 |
+
|
30 |
+
|
31 |
+
[optimizer]
|
32 |
+
name = "AdamW"
|
33 |
+
lr = 8e-4
|
34 |
+
eps = 1e-8
|
35 |
+
|
36 |
+
[lr_scheduler]
|
37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
39 |
+
decay_type = "linear"
|
40 |
+
lr_min = 0.0
|
41 |
+
|
42 |
+
[training]
|
43 |
+
batch_size = 32
|
44 |
+
seq_len = 512
|
45 |
+
max_norm = 1.0 # grad norm clipping
|
46 |
+
steps = 10
|
47 |
+
compile = false
|
48 |
+
dataset = "cc12m"
|
49 |
+
guidance = 3.5
|
50 |
+
seed = 0
|
51 |
+
|
52 |
+
[encoder]
|
53 |
+
t5_encoder="google/t5-v1_1-small"
|
54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
55 |
+
max_t5_encoding_len=512
|
56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
57 |
+
|
58 |
+
[parallelism]
|
59 |
+
data_parallel_replicate_degree = 1
|
60 |
+
data_parallel_shard_degree = 1
|
61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
62 |
+
tensor_parallel_degree = 1
|
63 |
+
enable_async_tensor_parallel = false
|
64 |
+
pipeline_parallel_degree = 1
|
65 |
+
context_parallel_degree = 1
|
66 |
+
|
67 |
+
[experimental]
|
68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py
ADDED
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
import time
|
11 |
+
|
12 |
+
from typing import Dict, List, Tuple
|
13 |
+
|
14 |
+
# import numpy as np
|
15 |
+
import torch #
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.optim as optim
|
19 |
+
|
20 |
+
# from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
|
21 |
+
|
22 |
+
# Configure logging
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
25 |
+
)
|
26 |
+
|
27 |
+
# Try to import the optimized MG GEMM implementation
|
28 |
+
try:
|
29 |
+
from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
|
30 |
+
grouped_gemm_forward,
|
31 |
+
)
|
32 |
+
|
33 |
+
has_mg_gemm = True
|
34 |
+
except ImportError:
|
35 |
+
logging.warning("MG GEMM implementation not found. Will use manual looping only.")
|
36 |
+
has_mg_gemm = False
|
37 |
+
|
38 |
+
|
39 |
+
class Router(nn.Module):
|
40 |
+
"""
|
41 |
+
Router module that assigns tokens to experts.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
|
45 |
+
super().__init__()
|
46 |
+
self.input_dim = input_dim
|
47 |
+
self.num_experts = num_experts
|
48 |
+
self.top_k = top_k
|
49 |
+
|
50 |
+
# Routing layer
|
51 |
+
self.router = nn.Linear(input_dim, num_experts)
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
54 |
+
"""
|
55 |
+
Route input tokens to experts.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tuple containing:
|
62 |
+
- router_logits: Raw routing probabilities
|
63 |
+
- dispatch_tensor: One-hot tensor indicating expert assignment
|
64 |
+
- expert_indices: List of indices for each expert's tokens
|
65 |
+
"""
|
66 |
+
batch_size, seq_len, _ = x.shape
|
67 |
+
|
68 |
+
# Flatten batch and sequence dimensions
|
69 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
70 |
+
|
71 |
+
# Compute routing probabilities
|
72 |
+
router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
|
73 |
+
|
74 |
+
# Apply softmax to get probabilities
|
75 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
76 |
+
|
77 |
+
# Get top-k experts for each token
|
78 |
+
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
79 |
+
|
80 |
+
# Normalize top-k probabilities
|
81 |
+
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
|
82 |
+
|
83 |
+
# Create dispatch tensor (one-hot representation of assignments)
|
84 |
+
dispatch_tensor = torch.zeros_like(router_probs)
|
85 |
+
token_indices = (
|
86 |
+
torch.arange(router_probs.size(0), device=router_probs.device)
|
87 |
+
.unsqueeze(1)
|
88 |
+
.expand(-1, self.top_k)
|
89 |
+
)
|
90 |
+
dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
|
91 |
+
|
92 |
+
# For each expert, get the indices of tokens routed to it
|
93 |
+
expert_indices = []
|
94 |
+
for expert_idx in range(self.num_experts):
|
95 |
+
# Get indices of tokens that have non-zero probability for this expert
|
96 |
+
indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
|
97 |
+
0
|
98 |
+
]
|
99 |
+
expert_indices.append(indices)
|
100 |
+
|
101 |
+
return router_logits, dispatch_tensor, expert_indices
|
102 |
+
|
103 |
+
|
104 |
+
class Expert(nn.Module):
|
105 |
+
"""
|
106 |
+
Individual expert module.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
110 |
+
super().__init__()
|
111 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
|
112 |
+
self.activation = nn.GELU()
|
113 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
|
114 |
+
|
115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
116 |
+
x = self.fc1(x)
|
117 |
+
x = self.activation(x)
|
118 |
+
x = self.fc2(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
class MixtureOfExperts(nn.Module):
|
123 |
+
"""
|
124 |
+
Mixture of Experts layer with support for both manual looping and grouped GEMM.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
input_dim: int,
|
130 |
+
hidden_dim: int,
|
131 |
+
output_dim: int,
|
132 |
+
num_experts: int,
|
133 |
+
top_k: int = 2,
|
134 |
+
use_mg_gemm: bool = False,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.input_dim = input_dim
|
138 |
+
self.hidden_dim = hidden_dim
|
139 |
+
self.output_dim = output_dim
|
140 |
+
self.num_experts = num_experts
|
141 |
+
self.top_k = top_k
|
142 |
+
self.use_mg_gemm = use_mg_gemm and has_mg_gemm
|
143 |
+
|
144 |
+
# Router
|
145 |
+
self.router = Router(input_dim, num_experts, top_k)
|
146 |
+
|
147 |
+
# Create expert modules
|
148 |
+
if self.use_mg_gemm:
|
149 |
+
# For MG GEMM, we need a single weight tensor for all experts
|
150 |
+
# First layer (input -> hidden)
|
151 |
+
self.expert_fc1_weight = nn.Parameter(
|
152 |
+
torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
|
153 |
+
)
|
154 |
+
# self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
|
155 |
+
|
156 |
+
# Second layer (hidden -> output)
|
157 |
+
self.expert_fc2_weight = nn.Parameter(
|
158 |
+
torch.randn(num_experts * output_dim, hidden_dim)
|
159 |
+
/ math.sqrt(hidden_dim)
|
160 |
+
)
|
161 |
+
# self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
|
162 |
+
else:
|
163 |
+
# For manual looping, create separate experts
|
164 |
+
self.experts = nn.ModuleList(
|
165 |
+
[Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
|
169 |
+
"""
|
170 |
+
Forward pass using manual looping over experts.
|
171 |
+
"""
|
172 |
+
batch_size, seq_len, _ = x.shape
|
173 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
174 |
+
|
175 |
+
# Get routing information
|
176 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
177 |
+
|
178 |
+
# Initialize output tensor
|
179 |
+
final_output = torch.zeros(
|
180 |
+
batch_size * seq_len, self.output_dim, device=x.device
|
181 |
+
)
|
182 |
+
|
183 |
+
# Process each expert
|
184 |
+
for expert_idx, indices in enumerate(expert_indices):
|
185 |
+
if indices.numel() > 0:
|
186 |
+
# Get tokens routed to this expert
|
187 |
+
expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
|
188 |
+
|
189 |
+
# Process tokens through expert
|
190 |
+
expert_outputs = self.experts[expert_idx](
|
191 |
+
expert_inputs
|
192 |
+
) # (num_tokens_for_expert, output_dim)
|
193 |
+
|
194 |
+
# Scale outputs by router probabilities
|
195 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
196 |
+
indices, expert_idx
|
197 |
+
].unsqueeze(1)
|
198 |
+
|
199 |
+
# Add to final output
|
200 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
201 |
+
|
202 |
+
# Reshape back to original dimensions
|
203 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
204 |
+
|
205 |
+
return output, router_logits
|
206 |
+
|
207 |
+
def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
|
208 |
+
batch_size, seq_len, _ = x.shape
|
209 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
210 |
+
total_tokens = batch_size * seq_len
|
211 |
+
|
212 |
+
# Get routing information
|
213 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
214 |
+
|
215 |
+
# Get token counts for each expert
|
216 |
+
token_counts = [indices.numel() for indices in expert_indices]
|
217 |
+
m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
|
218 |
+
|
219 |
+
print(f"Token counts per expert: {token_counts}")
|
220 |
+
print(f"m_sizes: {m_sizes}")
|
221 |
+
|
222 |
+
# Create the combined input tensor
|
223 |
+
combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
|
224 |
+
|
225 |
+
start_idx = 0
|
226 |
+
for expert_idx, indices in enumerate(expert_indices):
|
227 |
+
if indices.numel() > 0:
|
228 |
+
end_idx = start_idx + indices.numel()
|
229 |
+
combined_input[start_idx:end_idx] = x_flat[indices]
|
230 |
+
start_idx = end_idx
|
231 |
+
|
232 |
+
print(f"combined_input shape: {combined_input.shape}")
|
233 |
+
|
234 |
+
# First layer: input -> hidden
|
235 |
+
fc1_weight_reshaped = self.expert_fc1_weight.reshape(
|
236 |
+
self.num_experts, self.hidden_dim, self.input_dim
|
237 |
+
)
|
238 |
+
fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
|
239 |
+
|
240 |
+
print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
|
241 |
+
|
242 |
+
# Run the grouped GEMM
|
243 |
+
hidden_outputs = grouped_gemm_forward(
|
244 |
+
combined_input, fc1_weight_combined, m_sizes
|
245 |
+
)
|
246 |
+
|
247 |
+
print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
|
248 |
+
|
249 |
+
# Apply activation
|
250 |
+
hidden_outputs = F.gelu(hidden_outputs)
|
251 |
+
|
252 |
+
print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
|
253 |
+
|
254 |
+
# Second layer: hidden -> output
|
255 |
+
# Reshape hidden_outputs to match expected dimensions
|
256 |
+
reshaped_hidden_outputs = []
|
257 |
+
start_idx = 0
|
258 |
+
|
259 |
+
for expert_idx, count in enumerate(token_counts):
|
260 |
+
if count > 0:
|
261 |
+
end_idx = start_idx + count
|
262 |
+
# Take this expert's outputs and reshape to [count, hidden_dim]
|
263 |
+
expert_output = hidden_outputs[
|
264 |
+
start_idx:end_idx,
|
265 |
+
expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
|
266 |
+
]
|
267 |
+
reshaped_hidden_outputs.append(expert_output)
|
268 |
+
start_idx = end_idx
|
269 |
+
|
270 |
+
# Concatenate all reshaped outputs
|
271 |
+
hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
|
272 |
+
|
273 |
+
# Reshape expert weights for second layer
|
274 |
+
fc2_weight_reshaped = self.expert_fc2_weight.reshape(
|
275 |
+
self.num_experts, self.output_dim, self.hidden_dim
|
276 |
+
)
|
277 |
+
fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
|
278 |
+
|
279 |
+
print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
|
280 |
+
|
281 |
+
# Run the second grouped GEMM
|
282 |
+
expert_outputs_combined = grouped_gemm_forward(
|
283 |
+
hidden_outputs, fc2_weight_combined, m_sizes
|
284 |
+
)
|
285 |
+
|
286 |
+
# Initialize final output tensor with correct shape
|
287 |
+
final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
|
288 |
+
|
289 |
+
# Distribute the outputs back to the original token positions
|
290 |
+
start_idx = 0
|
291 |
+
for expert_idx, indices in enumerate(expert_indices):
|
292 |
+
if indices.numel() > 0:
|
293 |
+
end_idx = start_idx + indices.numel()
|
294 |
+
# Get this expert's outputs
|
295 |
+
expert_outputs = expert_outputs_combined[start_idx:end_idx]
|
296 |
+
|
297 |
+
print(
|
298 |
+
f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
|
299 |
+
)
|
300 |
+
|
301 |
+
# Scale outputs by router probabilities
|
302 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
303 |
+
indices, expert_idx
|
304 |
+
].unsqueeze(1)
|
305 |
+
|
306 |
+
# Ensure dimensions match before using index_add_
|
307 |
+
if scaled_outputs.shape[1] != final_output.shape[1]:
|
308 |
+
# print(
|
309 |
+
# f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
|
310 |
+
# )
|
311 |
+
# Reshape if needed - make sure output_dim is correct
|
312 |
+
scaled_outputs = scaled_outputs[:, : self.output_dim]
|
313 |
+
|
314 |
+
# Add to final output
|
315 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
316 |
+
|
317 |
+
start_idx = end_idx
|
318 |
+
|
319 |
+
# Reshape back to original dimensions
|
320 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
321 |
+
|
322 |
+
return output, router_logits
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
325 |
+
if self.use_mg_gemm and has_mg_gemm:
|
326 |
+
return self.forward_mg_gemm(x)
|
327 |
+
else:
|
328 |
+
return self.forward_manual_loop(x)
|
329 |
+
|
330 |
+
|
331 |
+
class MoEModel(nn.Module):
|
332 |
+
"""
|
333 |
+
Simple model using MoE layers.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
vocab_size: int,
|
339 |
+
embed_dim: int,
|
340 |
+
hidden_dim: int,
|
341 |
+
num_experts: int,
|
342 |
+
top_k: int = 2,
|
343 |
+
use_mg_gemm: bool = False,
|
344 |
+
):
|
345 |
+
super().__init__()
|
346 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
347 |
+
self.moe_layer = MixtureOfExperts(
|
348 |
+
input_dim=embed_dim,
|
349 |
+
hidden_dim=hidden_dim,
|
350 |
+
output_dim=embed_dim,
|
351 |
+
num_experts=num_experts,
|
352 |
+
top_k=top_k,
|
353 |
+
use_mg_gemm=use_mg_gemm,
|
354 |
+
)
|
355 |
+
self.output_layer = nn.Linear(embed_dim, vocab_size)
|
356 |
+
|
357 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
358 |
+
# x shape: (batch_size, seq_len)
|
359 |
+
embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
|
360 |
+
moe_output, router_logits = self.moe_layer(
|
361 |
+
embedded
|
362 |
+
) # (batch_size, seq_len, embed_dim)
|
363 |
+
logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
|
364 |
+
return logits, router_logits
|
365 |
+
|
366 |
+
|
367 |
+
def compute_load_balancing_loss(
|
368 |
+
router_logits: torch.Tensor, num_experts: int
|
369 |
+
) -> torch.Tensor:
|
370 |
+
"""
|
371 |
+
Compute the load balancing loss for MoE training.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
|
375 |
+
num_experts (int): Number of experts
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
torch.Tensor: Load balancing loss
|
379 |
+
"""
|
380 |
+
# Get router probabilities
|
381 |
+
router_probs = F.softmax(
|
382 |
+
router_logits, dim=-1
|
383 |
+
) # (batch_size * seq_len, num_experts)
|
384 |
+
|
385 |
+
# Compute fraction of tokens routed to each expert
|
386 |
+
# Sum across the batch dimension and normalize
|
387 |
+
router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
|
388 |
+
router_probs_sum = router_probs_sum / router_probs_sum.sum()
|
389 |
+
|
390 |
+
# Compute the mean probability per expert
|
391 |
+
mean_prob = 1.0 / num_experts
|
392 |
+
|
393 |
+
# Compute the fraction of tokens routed to each expert
|
394 |
+
# The goal is to have uniform routing across experts
|
395 |
+
load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
|
396 |
+
|
397 |
+
return load_balancing_loss
|
398 |
+
|
399 |
+
|
400 |
+
def generate_sample_data(
|
401 |
+
batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
|
402 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
403 |
+
"""
|
404 |
+
Generate sample data for training.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
batch_size (int): Batch size
|
408 |
+
seq_len (int): Sequence length
|
409 |
+
vocab_size (int): Vocabulary size
|
410 |
+
device (str): Device to use
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
Tuple of input tokens and target tokens
|
414 |
+
"""
|
415 |
+
# Generate random input tokens
|
416 |
+
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
417 |
+
|
418 |
+
# Generate random target tokens
|
419 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
420 |
+
|
421 |
+
return inputs, targets
|
422 |
+
|
423 |
+
|
424 |
+
def train_epoch(
|
425 |
+
model: nn.Module,
|
426 |
+
optimizer: torch.optim.Optimizer,
|
427 |
+
batch_size: int,
|
428 |
+
seq_len: int,
|
429 |
+
vocab_size: int,
|
430 |
+
num_batches: int,
|
431 |
+
device: str,
|
432 |
+
load_balance_coef: float = 0.01,
|
433 |
+
) -> Dict[str, float]:
|
434 |
+
"""
|
435 |
+
Train the model for one epoch.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
model (nn.Module): Model to train
|
439 |
+
optimizer (torch.optim.Optimizer): Optimizer
|
440 |
+
batch_size (int): Batch size
|
441 |
+
seq_len (int): Sequence length
|
442 |
+
vocab_size (int): Vocabulary size
|
443 |
+
num_batches (int): Number of batches per epoch
|
444 |
+
device (str): Device to use
|
445 |
+
load_balance_coef (float): Coefficient for load balancing loss
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
Dict containing training metrics
|
449 |
+
"""
|
450 |
+
model.train()
|
451 |
+
total_loss = 0.0
|
452 |
+
total_acc = 0.0
|
453 |
+
start_time = time.time()
|
454 |
+
|
455 |
+
for i in range(num_batches):
|
456 |
+
# Generate sample data
|
457 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
458 |
+
|
459 |
+
# Forward pass
|
460 |
+
optimizer.zero_grad()
|
461 |
+
logits, router_logits = model(inputs)
|
462 |
+
|
463 |
+
# Compute loss
|
464 |
+
# Reshape for cross entropy loss
|
465 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
466 |
+
targets_flat = targets.reshape(-1)
|
467 |
+
|
468 |
+
# Cross entropy loss
|
469 |
+
ce_loss = F.cross_entropy(logits_flat, targets_flat)
|
470 |
+
|
471 |
+
# Load balancing loss
|
472 |
+
lb_loss = compute_load_balancing_loss(
|
473 |
+
router_logits, model.moe_layer.num_experts
|
474 |
+
)
|
475 |
+
|
476 |
+
# Combined loss
|
477 |
+
loss = ce_loss + load_balance_coef * lb_loss
|
478 |
+
|
479 |
+
# Backward pass
|
480 |
+
loss.backward()
|
481 |
+
optimizer.step()
|
482 |
+
|
483 |
+
# Compute accuracy
|
484 |
+
preds = logits_flat.argmax(dim=-1)
|
485 |
+
correct = (preds == targets_flat).float().sum()
|
486 |
+
acc = correct / (batch_size * seq_len)
|
487 |
+
|
488 |
+
# Accumulate metrics
|
489 |
+
total_loss += loss.item()
|
490 |
+
total_acc += acc.item()
|
491 |
+
|
492 |
+
# Log progress
|
493 |
+
if (i + 1) % 10 == 0:
|
494 |
+
logging.info(
|
495 |
+
f"Batch {i + 1}/{num_batches} | "
|
496 |
+
f"Loss: {loss.item():.4f} | "
|
497 |
+
f"CE Loss: {ce_loss.item():.4f} | "
|
498 |
+
f"LB Loss: {lb_loss.item():.4f} | "
|
499 |
+
f"Acc: {acc.item():.4f}"
|
500 |
+
)
|
501 |
+
|
502 |
+
# Compute average metrics
|
503 |
+
avg_loss = total_loss / num_batches
|
504 |
+
avg_acc = total_acc / num_batches
|
505 |
+
epoch_time = time.time() - start_time
|
506 |
+
|
507 |
+
return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
|
508 |
+
|
509 |
+
|
510 |
+
def evaluate(
|
511 |
+
model: nn.Module,
|
512 |
+
batch_size: int,
|
513 |
+
seq_len: int,
|
514 |
+
vocab_size: int,
|
515 |
+
num_batches: int,
|
516 |
+
device: str,
|
517 |
+
) -> Dict[str, float]:
|
518 |
+
"""
|
519 |
+
Evaluate the model.
|
520 |
+
|
521 |
+
Args:
|
522 |
+
model (nn.Module): Model to evaluate
|
523 |
+
batch_size (int): Batch size
|
524 |
+
seq_len (int): Sequence length
|
525 |
+
vocab_size (int): Vocabulary size
|
526 |
+
num_batches (int): Number of batches for evaluation
|
527 |
+
device (str): Device to use
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
Dict containing evaluation metrics
|
531 |
+
"""
|
532 |
+
model.eval()
|
533 |
+
total_loss = 0.0
|
534 |
+
total_acc = 0.0
|
535 |
+
|
536 |
+
with torch.no_grad():
|
537 |
+
for i in range(num_batches):
|
538 |
+
# Generate sample data
|
539 |
+
inputs, targets = generate_sample_data(
|
540 |
+
batch_size, seq_len, vocab_size, device
|
541 |
+
)
|
542 |
+
|
543 |
+
# Forward pass
|
544 |
+
logits, router_logits = model(inputs)
|
545 |
+
|
546 |
+
# Compute loss
|
547 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
548 |
+
targets_flat = targets.reshape(-1)
|
549 |
+
|
550 |
+
# Cross entropy loss
|
551 |
+
loss = F.cross_entropy(logits_flat, targets_flat)
|
552 |
+
|
553 |
+
# Compute accuracy
|
554 |
+
preds = logits_flat.argmax(dim=-1)
|
555 |
+
correct = (preds == targets_flat).float().sum()
|
556 |
+
acc = correct / (batch_size * seq_len)
|
557 |
+
|
558 |
+
# Accumulate metrics
|
559 |
+
total_loss += loss.item()
|
560 |
+
total_acc += acc.item()
|
561 |
+
|
562 |
+
# Compute average metrics
|
563 |
+
avg_loss = total_loss / num_batches
|
564 |
+
avg_acc = total_acc / num_batches
|
565 |
+
|
566 |
+
return {"loss": avg_loss, "acc": avg_acc}
|
567 |
+
|
568 |
+
|
569 |
+
def measure_performance(
|
570 |
+
model: nn.Module,
|
571 |
+
batch_size: int,
|
572 |
+
seq_len: int,
|
573 |
+
vocab_size: int,
|
574 |
+
num_batches: int,
|
575 |
+
device: str,
|
576 |
+
) -> Dict[str, float]:
|
577 |
+
"""
|
578 |
+
Measure forward and backward pass performance.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
model (nn.Module): Model to evaluate
|
582 |
+
batch_size (int): Batch size
|
583 |
+
seq_len (int): Sequence length
|
584 |
+
vocab_size (int): Vocabulary size
|
585 |
+
num_batches (int): Number of batches for measurement
|
586 |
+
device (str): Device to use
|
587 |
+
|
588 |
+
Returns:
|
589 |
+
Dict containing performance metrics
|
590 |
+
"""
|
591 |
+
model.train()
|
592 |
+
|
593 |
+
# Create dummy optimizer
|
594 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
595 |
+
|
596 |
+
# Warmup
|
597 |
+
for _ in range(5):
|
598 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
599 |
+
logits, router_logits = model(inputs)
|
600 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
601 |
+
loss.backward()
|
602 |
+
optimizer.zero_grad()
|
603 |
+
|
604 |
+
# Measure forward pass time
|
605 |
+
torch.cuda.synchronize()
|
606 |
+
forward_start = time.time()
|
607 |
+
|
608 |
+
for _ in range(num_batches):
|
609 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
610 |
+
with torch.no_grad():
|
611 |
+
logits, router_logits = model(inputs)
|
612 |
+
|
613 |
+
torch.cuda.synchronize()
|
614 |
+
forward_end = time.time()
|
615 |
+
forward_time = (forward_end - forward_start) / num_batches
|
616 |
+
|
617 |
+
# Measure backward pass time
|
618 |
+
torch.cuda.synchronize()
|
619 |
+
backward_start = time.time()
|
620 |
+
|
621 |
+
for _ in range(num_batches):
|
622 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
623 |
+
logits, router_logits = model(inputs)
|
624 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
625 |
+
loss.backward()
|
626 |
+
optimizer.zero_grad()
|
627 |
+
|
628 |
+
torch.cuda.synchronize()
|
629 |
+
backward_end = time.time()
|
630 |
+
backward_time = (backward_end - backward_start) / num_batches
|
631 |
+
|
632 |
+
return {
|
633 |
+
"forward_time": forward_time * 1000, # Convert to ms
|
634 |
+
"backward_time": backward_time * 1000, # Convert to ms
|
635 |
+
"total_time": (forward_time + backward_time) * 1000, # Convert to ms
|
636 |
+
}
|
637 |
+
|
638 |
+
|
639 |
+
def compare_methods(args):
|
640 |
+
"""
|
641 |
+
Compare manual looping and MG GEMM implementations.
|
642 |
+
"""
|
643 |
+
device = torch.device(args.device)
|
644 |
+
|
645 |
+
# Create models
|
646 |
+
manual_model = MoEModel(
|
647 |
+
vocab_size=args.vocab_size,
|
648 |
+
embed_dim=args.embed_dim,
|
649 |
+
hidden_dim=args.hidden_dim,
|
650 |
+
num_experts=args.num_experts,
|
651 |
+
top_k=args.top_k,
|
652 |
+
use_mg_gemm=False,
|
653 |
+
).to(device)
|
654 |
+
|
655 |
+
if has_mg_gemm:
|
656 |
+
mg_model = MoEModel(
|
657 |
+
vocab_size=args.vocab_size,
|
658 |
+
embed_dim=args.embed_dim,
|
659 |
+
hidden_dim=args.hidden_dim,
|
660 |
+
num_experts=args.num_experts,
|
661 |
+
top_k=args.top_k,
|
662 |
+
use_mg_gemm=True,
|
663 |
+
).to(device)
|
664 |
+
else:
|
665 |
+
mg_model = None
|
666 |
+
|
667 |
+
# Measure performance
|
668 |
+
logging.info("Measuring performance of manual looping method...")
|
669 |
+
manual_perf = measure_performance(
|
670 |
+
manual_model,
|
671 |
+
args.batch_size,
|
672 |
+
args.seq_len,
|
673 |
+
args.vocab_size,
|
674 |
+
args.perf_batches,
|
675 |
+
device,
|
676 |
+
)
|
677 |
+
|
678 |
+
if mg_model is not None:
|
679 |
+
logging.info("Measuring performance of MG GEMM method...")
|
680 |
+
mg_perf = measure_performance(
|
681 |
+
mg_model,
|
682 |
+
args.batch_size,
|
683 |
+
args.seq_len,
|
684 |
+
args.vocab_size,
|
685 |
+
args.perf_batches,
|
686 |
+
device,
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
|
690 |
+
|
691 |
+
# Log results
|
692 |
+
logging.info("\n===== Performance Comparison =====")
|
693 |
+
logging.info("Model Configuration:")
|
694 |
+
logging.info(f" - Batch Size: {args.batch_size}")
|
695 |
+
logging.info(f" - Sequence Length: {args.seq_len}")
|
696 |
+
logging.info(f" - Embed Dimension: {args.embed_dim}")
|
697 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
698 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
699 |
+
logging.info(f" - Top-K: {args.top_k}")
|
700 |
+
logging.info("")
|
701 |
+
|
702 |
+
logging.info("Manual Looping Method:")
|
703 |
+
logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
|
704 |
+
logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
|
705 |
+
logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
|
706 |
+
logging.info("")
|
707 |
+
|
708 |
+
if mg_model is not None:
|
709 |
+
logging.info("MG GEMM Method:")
|
710 |
+
logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
|
711 |
+
logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
|
712 |
+
logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
|
713 |
+
logging.info("")
|
714 |
+
|
715 |
+
# Calculate speedup
|
716 |
+
forward_speedup = (
|
717 |
+
manual_perf["forward_time"] / mg_perf["forward_time"]
|
718 |
+
if mg_perf["forward_time"] > 0
|
719 |
+
else 0
|
720 |
+
)
|
721 |
+
backward_speedup = (
|
722 |
+
manual_perf["backward_time"] / mg_perf["backward_time"]
|
723 |
+
if mg_perf["backward_time"] > 0
|
724 |
+
else 0
|
725 |
+
)
|
726 |
+
total_speedup = (
|
727 |
+
manual_perf["total_time"] / mg_perf["total_time"]
|
728 |
+
if mg_perf["total_time"] > 0
|
729 |
+
else 0
|
730 |
+
)
|
731 |
+
|
732 |
+
logging.info("Speedup (MG GEMM vs Manual):")
|
733 |
+
logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
|
734 |
+
logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
|
735 |
+
logging.info(f" - Total Speedup: {total_speedup:.2f}x")
|
736 |
+
else:
|
737 |
+
logging.info("MG GEMM method not available.")
|
738 |
+
|
739 |
+
|
740 |
+
def train_model(args):
|
741 |
+
"""
|
742 |
+
Train an MoE model.
|
743 |
+
"""
|
744 |
+
device = torch.device(args.device)
|
745 |
+
|
746 |
+
# Create model
|
747 |
+
model = MoEModel(
|
748 |
+
vocab_size=args.vocab_size,
|
749 |
+
embed_dim=args.embed_dim,
|
750 |
+
hidden_dim=args.hidden_dim,
|
751 |
+
num_experts=args.num_experts,
|
752 |
+
top_k=args.top_k,
|
753 |
+
use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
|
754 |
+
).to(device)
|
755 |
+
|
756 |
+
# Create optimizer
|
757 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
758 |
+
|
759 |
+
# Log model information
|
760 |
+
logging.info("Model configuration:")
|
761 |
+
logging.info(f" - Vocabulary Size: {args.vocab_size}")
|
762 |
+
logging.info(f" - Embedding Dimension: {args.embed_dim}")
|
763 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
764 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
765 |
+
logging.info(f" - Top-K: {args.top_k}")
|
766 |
+
logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
|
767 |
+
|
768 |
+
# Training loop
|
769 |
+
for epoch in range(args.epochs):
|
770 |
+
logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
|
771 |
+
|
772 |
+
# Train
|
773 |
+
train_metrics = train_epoch(
|
774 |
+
model=model,
|
775 |
+
optimizer=optimizer,
|
776 |
+
batch_size=args.batch_size,
|
777 |
+
seq_len=args.seq_len,
|
778 |
+
vocab_size=args.vocab_size,
|
779 |
+
num_batches=args.train_batches,
|
780 |
+
device=device,
|
781 |
+
load_balance_coef=args.load_balance_coef,
|
782 |
+
)
|
783 |
+
|
784 |
+
# Evaluate
|
785 |
+
eval_metrics = evaluate(
|
786 |
+
model=model,
|
787 |
+
batch_size=args.batch_size,
|
788 |
+
seq_len=args.seq_len,
|
789 |
+
vocab_size=args.vocab_size,
|
790 |
+
num_batches=args.eval_batches,
|
791 |
+
device=device,
|
792 |
+
)
|
793 |
+
|
794 |
+
# Log metrics
|
795 |
+
logging.info(
|
796 |
+
f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
|
797 |
+
)
|
798 |
+
logging.info(
|
799 |
+
f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
|
800 |
+
)
|
801 |
+
logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
|
802 |
+
|
803 |
+
|
804 |
+
if __name__ == "__main__":
|
805 |
+
parser = argparse.ArgumentParser(description="Train MoE model")
|
806 |
+
|
807 |
+
# Model parameters
|
808 |
+
parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
|
809 |
+
parser.add_argument(
|
810 |
+
"--embed_dim", type=int, default=512, help="Embedding dimension"
|
811 |
+
)
|
812 |
+
parser.add_argument(
|
813 |
+
"--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
|
814 |
+
)
|
815 |
+
parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
|
816 |
+
parser.add_argument(
|
817 |
+
"--top_k", type=int, default=2, help="Top-k experts to route to"
|
818 |
+
)
|
819 |
+
|
820 |
+
# Training parameters
|
821 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
822 |
+
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
823 |
+
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
824 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
825 |
+
parser.add_argument(
|
826 |
+
"--train_batches",
|
827 |
+
type=int,
|
828 |
+
default=100,
|
829 |
+
help="Number of training batches per epoch",
|
830 |
+
)
|
831 |
+
parser.add_argument(
|
832 |
+
"--eval_batches", type=int, default=20, help="Number of evaluation batches"
|
833 |
+
)
|
834 |
+
parser.add_argument(
|
835 |
+
"--perf_batches",
|
836 |
+
type=int,
|
837 |
+
default=50,
|
838 |
+
help="Number of batches for performance testing",
|
839 |
+
)
|
840 |
+
parser.add_argument(
|
841 |
+
"--load_balance_coef",
|
842 |
+
type=float,
|
843 |
+
default=0.01,
|
844 |
+
help="Load balancing loss coefficient",
|
845 |
+
)
|
846 |
+
|
847 |
+
# Runtime parameters
|
848 |
+
parser.add_argument(
|
849 |
+
"--device",
|
850 |
+
type=str,
|
851 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
852 |
+
help="Device to use (cuda or cpu)",
|
853 |
+
)
|
854 |
+
parser.add_argument(
|
855 |
+
"--use_mg_gemm",
|
856 |
+
action="store_true",
|
857 |
+
help="Use MG GEMM implementation if available",
|
858 |
+
)
|
859 |
+
parser.add_argument(
|
860 |
+
"--compare",
|
861 |
+
action="store_true",
|
862 |
+
help="Compare manual and MG GEMM implementations",
|
863 |
+
)
|
864 |
+
parser.add_argument("--train", action="store_true", help="Train the model")
|
865 |
+
|
866 |
+
args = parser.parse_args()
|
867 |
+
|
868 |
+
# Check for CUDA
|
869 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
870 |
+
logging.warning("CUDA not available, using CPU instead.")
|
871 |
+
args.device = "cpu"
|
872 |
+
|
873 |
+
# Log basic information
|
874 |
+
logging.info(f"PyTorch version: {torch.__version__}")
|
875 |
+
logging.info(f"Device: {args.device}")
|
876 |
+
logging.info(f"MG GEMM available: {has_mg_gemm}")
|
877 |
+
|
878 |
+
# Run the requested action
|
879 |
+
if args.compare:
|
880 |
+
compare_methods(args)
|
881 |
+
elif args.train:
|
882 |
+
train_model(args)
|
883 |
+
else:
|
884 |
+
# Default to comparison if no action specified
|
885 |
+
compare_methods(args)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# pyre-unsafe
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from reference_utils import (
|
14 |
+
analyze_tensor_differences,
|
15 |
+
compute_reference_backward,
|
16 |
+
compute_reference_forward,
|
17 |
+
)
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(
|
21 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
22 |
+
)
|
23 |
+
|
24 |
+
# Import grouped GEMM implementations
|
25 |
+
try:
|
26 |
+
from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
|
27 |
+
|
28 |
+
except ImportError:
|
29 |
+
logging.error(
|
30 |
+
"Error importing grouped GEMM modules. Make sure the implementation files are in the correct path."
|
31 |
+
)
|
32 |
+
raise
|
33 |
+
|
34 |
+
|
35 |
+
def test_forward_pass():
|
36 |
+
"""
|
37 |
+
A simple test for the M*G grouped GEMM forward pass with detailed error handling.
|
38 |
+
|
39 |
+
In M*G grouping:
|
40 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
41 |
+
- N dimension is the same for all groups
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
|
46 |
+
# Test parameters for DeepSeek-like models
|
47 |
+
G = 1 # Number of groups
|
48 |
+
M_sizes = [
|
49 |
+
2048,
|
50 |
+
] # 2048, 2048, 2048] # Group sizes (will be adjusted)
|
51 |
+
M_total = sum(M_sizes) # Total M dimension
|
52 |
+
N = 4096 # Output dimension (same for all groups)
|
53 |
+
K = 7168 # Hidden dimension
|
54 |
+
|
55 |
+
# Create group sizes tensor
|
56 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
57 |
+
|
58 |
+
# Create input and weight tensors - using float16 for higher precision
|
59 |
+
x = torch.randn(M_total, K, dtype=torch.float16, device=device)
|
60 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device)
|
61 |
+
|
62 |
+
# Log the setup
|
63 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
64 |
+
logging.info(f"Group sizes: {m_sizes}")
|
65 |
+
logging.info(f"Input x shape: {x.shape}")
|
66 |
+
logging.info(f"Weight w shape: {w.shape}")
|
67 |
+
|
68 |
+
# Run forward pass
|
69 |
+
logging.info("Running forward pass with grouped GEMM")
|
70 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
71 |
+
logging.info(f"Forward result shape: {result.shape}")
|
72 |
+
|
73 |
+
# Compute reference result
|
74 |
+
logging.info("Computing reference result with PyTorch")
|
75 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
76 |
+
|
77 |
+
# Compare results
|
78 |
+
logging.info("Comparing with PyTorch reference")
|
79 |
+
forward_close = analyze_tensor_differences(
|
80 |
+
result, reference_result, "Forward output"
|
81 |
+
)
|
82 |
+
|
83 |
+
return forward_close
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"Test failed with error: {e}")
|
87 |
+
import traceback
|
88 |
+
|
89 |
+
logging.error(traceback.format_exc())
|
90 |
+
return False
|
91 |
+
|
92 |
+
|
93 |
+
def test_backward_pass():
|
94 |
+
"""
|
95 |
+
A simple test for the M*G grouped GEMM backward pass with detailed error handling.
|
96 |
+
|
97 |
+
In M*G grouping:
|
98 |
+
- M dimension is partitioned into G groups (M_total = sum(M_sizes))
|
99 |
+
- N dimension is the same for all groups
|
100 |
+
"""
|
101 |
+
try:
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
+
|
104 |
+
# Test parameters for DeepSeek-like models
|
105 |
+
G = 4 # Number of groups
|
106 |
+
M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted)
|
107 |
+
M_total = sum(M_sizes) # Total M dimension
|
108 |
+
N = 4096 # Output dimension (same for all groups)
|
109 |
+
K = 7168 # Hidden dimension
|
110 |
+
|
111 |
+
# Create group sizes tensor
|
112 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
113 |
+
|
114 |
+
# Create input and weight tensors - using float16 for higher precision
|
115 |
+
x = torch.randn(
|
116 |
+
M_total, K, dtype=torch.float16, device=device, requires_grad=True
|
117 |
+
)
|
118 |
+
w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True)
|
119 |
+
|
120 |
+
# Log the setup
|
121 |
+
logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}")
|
122 |
+
logging.info(f"Group sizes: {m_sizes}")
|
123 |
+
logging.info(f"Input x shape: {x.shape}")
|
124 |
+
logging.info(f"Weight w shape: {w.shape}")
|
125 |
+
|
126 |
+
# Step 1: Run forward pass
|
127 |
+
logging.info("Running forward pass")
|
128 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
129 |
+
logging.info(f"Forward result shape: {result.shape}")
|
130 |
+
|
131 |
+
# Create a gradient for backpropagation
|
132 |
+
grad_output = torch.randn_like(result)
|
133 |
+
logging.info(f"Created gradient with shape: {grad_output.shape}")
|
134 |
+
|
135 |
+
# Step 2: Run backward pass directly
|
136 |
+
logging.info("Running backward pass directly")
|
137 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
138 |
+
|
139 |
+
# Verify gradient shapes
|
140 |
+
logging.info(
|
141 |
+
f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
|
142 |
+
)
|
143 |
+
|
144 |
+
# Step 3: Verify gradient computation using PyTorch's autograd
|
145 |
+
logging.info("Running PyTorch reference implementation")
|
146 |
+
|
147 |
+
# Compute reference gradients
|
148 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output)
|
149 |
+
|
150 |
+
# Compare gradients
|
151 |
+
logging.info("Comparing gradients with PyTorch reference")
|
152 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
153 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
154 |
+
|
155 |
+
# Log overall result
|
156 |
+
if grad_x_close and grad_w_close:
|
157 |
+
logging.info("✓ SUCCESS: Gradients match the PyTorch reference")
|
158 |
+
else:
|
159 |
+
logging.error("✗ FAILURE: Gradient mismatch detected")
|
160 |
+
|
161 |
+
return grad_x_close and grad_w_close
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
logging.error(f"Test failed with error: {e}")
|
165 |
+
import traceback
|
166 |
+
|
167 |
+
logging.error(traceback.format_exc())
|
168 |
+
return False
|
169 |
+
|
170 |
+
|
171 |
+
def test_multiple_deepseek_configs():
|
172 |
+
"""
|
173 |
+
Test multiple DeepSeek model configurations with both forward and backward pass verification.
|
174 |
+
"""
|
175 |
+
# DeepSeek configurations: (G, M, K, N)
|
176 |
+
configs = [
|
177 |
+
(4, 8192, 7168, 4096), # Config 1
|
178 |
+
(4, 8192, 2048, 7168), # Config 2
|
179 |
+
(8, 4096, 7168, 4096), # Config 3
|
180 |
+
(8, 4096, 2048, 7168), # Config 4
|
181 |
+
]
|
182 |
+
|
183 |
+
results = []
|
184 |
+
|
185 |
+
for config_idx, (G, M, K, N) in enumerate(configs):
|
186 |
+
logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====")
|
187 |
+
logging.info(f"G={G}, M={M}, K={K}, N={N}")
|
188 |
+
|
189 |
+
try:
|
190 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
191 |
+
|
192 |
+
# Create even group sizes
|
193 |
+
base_size = M // G
|
194 |
+
remainder = M % G
|
195 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
196 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
197 |
+
|
198 |
+
# Create input and weight tensors using float16 for higher precision
|
199 |
+
x = torch.randn(
|
200 |
+
M, K, dtype=torch.float16, device=device, requires_grad=True
|
201 |
+
)
|
202 |
+
w = torch.randn(
|
203 |
+
N, K, dtype=torch.float16, device=device, requires_grad=True
|
204 |
+
)
|
205 |
+
|
206 |
+
logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}")
|
207 |
+
|
208 |
+
# Run forward pass
|
209 |
+
result = grouped_gemm_forward(x, w, m_sizes)
|
210 |
+
logging.info(f"Forward result shape: {result.shape}")
|
211 |
+
|
212 |
+
# ===== FORWARD PASS VERIFICATION =====
|
213 |
+
# Compute reference forward result
|
214 |
+
reference_result = compute_reference_forward(x, w, m_sizes)
|
215 |
+
|
216 |
+
# Compare forward results
|
217 |
+
forward_close = analyze_tensor_differences(
|
218 |
+
result, reference_result, "Forward output"
|
219 |
+
)
|
220 |
+
|
221 |
+
# ===== BACKWARD PASS VERIFICATION =====
|
222 |
+
# Create gradient for backpropagation
|
223 |
+
grad_output = torch.randn_like(result)
|
224 |
+
|
225 |
+
# Run backward pass
|
226 |
+
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
|
227 |
+
|
228 |
+
# Compute reference gradients
|
229 |
+
x_ref_grad, w_ref_grad = compute_reference_backward(
|
230 |
+
x, w, m_sizes, grad_output
|
231 |
+
)
|
232 |
+
|
233 |
+
# Compare backward results
|
234 |
+
grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x")
|
235 |
+
grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w")
|
236 |
+
|
237 |
+
# Overall config result
|
238 |
+
backward_close = grad_x_close and grad_w_close
|
239 |
+
config_success = forward_close and backward_close
|
240 |
+
results.append(
|
241 |
+
(config_idx + 1, config_success, forward_close, backward_close)
|
242 |
+
)
|
243 |
+
|
244 |
+
# Log overall config result
|
245 |
+
if config_success:
|
246 |
+
logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!")
|
247 |
+
else:
|
248 |
+
logging.error(
|
249 |
+
f"✗ FAILURE: Config {config_idx+1} failed one or more tests"
|
250 |
+
)
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
logging.error(f"Config {config_idx+1} test failed with error: {e}")
|
254 |
+
import traceback
|
255 |
+
|
256 |
+
logging.error(traceback.format_exc())
|
257 |
+
results.append((config_idx + 1, False, False, False))
|
258 |
+
|
259 |
+
# Summary
|
260 |
+
logging.info("\n===== Test Results Summary =====")
|
261 |
+
for config_idx, overall_success, forward_success, backward_success in results:
|
262 |
+
overall_status = "✓ PASSED" if overall_success else "✗ FAILED"
|
263 |
+
forward_status = "✓ PASSED" if forward_success else "✗ FAILED"
|
264 |
+
backward_status = "✓ PASSED" if backward_success else "✗ FAILED"
|
265 |
+
|
266 |
+
logging.info(f"Config {config_idx}: {overall_status}")
|
267 |
+
logging.info(f" - Forward pass: {forward_status}")
|
268 |
+
logging.info(f" - Backward pass: {backward_status}")
|
269 |
+
|
270 |
+
return all(overall_success for _, overall_success, _, _ in results)
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
logging.info(
|
275 |
+
"Running verification for both forward and backward pass of M*G grouped GEMM"
|
276 |
+
)
|
277 |
+
|
278 |
+
# Run basic forward pass test
|
279 |
+
logging.info("\n===== Running basic forward pass test =====")
|
280 |
+
success_forward = test_forward_pass()
|
281 |
+
logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}")
|
282 |
+
|
283 |
+
# Run basic backward pass test
|
284 |
+
logging.info("\n===== Running basic backward pass test =====")
|
285 |
+
success_backward = test_backward_pass()
|
286 |
+
logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}")
|
287 |
+
|
288 |
+
# Run multiple DeepSeek configs with forward and backward verification
|
289 |
+
logging.info("\n===== Running tests for all DeepSeek configs =====")
|
290 |
+
success_configs = test_multiple_deepseek_configs()
|
291 |
+
logging.info(
|
292 |
+
f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}"
|
293 |
+
)
|
294 |
+
|
295 |
+
# Overall result
|
296 |
+
overall_success = success_forward and success_backward and success_configs
|
297 |
+
logging.info(
|
298 |
+
f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}"
|
299 |
+
)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# credit - flat index forward kernel is derived from FBGemm:
|
8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
9 |
+
|
10 |
+
# pyre-unsafe
|
11 |
+
import functools
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from typing import Any, Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
from triton import Config as TConfig
|
23 |
+
|
24 |
+
from triton.runtime import driver # @manual
|
25 |
+
|
26 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
27 |
+
|
28 |
+
from tma_autotuning import (
|
29 |
+
ALIGN_SIZE_M,
|
30 |
+
_NV_CONFIGS,
|
31 |
+
CudaUtils,
|
32 |
+
early_config_prune,
|
33 |
+
TmaDescriptorHelper,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
# Configure logging
|
38 |
+
logging.basicConfig(
|
39 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
40 |
+
)
|
41 |
+
|
42 |
+
# ============== Start Triton Kernels ===============
|
43 |
+
|
44 |
+
|
45 |
+
@triton.autotune(
|
46 |
+
configs=_NV_CONFIGS,
|
47 |
+
key=["G", "M_BUCKET", "N", "K"],
|
48 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
49 |
+
)
|
50 |
+
@triton.jit
|
51 |
+
def _kernel_mg_forward_hopper(
|
52 |
+
a_desc_ptr,
|
53 |
+
b_desc_ptr,
|
54 |
+
c_ptr,
|
55 |
+
workspace,
|
56 |
+
m_sizes,
|
57 |
+
# problem sizes
|
58 |
+
G: tl.constexpr,
|
59 |
+
M_BUCKET: tl.constexpr,
|
60 |
+
N: tl.constexpr,
|
61 |
+
K: tl.constexpr,
|
62 |
+
# config
|
63 |
+
NUM_SMS: tl.constexpr,
|
64 |
+
TMA_SIZE: tl.constexpr,
|
65 |
+
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
66 |
+
# tiles
|
67 |
+
BLOCK_SIZE_M: tl.constexpr,
|
68 |
+
BLOCK_SIZE_N: tl.constexpr,
|
69 |
+
BLOCK_SIZE_K: tl.constexpr,
|
70 |
+
) -> None:
|
71 |
+
"""
|
72 |
+
Flat index style forward kernel for Hopper.
|
73 |
+
For simplicity, we always use TMA Load and TMA Store
|
74 |
+
"""
|
75 |
+
tbidx = tl.program_id(0) # thread block index
|
76 |
+
|
77 |
+
c_dtype = c_ptr.dtype.element_ty # output dtype
|
78 |
+
|
79 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
|
80 |
+
|
81 |
+
M_end = 0
|
82 |
+
M_start = 0
|
83 |
+
processed_tiles = 0
|
84 |
+
# Size of individual weight matrix
|
85 |
+
n_size = N // G
|
86 |
+
n_start = 0
|
87 |
+
|
88 |
+
for g in range(G):
|
89 |
+
# Move down along groups
|
90 |
+
# reset to new M offset
|
91 |
+
M_start = M_end
|
92 |
+
m_size = tl.load(m_sizes + g)
|
93 |
+
M_end = M_start + m_size
|
94 |
+
n_start = n_size * g
|
95 |
+
|
96 |
+
if m_size > 0:
|
97 |
+
# Process this group
|
98 |
+
|
99 |
+
# Acquire hold on c_desc_ptr for TMA Store
|
100 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
101 |
+
desc_ptr=c_desc_ptr,
|
102 |
+
global_address=c_ptr + M_start * n_size,
|
103 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
104 |
+
global_size=[m_size, n_size],
|
105 |
+
element_ty=c_dtype,
|
106 |
+
)
|
107 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
108 |
+
|
109 |
+
# tiles for this group
|
110 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
111 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
112 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
113 |
+
|
114 |
+
while tbidx >= processed_tiles and tbidx < (
|
115 |
+
processed_tiles + group_num_tiles
|
116 |
+
):
|
117 |
+
group_index = tbidx - processed_tiles
|
118 |
+
|
119 |
+
# columnwise
|
120 |
+
tile_m_index = group_index % num_m_tiles
|
121 |
+
tile_n_index = group_index // num_m_tiles
|
122 |
+
|
123 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
124 |
+
|
125 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
126 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
127 |
+
global_n_offset = (n_start + n_offset).to(tl.int32)
|
128 |
+
|
129 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
130 |
+
# input block [M,K]
|
131 |
+
a = tl._experimental_descriptor_load(
|
132 |
+
a_desc_ptr,
|
133 |
+
[m_offset, k_offset],
|
134 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
135 |
+
c_dtype,
|
136 |
+
)
|
137 |
+
# weight block [N, K]
|
138 |
+
b = tl._experimental_descriptor_load(
|
139 |
+
b_desc_ptr,
|
140 |
+
[global_n_offset, k_offset],
|
141 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
142 |
+
c_dtype,
|
143 |
+
)
|
144 |
+
|
145 |
+
accumulator += tl.dot(a, b.T)
|
146 |
+
|
147 |
+
# Store using TMA
|
148 |
+
|
149 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
150 |
+
|
151 |
+
if USE_EPILOGUE_SUBTILING:
|
152 |
+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
153 |
+
acc = tl.permute(acc, (0, 2, 1))
|
154 |
+
acc0, acc1 = tl.split(acc)
|
155 |
+
c0 = acc0.to(c_dtype)
|
156 |
+
tl._experimental_descriptor_store(
|
157 |
+
c_desc_ptr, c0, [m_offset, n_offset]
|
158 |
+
)
|
159 |
+
c1 = acc1.to(c_dtype)
|
160 |
+
tl._experimental_descriptor_store(
|
161 |
+
c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
tl._experimental_descriptor_store(
|
165 |
+
c_desc_ptr,
|
166 |
+
accumulator.to(c_dtype),
|
167 |
+
[m_offset, n_offset],
|
168 |
+
)
|
169 |
+
# move to next tile in group
|
170 |
+
tbidx += NUM_SMS
|
171 |
+
# Update the total tiles count for the next group
|
172 |
+
processed_tiles += group_num_tiles
|
173 |
+
|
174 |
+
|
175 |
+
@triton.autotune(
|
176 |
+
configs=_NV_CONFIGS,
|
177 |
+
key=["G", "M_BUCKET", "N", "K"],
|
178 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
179 |
+
)
|
180 |
+
@triton.jit
|
181 |
+
def _kernel_mg_forward_tma(
|
182 |
+
a_desc_ptr,
|
183 |
+
b_desc_ptr,
|
184 |
+
c_ptr,
|
185 |
+
workspace,
|
186 |
+
m_sizes,
|
187 |
+
a_scale_ptr,
|
188 |
+
b_scale_ptr,
|
189 |
+
# problem sizes
|
190 |
+
G: tl.constexpr,
|
191 |
+
M_BUCKET: tl.constexpr,
|
192 |
+
N: tl.constexpr,
|
193 |
+
K: tl.constexpr,
|
194 |
+
# config
|
195 |
+
NUM_SMS: tl.constexpr,
|
196 |
+
USE_TMA_LOAD: tl.constexpr,
|
197 |
+
USE_TMA_STORE: tl.constexpr,
|
198 |
+
TMA_SIZE: tl.constexpr,
|
199 |
+
USE_FP8: tl.constexpr,
|
200 |
+
# tiles
|
201 |
+
BLOCK_SIZE_M: tl.constexpr,
|
202 |
+
BLOCK_SIZE_N: tl.constexpr,
|
203 |
+
BLOCK_SIZE_K: tl.constexpr,
|
204 |
+
) -> None:
|
205 |
+
"""
|
206 |
+
Flat index style forward kernel.
|
207 |
+
For simplicity, we always use TMA Load and TMA Store
|
208 |
+
"""
|
209 |
+
tbidx = tl.program_id(0) # thread block index
|
210 |
+
|
211 |
+
c_dtype = c_ptr.dtype.element_ty
|
212 |
+
|
213 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
214 |
+
|
215 |
+
M_end = 0
|
216 |
+
processed_tiles = 0
|
217 |
+
|
218 |
+
for g in range(G):
|
219 |
+
# Move down along groups
|
220 |
+
# reset to new M offset
|
221 |
+
M_start = M_end
|
222 |
+
m_size = tl.load(m_sizes + g)
|
223 |
+
M_end = M_start + m_size
|
224 |
+
|
225 |
+
if m_size > 0:
|
226 |
+
# Process this group
|
227 |
+
n_size = N
|
228 |
+
|
229 |
+
# TMA Store prep
|
230 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
231 |
+
desc_ptr=c_desc_ptr,
|
232 |
+
global_address=c_ptr + M_start * N,
|
233 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
234 |
+
global_size=[m_size, n_size],
|
235 |
+
element_ty=c_dtype,
|
236 |
+
)
|
237 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
238 |
+
|
239 |
+
# tiles for this group
|
240 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
241 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
242 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
243 |
+
|
244 |
+
while tbidx >= processed_tiles and tbidx < (
|
245 |
+
processed_tiles + group_num_tiles
|
246 |
+
):
|
247 |
+
group_index = tbidx - processed_tiles
|
248 |
+
|
249 |
+
tile_m_index = group_index % num_m_tiles
|
250 |
+
tile_n_index = group_index // num_m_tiles
|
251 |
+
|
252 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
253 |
+
|
254 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
255 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
256 |
+
|
257 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
258 |
+
# input block [M,K]
|
259 |
+
a = tl._experimental_descriptor_load(
|
260 |
+
a_desc_ptr,
|
261 |
+
[m_offset, k_offset],
|
262 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
263 |
+
c_dtype,
|
264 |
+
)
|
265 |
+
# weight block [N, K]
|
266 |
+
b = tl._experimental_descriptor_load(
|
267 |
+
b_desc_ptr,
|
268 |
+
[n_offset, k_offset],
|
269 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
270 |
+
c_dtype,
|
271 |
+
)
|
272 |
+
|
273 |
+
accumulator += tl.dot(a, b.T)
|
274 |
+
|
275 |
+
# Store using TMA
|
276 |
+
|
277 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
278 |
+
# n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
279 |
+
|
280 |
+
tl._experimental_descriptor_store(
|
281 |
+
c_desc_ptr,
|
282 |
+
accumulator.to(c_dtype),
|
283 |
+
[m_offset, n_offset],
|
284 |
+
)
|
285 |
+
|
286 |
+
# Move to the next tile
|
287 |
+
tbidx += NUM_SMS
|
288 |
+
# Update the total tiles count for the next group
|
289 |
+
processed_tiles += group_num_tiles
|
290 |
+
|
291 |
+
|
292 |
+
@triton.autotune(
|
293 |
+
configs=_NV_CONFIGS,
|
294 |
+
key=["G", "M_BUCKET", "N", "K"],
|
295 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
296 |
+
)
|
297 |
+
@triton.jit
|
298 |
+
def _kernel_mg_forward_no_tma(
|
299 |
+
a_ptr,
|
300 |
+
b_ptr,
|
301 |
+
c_ptr,
|
302 |
+
workspace,
|
303 |
+
m_sizes,
|
304 |
+
# problem sizes
|
305 |
+
G: tl.constexpr,
|
306 |
+
M_BUCKET: tl.constexpr,
|
307 |
+
N: tl.constexpr,
|
308 |
+
K: tl.constexpr,
|
309 |
+
# config
|
310 |
+
NUM_SMS: tl.constexpr,
|
311 |
+
USE_TMA_LOAD: tl.constexpr,
|
312 |
+
USE_TMA_STORE: tl.constexpr,
|
313 |
+
TMA_SIZE: tl.constexpr,
|
314 |
+
# tiles
|
315 |
+
BLOCK_SIZE_M: tl.constexpr,
|
316 |
+
BLOCK_SIZE_N: tl.constexpr,
|
317 |
+
BLOCK_SIZE_K: tl.constexpr,
|
318 |
+
) -> None:
|
319 |
+
"""
|
320 |
+
Flat index style forward kernel.
|
321 |
+
For bc and Ampere, we never use TMA Load and TMA Store
|
322 |
+
"""
|
323 |
+
tbidx = tl.program_id(0) # thread block index
|
324 |
+
|
325 |
+
c_dtype = c_ptr.dtype.element_ty
|
326 |
+
c_desc_ptr = None
|
327 |
+
|
328 |
+
M_end = 0
|
329 |
+
processed_tiles = 0
|
330 |
+
|
331 |
+
for g in range(G):
|
332 |
+
# Move down along groups
|
333 |
+
# reset to new M offset
|
334 |
+
M_start = M_end
|
335 |
+
m_size = tl.load(m_sizes + g)
|
336 |
+
M_end = M_start + m_size
|
337 |
+
|
338 |
+
if m_size > 0:
|
339 |
+
# Process this group
|
340 |
+
n_size = N
|
341 |
+
|
342 |
+
# tiles for this group
|
343 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
344 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
345 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
346 |
+
|
347 |
+
while tbidx >= processed_tiles and tbidx < (
|
348 |
+
processed_tiles + group_num_tiles
|
349 |
+
):
|
350 |
+
group_index = tbidx - processed_tiles
|
351 |
+
|
352 |
+
tile_m_index = group_index % num_m_tiles
|
353 |
+
tile_n_index = group_index // num_m_tiles
|
354 |
+
|
355 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
356 |
+
|
357 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
358 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
359 |
+
|
360 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
361 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
362 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
363 |
+
|
364 |
+
a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
|
365 |
+
b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
|
366 |
+
|
367 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
368 |
+
# Load with bounds checking
|
369 |
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
370 |
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
371 |
+
|
372 |
+
# Main matmul
|
373 |
+
accumulator += tl.dot(a, b.T)
|
374 |
+
|
375 |
+
# Update pointers for next block
|
376 |
+
a_ptrs += BLOCK_SIZE_K
|
377 |
+
b_ptrs += BLOCK_SIZE_K
|
378 |
+
|
379 |
+
# Store without TMA
|
380 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
381 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
382 |
+
|
383 |
+
c = accumulator.to(c_dtype)
|
384 |
+
|
385 |
+
tl.store(
|
386 |
+
c_ptr
|
387 |
+
+ (M_start + offs_am[:, None]) * N # Row stride is N
|
388 |
+
+ offs_bn[None, :], # Column offset
|
389 |
+
c,
|
390 |
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
391 |
+
)
|
392 |
+
# Move to the next tile
|
393 |
+
tbidx += NUM_SMS
|
394 |
+
# Update the total tiles count for the next group
|
395 |
+
processed_tiles += group_num_tiles
|
396 |
+
|
397 |
+
|
398 |
+
"""
|
399 |
+
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
400 |
+
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
401 |
+
"""
|
402 |
+
|
403 |
+
|
404 |
+
# ---- dx flat linear indexed ----
|
405 |
+
@triton.autotune(
|
406 |
+
configs=_NV_CONFIGS,
|
407 |
+
key=["G", "M_BUCKET", "N", "K"],
|
408 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
409 |
+
)
|
410 |
+
@triton.jit
|
411 |
+
def _kernel_mg_dx_tma(
|
412 |
+
grad_output_desc_ptr, # [MG, N]
|
413 |
+
w_desc_ptr, # [N, K]
|
414 |
+
grad_input_ptr, # output grad_x [MG, K]
|
415 |
+
workspace, # for TMA store
|
416 |
+
m_sizes, # group sizes [G]
|
417 |
+
# problem sizes
|
418 |
+
G: tl.constexpr,
|
419 |
+
M_BUCKET: tl.constexpr,
|
420 |
+
N: tl.constexpr,
|
421 |
+
K: tl.constexpr,
|
422 |
+
# config
|
423 |
+
NUM_SMS: tl.constexpr,
|
424 |
+
USE_TMA_LOAD: tl.constexpr,
|
425 |
+
USE_TMA_STORE: tl.constexpr,
|
426 |
+
TMA_SIZE: tl.constexpr,
|
427 |
+
# tiles
|
428 |
+
BLOCK_SIZE_M: tl.constexpr,
|
429 |
+
BLOCK_SIZE_N: tl.constexpr,
|
430 |
+
BLOCK_SIZE_K: tl.constexpr,
|
431 |
+
) -> None:
|
432 |
+
"""
|
433 |
+
TMA-optimized kernel for computing gradients with respect to input (dx).
|
434 |
+
For the forward pass Y = X @ W.T, the backward for input is:
|
435 |
+
grad_X = grad_Y @ W
|
436 |
+
|
437 |
+
This maps to [MG, N] @ [N, K] -> [MG, K]
|
438 |
+
|
439 |
+
Key differences from forward:
|
440 |
+
1. W is used directly and not transposed
|
441 |
+
2. The reduction dimension is now N (not K)
|
442 |
+
3. Output is [M, K] instead of [M, N]
|
443 |
+
"""
|
444 |
+
tbidx = tl.program_id(0) # thread block index
|
445 |
+
|
446 |
+
c_dtype = grad_input_ptr.dtype.element_ty
|
447 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
448 |
+
|
449 |
+
M_end = 0
|
450 |
+
processed_tiles = 0
|
451 |
+
|
452 |
+
for g in range(G):
|
453 |
+
# Move down along groups - same as forward
|
454 |
+
M_start = M_end
|
455 |
+
m_size = tl.load(m_sizes + g)
|
456 |
+
M_end = M_start + m_size
|
457 |
+
|
458 |
+
if m_size > 0:
|
459 |
+
# Process this group
|
460 |
+
# tiles for this group - now producing [M, K] output
|
461 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
462 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
463 |
+
group_num_tiles = num_m_tiles * num_k_tiles
|
464 |
+
|
465 |
+
# TMA Store prep for [M, K] output
|
466 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
467 |
+
desc_ptr=c_desc_ptr,
|
468 |
+
global_address=grad_input_ptr + M_start * K,
|
469 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
470 |
+
global_size=[m_size, K],
|
471 |
+
element_ty=c_dtype,
|
472 |
+
)
|
473 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
474 |
+
|
475 |
+
while tbidx >= processed_tiles and tbidx < (
|
476 |
+
processed_tiles + group_num_tiles
|
477 |
+
):
|
478 |
+
group_index = tbidx - processed_tiles
|
479 |
+
|
480 |
+
# Different tiling scheme for [M, K] output
|
481 |
+
tile_m_index = group_index % num_m_tiles
|
482 |
+
tile_k_index = group_index // num_m_tiles
|
483 |
+
|
484 |
+
# for grad_input block [M, K]
|
485 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
486 |
+
|
487 |
+
# Position in full matrix
|
488 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
489 |
+
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
490 |
+
|
491 |
+
# reduce along N dimension (instead of K in forward)
|
492 |
+
for n_offset in range(0, N, BLOCK_SIZE_N):
|
493 |
+
# grad_output block [M, N]
|
494 |
+
grad_output = tl._experimental_descriptor_load(
|
495 |
+
grad_output_desc_ptr,
|
496 |
+
[m_offset, n_offset],
|
497 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
498 |
+
c_dtype,
|
499 |
+
)
|
500 |
+
|
501 |
+
# weight block [N, K] - no transpose needed
|
502 |
+
w = tl._experimental_descriptor_load(
|
503 |
+
w_desc_ptr,
|
504 |
+
[n_offset, k_offset],
|
505 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
506 |
+
c_dtype,
|
507 |
+
)
|
508 |
+
|
509 |
+
# grad_x = grad_output @ w
|
510 |
+
# reducing along N dimension
|
511 |
+
accumulator += tl.dot(grad_output, w)
|
512 |
+
|
513 |
+
# Store using TMA
|
514 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
515 |
+
# k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
516 |
+
|
517 |
+
tl._experimental_descriptor_store(
|
518 |
+
c_desc_ptr,
|
519 |
+
accumulator.to(c_dtype),
|
520 |
+
[m_offset, k_offset],
|
521 |
+
)
|
522 |
+
|
523 |
+
# Move to the next tile
|
524 |
+
tbidx += NUM_SMS
|
525 |
+
|
526 |
+
# Update the total tiles count for the next group
|
527 |
+
processed_tiles += group_num_tiles
|
528 |
+
|
529 |
+
|
530 |
+
# ---- dw flat linear indexed ----
|
531 |
+
|
532 |
+
|
533 |
+
@triton.autotune(
|
534 |
+
configs=_NV_CONFIGS,
|
535 |
+
key=["G", "M_BUCKET", "N", "K"],
|
536 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
537 |
+
)
|
538 |
+
@triton.jit
|
539 |
+
def _kernel_mg_dw_tma(
|
540 |
+
x_desc_ptr, # input descriptor [M_total, K]
|
541 |
+
grad_output_desc_ptr, # grad_output descriptor [M_total, N]
|
542 |
+
grad_weight_ptr, # output grad_w [N, K]
|
543 |
+
workspace, # workspace for TMA store
|
544 |
+
m_sizes, # group sizes [G]
|
545 |
+
# problem sizes
|
546 |
+
G: tl.constexpr,
|
547 |
+
M_BUCKET: tl.constexpr,
|
548 |
+
N: tl.constexpr,
|
549 |
+
K: tl.constexpr,
|
550 |
+
# config
|
551 |
+
NUM_SMS: tl.constexpr,
|
552 |
+
USE_TMA_LOAD: tl.constexpr,
|
553 |
+
USE_TMA_STORE: tl.constexpr,
|
554 |
+
TMA_SIZE: tl.constexpr,
|
555 |
+
# tiles
|
556 |
+
BLOCK_SIZE_N: tl.constexpr,
|
557 |
+
BLOCK_SIZE_K: tl.constexpr,
|
558 |
+
BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
|
559 |
+
) -> None:
|
560 |
+
"""
|
561 |
+
Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
|
562 |
+
Uses flat index structure similar to forward.
|
563 |
+
|
564 |
+
For the forward pass Y = X @ W.T,
|
565 |
+
the backward for weights is:
|
566 |
+
grad_W = grad_Y.T @ X
|
567 |
+
|
568 |
+
Where:
|
569 |
+
- grad_Y is [MG, N]
|
570 |
+
- X is [MG, K]
|
571 |
+
- grad_W is [N, K]
|
572 |
+
- we return [N,K]
|
573 |
+
"""
|
574 |
+
# Get thread block index l
|
575 |
+
tbidx = tl.program_id(0)
|
576 |
+
|
577 |
+
# Get output data type
|
578 |
+
c_dtype = grad_weight_ptr.dtype.element_ty
|
579 |
+
|
580 |
+
# Calculate number of output tiles
|
581 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
582 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
583 |
+
total_output_tiles = num_n_tiles * num_k_tiles
|
584 |
+
|
585 |
+
# Process tiles in strided manner across SMs
|
586 |
+
for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
|
587 |
+
# Calculate tile indices
|
588 |
+
tile_n_idx = tile_idx % num_n_tiles
|
589 |
+
tile_k_idx = tile_idx // num_n_tiles
|
590 |
+
|
591 |
+
# Calculate global offsets
|
592 |
+
n_offset = tile_n_idx * BLOCK_SIZE_N
|
593 |
+
k_offset = tile_k_idx * BLOCK_SIZE_K
|
594 |
+
|
595 |
+
# Initialize accumulator for this output tile [N, K]
|
596 |
+
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
597 |
+
|
598 |
+
# Process each group
|
599 |
+
M_end = 0
|
600 |
+
for g in range(G):
|
601 |
+
# Get group boundaries
|
602 |
+
M_start = M_end
|
603 |
+
m_size = tl.load(m_sizes + g)
|
604 |
+
M_end = M_start + m_size
|
605 |
+
|
606 |
+
# Only process if group is non-empty
|
607 |
+
if m_size > 0:
|
608 |
+
# Process this group in chunks along the M dimension
|
609 |
+
for m_offset in range(0, m_size, BLOCK_SIZE_M):
|
610 |
+
# Calculate actual block size (handling boundary)
|
611 |
+
m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
|
612 |
+
|
613 |
+
# Only process if we have actual work to do
|
614 |
+
if m_block_size > 0:
|
615 |
+
# Global offset for this chunk
|
616 |
+
m_global_offset = M_start + m_offset
|
617 |
+
|
618 |
+
if USE_TMA_LOAD:
|
619 |
+
# Load input chunk [M_chunk, K] using TMA
|
620 |
+
x_block = tl._experimental_descriptor_load(
|
621 |
+
x_desc_ptr,
|
622 |
+
[m_global_offset, k_offset],
|
623 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
624 |
+
c_dtype,
|
625 |
+
)
|
626 |
+
|
627 |
+
# Load grad_output chunk [M_chunk, N] using TMA
|
628 |
+
grad_output_block = tl._experimental_descriptor_load(
|
629 |
+
grad_output_desc_ptr,
|
630 |
+
[m_global_offset, n_offset],
|
631 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
632 |
+
c_dtype,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Apply masks for valid regions
|
636 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
637 |
+
m_mask = offs_m < m_block_size
|
638 |
+
|
639 |
+
# Zero out invalid elements
|
640 |
+
x_block = tl.where(m_mask[:, None], x_block, 0.0)
|
641 |
+
grad_output_block = tl.where(
|
642 |
+
m_mask[:, None], grad_output_block, 0.0
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
# Manual load with bounds checking
|
646 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
647 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
648 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
649 |
+
|
650 |
+
# Create masks
|
651 |
+
m_mask = offs_m < m_block_size
|
652 |
+
n_mask = offs_n < N - n_offset
|
653 |
+
k_mask = offs_k < K - k_offset
|
654 |
+
|
655 |
+
# Combined masks
|
656 |
+
mk_mask = m_mask[:, None] & k_mask[None, :]
|
657 |
+
mn_mask = m_mask[:, None] & n_mask[None, :]
|
658 |
+
|
659 |
+
# Global offsets for loading
|
660 |
+
m_global_offs = m_global_offset + offs_m
|
661 |
+
|
662 |
+
# Load x block [M_chunk, K]
|
663 |
+
x_block = tl.load(
|
664 |
+
x_desc_ptr
|
665 |
+
+ m_global_offs[:, None] * K
|
666 |
+
+ (k_offset + offs_k)[None, :],
|
667 |
+
mask=mk_mask,
|
668 |
+
other=0.0,
|
669 |
+
)
|
670 |
+
|
671 |
+
# Load grad_output block [M_chunk, N]
|
672 |
+
grad_output_block = tl.load(
|
673 |
+
grad_output_desc_ptr
|
674 |
+
+ m_global_offs[:, None] * N
|
675 |
+
+ (n_offset + offs_n)[None, :],
|
676 |
+
mask=mn_mask,
|
677 |
+
other=0.0,
|
678 |
+
)
|
679 |
+
|
680 |
+
# Compute partial contribution: grad_W += grad_Y.T @ X
|
681 |
+
# transpose grad_output for the matmul
|
682 |
+
contribution = tl.dot(
|
683 |
+
grad_output_block.to(tl.float32).T, # [N, M_chunk]
|
684 |
+
x_block.to(tl.float32), # [M_chunk, K]
|
685 |
+
)
|
686 |
+
|
687 |
+
# Accumulate
|
688 |
+
accumulator += contribution
|
689 |
+
|
690 |
+
# Store the result
|
691 |
+
if USE_TMA_STORE:
|
692 |
+
# Store using TMA
|
693 |
+
tl._experimental_descriptor_store(
|
694 |
+
workspace, # TMA store descriptor
|
695 |
+
accumulator.to(c_dtype),
|
696 |
+
[n_offset, k_offset],
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
# Manual store with bounds checking
|
700 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
701 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
702 |
+
|
703 |
+
# Create masks for bounds checking
|
704 |
+
n_mask = offs_n < N - n_offset
|
705 |
+
k_mask = offs_k < K - k_offset
|
706 |
+
output_mask = n_mask[:, None] & k_mask[None, :]
|
707 |
+
|
708 |
+
# Store the result
|
709 |
+
tl.store(
|
710 |
+
grad_weight_ptr
|
711 |
+
+ (n_offset + offs_n)[:, None] * K
|
712 |
+
+ (k_offset + offs_k)[None, :],
|
713 |
+
accumulator.to(c_dtype),
|
714 |
+
mask=output_mask,
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
# ======== End Triton kernels ========
|
719 |
+
|
720 |
+
# ======== Triton wrapper functions ========
|
721 |
+
|
722 |
+
# ----- main forward pass wrapper -----
|
723 |
+
|
724 |
+
|
725 |
+
def grouped_gemm_forward(
|
726 |
+
x: torch.Tensor,
|
727 |
+
w: torch.Tensor,
|
728 |
+
m_sizes: torch.Tensor,
|
729 |
+
tma_size: int = 128,
|
730 |
+
) -> torch.Tensor:
|
731 |
+
"""
|
732 |
+
M*G style grouped GEMM with TMA and Float8 support.
|
733 |
+
# Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
|
734 |
+
|
735 |
+
"""
|
736 |
+
if not CudaUtils.verify_tma():
|
737 |
+
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
738 |
+
|
739 |
+
G = m_sizes.shape[0]
|
740 |
+
|
741 |
+
assert x.is_contiguous()
|
742 |
+
assert w.is_contiguous()
|
743 |
+
assert m_sizes.is_contiguous()
|
744 |
+
|
745 |
+
# Total input size is now [M_total, K] where M_total is the sum of all group sizes
|
746 |
+
M_total, K = x.shape
|
747 |
+
N = w.shape[0] # N is now the same for all groups
|
748 |
+
|
749 |
+
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
750 |
+
|
751 |
+
# Verify that all group sizes are multiples of ALIGN_SIZE_M
|
752 |
+
# This check is commented out because it will involve a GPU-CPU sync
|
753 |
+
# assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
|
754 |
+
|
755 |
+
# Create output tensor with correct shape [M_total, N]
|
756 |
+
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
757 |
+
|
758 |
+
if M_total == 0:
|
759 |
+
return y
|
760 |
+
|
761 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
762 |
+
USE_TMA_LOAD = True
|
763 |
+
USE_TMA_STORE = True
|
764 |
+
USE_EPILOGUE_SUBTILING = False
|
765 |
+
|
766 |
+
# TMA descriptor helper
|
767 |
+
desc_helper = None
|
768 |
+
desc_x = x
|
769 |
+
desc_w = w
|
770 |
+
workspace = None
|
771 |
+
|
772 |
+
if USE_TMA_LOAD:
|
773 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
774 |
+
desc_helper.init_tma_descriptor("x")
|
775 |
+
desc_helper.init_tma_descriptor("w")
|
776 |
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
777 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
778 |
+
|
779 |
+
if USE_TMA_STORE:
|
780 |
+
workspace = torch.empty(
|
781 |
+
NUM_SMS * desc_helper.tma_size,
|
782 |
+
device=x.device,
|
783 |
+
dtype=torch.uint8,
|
784 |
+
)
|
785 |
+
|
786 |
+
def grid(META):
|
787 |
+
if USE_TMA_LOAD:
|
788 |
+
nonlocal desc_helper
|
789 |
+
desc_helper.fill_2d_tma_descriptor(
|
790 |
+
"x",
|
791 |
+
x.data_ptr(),
|
792 |
+
M_total,
|
793 |
+
K,
|
794 |
+
META["BLOCK_SIZE_M"],
|
795 |
+
META["BLOCK_SIZE_K"],
|
796 |
+
x.element_size(),
|
797 |
+
)
|
798 |
+
|
799 |
+
desc_helper.fill_2d_tma_descriptor(
|
800 |
+
"w",
|
801 |
+
w.data_ptr(),
|
802 |
+
N,
|
803 |
+
K,
|
804 |
+
META["BLOCK_SIZE_N"],
|
805 |
+
META["BLOCK_SIZE_K"],
|
806 |
+
w.element_size(),
|
807 |
+
)
|
808 |
+
return (NUM_SMS,)
|
809 |
+
|
810 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
811 |
+
|
812 |
+
_kernel_mg_forward_hopper[grid](
|
813 |
+
desc_x,
|
814 |
+
desc_w,
|
815 |
+
y,
|
816 |
+
workspace,
|
817 |
+
m_sizes,
|
818 |
+
G,
|
819 |
+
M_BUCKET,
|
820 |
+
N,
|
821 |
+
K,
|
822 |
+
NUM_SMS,
|
823 |
+
TMA_SIZE=tma_size,
|
824 |
+
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
825 |
+
)
|
826 |
+
|
827 |
+
return y
|
828 |
+
|
829 |
+
|
830 |
+
# ======== Improved Backward =============
|
831 |
+
def grouped_gemm_backward(
|
832 |
+
grad_output: torch.Tensor,
|
833 |
+
x: torch.Tensor,
|
834 |
+
w: torch.Tensor,
|
835 |
+
m_sizes: torch.Tensor,
|
836 |
+
use_tma: bool = True,
|
837 |
+
tma_size: int = 128,
|
838 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
839 |
+
"""
|
840 |
+
Unified backward pass for grouped GeMM with M*G grouping.
|
841 |
+
Uses optimized TMA-based implementations for both dx and dw when available.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
grad_output: Gradient of output, shape [M_total, N]
|
845 |
+
x: Input tensor from forward pass, shape [M_total, K]
|
846 |
+
w: Weight tensor from forward pass, shape [N, K]
|
847 |
+
m_sizes: Group sizes tensor, shape [G]
|
848 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
849 |
+
tma_size: Size of TMA descriptor in bytes
|
850 |
+
|
851 |
+
|
852 |
+
Returns:
|
853 |
+
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
854 |
+
"""
|
855 |
+
logging.info("Starting unified grouped_gemm_backward")
|
856 |
+
|
857 |
+
# do this once, seems expensive
|
858 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
859 |
+
|
860 |
+
# Basic validation
|
861 |
+
G = m_sizes.shape[0]
|
862 |
+
M_total, K_x = x.shape
|
863 |
+
M_grad, N = grad_output.shape
|
864 |
+
N_w, K_w = w.shape
|
865 |
+
|
866 |
+
# Check dimensions
|
867 |
+
if K_x != K_w:
|
868 |
+
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
869 |
+
if M_total != M_grad:
|
870 |
+
raise ValueError(
|
871 |
+
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
872 |
+
)
|
873 |
+
|
874 |
+
# Check total M matches sum of group sizes
|
875 |
+
sum_m_sizes = m_sizes.sum().item()
|
876 |
+
if M_total != sum_m_sizes:
|
877 |
+
raise ValueError(
|
878 |
+
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
879 |
+
)
|
880 |
+
|
881 |
+
# Make sure inputs are contiguous
|
882 |
+
grad_output = grad_output.contiguous()
|
883 |
+
x = x.contiguous()
|
884 |
+
w = w.contiguous()
|
885 |
+
m_sizes = m_sizes.contiguous()
|
886 |
+
|
887 |
+
# Check TMA support
|
888 |
+
can_use_tma = use_tma and CudaUtils.verify_tma()
|
889 |
+
if use_tma and not can_use_tma:
|
890 |
+
logging.info("TMA requested but not supported on this device")
|
891 |
+
use_tma = False
|
892 |
+
|
893 |
+
# Compute grad_x using flat linear implementation
|
894 |
+
try:
|
895 |
+
logging.info(f"Computing grad_x with flat linear kernel")
|
896 |
+
|
897 |
+
# Use TMA-optimized implementation
|
898 |
+
grad_x = grouped_gemm_dx_tma(
|
899 |
+
grad_output=grad_output,
|
900 |
+
w=w,
|
901 |
+
m_sizes=m_sizes,
|
902 |
+
num_sms=NUM_SMS,
|
903 |
+
tma_size=tma_size,
|
904 |
+
)
|
905 |
+
|
906 |
+
except Exception as e:
|
907 |
+
logging.error(f"Error in grad_x computation: {e}")
|
908 |
+
raise
|
909 |
+
|
910 |
+
# Compute grad_w using flat linear style implementation
|
911 |
+
try:
|
912 |
+
logging.info(f"Computing grad_w with flat linear kernel")
|
913 |
+
|
914 |
+
grad_w = grouped_gemm_dw_tma(
|
915 |
+
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
916 |
+
)
|
917 |
+
except Exception as e:
|
918 |
+
logging.error(f"Error in grad_w computation: {e}")
|
919 |
+
raise
|
920 |
+
|
921 |
+
return grad_x, grad_w
|
922 |
+
|
923 |
+
|
924 |
+
# ----- dx backward pass wrapper -----
|
925 |
+
|
926 |
+
|
927 |
+
def grouped_gemm_dx_tma(
|
928 |
+
grad_output: torch.Tensor,
|
929 |
+
w: torch.Tensor,
|
930 |
+
m_sizes: torch.Tensor,
|
931 |
+
num_sms: int = 132,
|
932 |
+
tma_size: int = 128,
|
933 |
+
) -> torch.Tensor:
|
934 |
+
"""
|
935 |
+
Optimized backward pass wrapper for computing gradient with respect to input (dx)
|
936 |
+
using TMA patterns similar to the forward pass.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
grad_output: Gradient of output, shape [M_total, N]
|
940 |
+
w: Weight tensor, shape [N, K]
|
941 |
+
m_sizes: Group sizes tensor, shape [G]
|
942 |
+
tma_size: Size of TMA descriptor
|
943 |
+
# using_fp8: Whether to use FP8 quantization
|
944 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
945 |
+
# w_scale: Scale for w in FP8 mode
|
946 |
+
|
947 |
+
Returns:
|
948 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
949 |
+
"""
|
950 |
+
"""
|
951 |
+
Optimized backward pass for computing gradient with respect to input (dx)
|
952 |
+
using TMA patterns similar to the forward pass.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
grad_output: Gradient of output, shape [M_total, N]
|
956 |
+
w: Weight tensor, shape [N, K]
|
957 |
+
m_sizes: Group sizes tensor, shape [G]
|
958 |
+
tma_size: Size of TMA descriptor
|
959 |
+
using_fp8: Whether to use FP8 quantization
|
960 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
961 |
+
# w_scale: Scale for w in FP8 mode
|
962 |
+
|
963 |
+
Returns:
|
964 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
965 |
+
"""
|
966 |
+
if not CudaUtils.verify_tma():
|
967 |
+
raise NotImplementedError("Optimized dx computation requires TMA support")
|
968 |
+
|
969 |
+
G = m_sizes.shape[0]
|
970 |
+
|
971 |
+
assert grad_output.is_contiguous()
|
972 |
+
assert w.is_contiguous()
|
973 |
+
assert m_sizes.is_contiguous()
|
974 |
+
|
975 |
+
M_total, N_grad = grad_output.shape
|
976 |
+
N_w, K = w.shape
|
977 |
+
|
978 |
+
# Check dimensions
|
979 |
+
assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
|
980 |
+
|
981 |
+
# Verify that the sum of m_sizes matches M_total
|
982 |
+
sum_m_sizes = m_sizes.sum().item()
|
983 |
+
assert (
|
984 |
+
M_total == sum_m_sizes
|
985 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
986 |
+
|
987 |
+
# Create output tensor (grad_x) with shape [M_total, K]
|
988 |
+
grad_x = torch.empty(
|
989 |
+
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
990 |
+
)
|
991 |
+
|
992 |
+
NUM_SMS = num_sms # CudaUtils.get_num_sms()
|
993 |
+
USE_TMA_LOAD = True
|
994 |
+
USE_TMA_STORE = True
|
995 |
+
|
996 |
+
# Set up TMA descriptors
|
997 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
998 |
+
desc_helper.init_tma_descriptor("grad_output")
|
999 |
+
desc_helper.init_tma_descriptor("w")
|
1000 |
+
desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
|
1001 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
1002 |
+
|
1003 |
+
# Allocate workspace for TMA store
|
1004 |
+
workspace = torch.empty(
|
1005 |
+
NUM_SMS * desc_helper.tma_size,
|
1006 |
+
device=grad_output.device,
|
1007 |
+
dtype=torch.uint8,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
def grid(META):
|
1011 |
+
# Fill TMA descriptors with appropriate dimensions
|
1012 |
+
desc_helper.fill_2d_tma_descriptor(
|
1013 |
+
"grad_output",
|
1014 |
+
grad_output.data_ptr(),
|
1015 |
+
M_total,
|
1016 |
+
N_grad,
|
1017 |
+
META["BLOCK_SIZE_M"],
|
1018 |
+
META["BLOCK_SIZE_N"],
|
1019 |
+
grad_output.element_size(),
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
desc_helper.fill_2d_tma_descriptor(
|
1023 |
+
"w",
|
1024 |
+
w.data_ptr(),
|
1025 |
+
N_w,
|
1026 |
+
K,
|
1027 |
+
META["BLOCK_SIZE_N"],
|
1028 |
+
META["BLOCK_SIZE_K"],
|
1029 |
+
w.element_size(),
|
1030 |
+
)
|
1031 |
+
return (NUM_SMS,)
|
1032 |
+
|
1033 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1034 |
+
|
1035 |
+
# Launch the flat linear kernel for computing grad_x
|
1036 |
+
_kernel_mg_dx_tma[grid](
|
1037 |
+
desc_grad_output,
|
1038 |
+
desc_w,
|
1039 |
+
grad_x,
|
1040 |
+
workspace,
|
1041 |
+
m_sizes,
|
1042 |
+
G,
|
1043 |
+
M_BUCKET,
|
1044 |
+
N_grad, # N dimension is now the reduction dimension
|
1045 |
+
K,
|
1046 |
+
NUM_SMS,
|
1047 |
+
USE_TMA_LOAD,
|
1048 |
+
USE_TMA_STORE,
|
1049 |
+
TMA_SIZE=tma_size,
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
return grad_x
|
1053 |
+
|
1054 |
+
|
1055 |
+
# ======== dw wrapper function ==========
|
1056 |
+
|
1057 |
+
|
1058 |
+
def grouped_gemm_dw_tma(
|
1059 |
+
x: torch.Tensor,
|
1060 |
+
grad_output: torch.Tensor,
|
1061 |
+
m_sizes: torch.Tensor,
|
1062 |
+
num_sms: int = 132,
|
1063 |
+
tma_size: int = 128,
|
1064 |
+
) -> torch.Tensor:
|
1065 |
+
"""
|
1066 |
+
Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
|
1067 |
+
For the forward pass Y = X @ W.T, the backward for weights is:
|
1068 |
+
grad_W = grad_Y.T @ X
|
1069 |
+
|
1070 |
+
Args:
|
1071 |
+
x: Input tensor, shape [M_total, K]
|
1072 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1073 |
+
m_sizes: Group sizes tensor, shape [G]
|
1074 |
+
tma_size: Size of TMA descriptor in bytes
|
1075 |
+
|
1076 |
+
|
1077 |
+
Returns:
|
1078 |
+
grad_w: Gradient with respect to weights, shape [N, K]
|
1079 |
+
"""
|
1080 |
+
# Check TMA support
|
1081 |
+
has_tma_support = CudaUtils.verify_tma()
|
1082 |
+
|
1083 |
+
# Get group count
|
1084 |
+
G = m_sizes.shape[0]
|
1085 |
+
|
1086 |
+
# Ensure contiguous tensors
|
1087 |
+
x = x.contiguous()
|
1088 |
+
grad_output = grad_output.contiguous()
|
1089 |
+
m_sizes = m_sizes.contiguous()
|
1090 |
+
|
1091 |
+
# Get dimensions
|
1092 |
+
M_total, K_x = x.shape
|
1093 |
+
M_grad, N = grad_output.shape
|
1094 |
+
|
1095 |
+
# Check dimensions
|
1096 |
+
assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
|
1097 |
+
|
1098 |
+
# Verify that the sum of m_sizes matches M_total
|
1099 |
+
sum_m_sizes = m_sizes.sum().item()
|
1100 |
+
assert (
|
1101 |
+
sum_m_sizes == M_total
|
1102 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
1103 |
+
|
1104 |
+
# Create output tensor (grad_w) with shape [N, K]
|
1105 |
+
grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
|
1106 |
+
|
1107 |
+
NUM_SMS = num_sms
|
1108 |
+
|
1109 |
+
# TODO - hardcoded for now...but should set TMA flags based on hardware support
|
1110 |
+
USE_TMA_LOAD = True # has_tma_support
|
1111 |
+
USE_TMA_STORE = True # has_tma_support
|
1112 |
+
|
1113 |
+
# Set up TMA descriptors or direct pointers
|
1114 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1115 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
1116 |
+
|
1117 |
+
if USE_TMA_LOAD:
|
1118 |
+
desc_helper.init_tma_descriptor("x")
|
1119 |
+
desc_helper.init_tma_descriptor("grad_output")
|
1120 |
+
x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
|
1121 |
+
grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
|
1122 |
+
"grad_output"
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
x_desc = x
|
1126 |
+
grad_output_desc = grad_output
|
1127 |
+
|
1128 |
+
if USE_TMA_STORE:
|
1129 |
+
desc_helper.init_tma_descriptor("grad_w")
|
1130 |
+
workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
|
1131 |
+
else:
|
1132 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1133 |
+
else:
|
1134 |
+
# If not using TMA, just use the tensors directly
|
1135 |
+
x_desc = x
|
1136 |
+
grad_output_desc = grad_output
|
1137 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1138 |
+
|
1139 |
+
# M_BUCKET for grid size
|
1140 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1141 |
+
|
1142 |
+
# Define grid for kernel launch
|
1143 |
+
def grid(META):
|
1144 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1145 |
+
|
1146 |
+
if USE_TMA_LOAD:
|
1147 |
+
desc_helper.fill_2d_tma_descriptor(
|
1148 |
+
"x",
|
1149 |
+
x.data_ptr(),
|
1150 |
+
M_total,
|
1151 |
+
K_x,
|
1152 |
+
META["BLOCK_SIZE_M"],
|
1153 |
+
META["BLOCK_SIZE_K"],
|
1154 |
+
x.element_size(),
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
desc_helper.fill_2d_tma_descriptor(
|
1158 |
+
"grad_output",
|
1159 |
+
grad_output.data_ptr(),
|
1160 |
+
M_total,
|
1161 |
+
N,
|
1162 |
+
META["BLOCK_SIZE_M"],
|
1163 |
+
META["BLOCK_SIZE_N"],
|
1164 |
+
grad_output.element_size(),
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
if USE_TMA_STORE:
|
1168 |
+
desc_helper.fill_2d_tma_descriptor(
|
1169 |
+
"grad_w",
|
1170 |
+
grad_w.data_ptr(),
|
1171 |
+
N,
|
1172 |
+
K_x,
|
1173 |
+
META["BLOCK_SIZE_N"],
|
1174 |
+
META["BLOCK_SIZE_K"],
|
1175 |
+
grad_w.element_size(),
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
# Return grid size - one block per SM for balanced work distribution
|
1179 |
+
return (NUM_SMS,)
|
1180 |
+
|
1181 |
+
# Launch the optimized kernel
|
1182 |
+
_kernel_mg_dw_tma[grid](
|
1183 |
+
x_desc,
|
1184 |
+
grad_output_desc,
|
1185 |
+
grad_w,
|
1186 |
+
workspace,
|
1187 |
+
m_sizes,
|
1188 |
+
G,
|
1189 |
+
M_BUCKET,
|
1190 |
+
N,
|
1191 |
+
K_x,
|
1192 |
+
NUM_SMS,
|
1193 |
+
USE_TMA_LOAD,
|
1194 |
+
USE_TMA_STORE,
|
1195 |
+
TMA_SIZE=tma_size,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
return grad_w
|
1199 |
+
|
1200 |
+
|
1201 |
+
# ======== End Backwards Wrapper Functions =============
|
1202 |
+
|
1203 |
+
# ======== PyTorch wrapper functions ========
|
1204 |
+
|
1205 |
+
|
1206 |
+
class GroupedGEMM_mg(torch.autograd.Function):
|
1207 |
+
"""
|
1208 |
+
Autograd function for GroupedGEMM with M*G grouping.
|
1209 |
+
Supports both standard and FP8 quantized operations.
|
1210 |
+
"""
|
1211 |
+
|
1212 |
+
@staticmethod
|
1213 |
+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
|
1214 |
+
"""
|
1215 |
+
Forward pass of GroupedGEMM.
|
1216 |
+
|
1217 |
+
Args:
|
1218 |
+
x: Input tensor, shape [M_total, K]
|
1219 |
+
w: Weight tensor, shape [N, K]
|
1220 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1221 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1222 |
+
tma_size: Size of TMA descriptor in bytes
|
1223 |
+
using_fp8: Whether to use FP8 quantization
|
1224 |
+
|
1225 |
+
Returns:
|
1226 |
+
Output tensor, shape [M_total, N]
|
1227 |
+
"""
|
1228 |
+
|
1229 |
+
# Use regular forward without quantization
|
1230 |
+
output = grouped_gemm_forward(
|
1231 |
+
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
# Save inputs and parameters for backward pass
|
1235 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1236 |
+
ctx.use_tma = use_tma
|
1237 |
+
ctx.tma_size = tma_size
|
1238 |
+
|
1239 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1240 |
+
|
1241 |
+
return output
|
1242 |
+
|
1243 |
+
@staticmethod
|
1244 |
+
def backward(ctx, grad_output):
|
1245 |
+
"""
|
1246 |
+
Backward pass of M*G GroupedGEMM.
|
1247 |
+
|
1248 |
+
Args:
|
1249 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1250 |
+
|
1251 |
+
Returns:
|
1252 |
+
Tuple of gradients:
|
1253 |
+
- grad_x: Gradient with respect to x, shape [M_total, K]
|
1254 |
+
- grad_w: Gradient with respect to w, shape [N, K]
|
1255 |
+
- None: Gradient with respect to m_sizes (not differentiable)
|
1256 |
+
- None: Gradient with respect to use_tma (not differentiable)
|
1257 |
+
- None: Gradient with respect to tma_size (not differentiable)
|
1258 |
+
|
1259 |
+
"""
|
1260 |
+
# Retrieve saved tensors and parameters
|
1261 |
+
|
1262 |
+
x, w, m_sizes = ctx.saved_tensors
|
1263 |
+
|
1264 |
+
use_tma = ctx.use_tma
|
1265 |
+
tma_size = ctx.tma_size
|
1266 |
+
|
1267 |
+
# Compute gradients using the unified implementation
|
1268 |
+
grad_x, grad_w = grouped_gemm_backward(
|
1269 |
+
grad_output=grad_output,
|
1270 |
+
x=x,
|
1271 |
+
w=w,
|
1272 |
+
m_sizes=m_sizes,
|
1273 |
+
use_tma=use_tma,
|
1274 |
+
tma_size=tma_size,
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
# Return gradients for all inputs (None for non-differentiable parameters)
|
1278 |
+
return grad_x, grad_w, None, None
|
1279 |
+
|
1280 |
+
|
1281 |
+
def mg_grouped_gemm(
|
1282 |
+
x: torch.Tensor,
|
1283 |
+
w: torch.Tensor,
|
1284 |
+
m_sizes: torch.Tensor,
|
1285 |
+
use_tma: bool = True,
|
1286 |
+
tma_size: int = 128,
|
1287 |
+
using_fp8: bool = False,
|
1288 |
+
) -> torch.Tensor:
|
1289 |
+
"""
|
1290 |
+
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
1291 |
+
Supports both standard precision and FP8 quantized operations.
|
1292 |
+
|
1293 |
+
Args:
|
1294 |
+
x: Input tensor, shape [M_total, K]
|
1295 |
+
w: Weight tensor, shape [N, K]
|
1296 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1297 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1298 |
+
tma_size: Size of TMA descriptor in bytes
|
1299 |
+
using_fp8: Whether to use FP8 quantization
|
1300 |
+
|
1301 |
+
Returns:
|
1302 |
+
Output tensor, shape [M_total, N]
|
1303 |
+
"""
|
1304 |
+
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# pyre-unsafe
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def compute_reference_forward(x, w, m_sizes):
|
20 |
+
"""
|
21 |
+
Compute reference forward pass using PyTorch operations.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
25 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
26 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: Reference output tensor of shape (M, N)
|
30 |
+
"""
|
31 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
32 |
+
|
33 |
+
m_start = 0
|
34 |
+
for g in range(len(m_sizes)):
|
35 |
+
m_size = m_sizes[g].item()
|
36 |
+
if m_size > 0:
|
37 |
+
m_end = m_start + m_size
|
38 |
+
|
39 |
+
# Extract group input
|
40 |
+
x_g = x[m_start:m_end]
|
41 |
+
|
42 |
+
# Compute group output: y_g = x_g @ w.T
|
43 |
+
y_g = torch.matmul(x_g, w.T)
|
44 |
+
|
45 |
+
# Store result
|
46 |
+
result[m_start:m_end] = y_g
|
47 |
+
|
48 |
+
# Update start index
|
49 |
+
m_start = m_end
|
50 |
+
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
def compute_reference_backward(x, w, m_sizes, grad_output):
|
55 |
+
"""
|
56 |
+
Compute reference backward pass using PyTorch autograd.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
60 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
61 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
62 |
+
grad_output (torch.Tensor): Gradient tensor of shape (M, N)
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
tuple: (grad_x, grad_w) gradient tensors
|
66 |
+
"""
|
67 |
+
# Create autograd-enabled copies
|
68 |
+
x_autograd = x.detach().clone().requires_grad_(True)
|
69 |
+
w_autograd = w.detach().clone().requires_grad_(True)
|
70 |
+
|
71 |
+
# Compute forward pass
|
72 |
+
output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
|
73 |
+
|
74 |
+
# Backpropagate
|
75 |
+
output.backward(grad_output)
|
76 |
+
|
77 |
+
return x_autograd.grad, w_autograd.grad
|
78 |
+
|
79 |
+
|
80 |
+
def analyze_tensor_differences(actual, expected, name):
|
81 |
+
"""
|
82 |
+
Analyze differences between actual and expected tensors.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
actual (torch.Tensor): Actual tensor
|
86 |
+
expected (torch.Tensor): Expected tensor
|
87 |
+
name (str): Name of the tensor for logging
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
bool: True if tensors are close enough
|
91 |
+
"""
|
92 |
+
rtol = 0.5 # Relative tolerance for float16
|
93 |
+
atol = 0.5 # Absolute tolerance for float16
|
94 |
+
|
95 |
+
# Analyze differences
|
96 |
+
diff = (actual - expected).abs()
|
97 |
+
max_idx = diff.argmax().item()
|
98 |
+
idx = np.unravel_index(max_idx, actual.shape)
|
99 |
+
max_diff = diff.max().item()
|
100 |
+
|
101 |
+
logging.info(f"Largest {name} difference: {max_diff} at {idx}")
|
102 |
+
logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
|
103 |
+
|
104 |
+
is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
|
105 |
+
|
106 |
+
if is_close:
|
107 |
+
logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
|
108 |
+
else:
|
109 |
+
logging.error(f"✗ FAILURE: {name} mismatch detected")
|
110 |
+
|
111 |
+
# Count zeros
|
112 |
+
zeros_actual = (actual == 0).sum().item()
|
113 |
+
zeros_expected = (expected == 0).sum().item()
|
114 |
+
logging.info(
|
115 |
+
f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
|
116 |
+
)
|
117 |
+
logging.info(
|
118 |
+
f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Check for NaNs
|
122 |
+
nan_actual = torch.isnan(actual).sum().item()
|
123 |
+
if nan_actual > 0:
|
124 |
+
logging.error(f"NaN values detected in {name}: {nan_actual}")
|
125 |
+
|
126 |
+
return is_close
|