diff --git a/flame/__pycache__/__init__.cpython-312.pyc b/flame/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0bcf73618d79ca7f9328bbf4ef75ec8c5aa8651 Binary files /dev/null and b/flame/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/__pycache__/config_manager.cpython-312.pyc b/flame/__pycache__/config_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42d49c5b24c7d9e03330d1d65c9892212983e6cb Binary files /dev/null and b/flame/__pycache__/config_manager.cpython-312.pyc differ diff --git a/flame/__pycache__/data.cpython-312.pyc b/flame/__pycache__/data.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb8e84d966fca0701d5055beb7002ff81a86dbf Binary files /dev/null and b/flame/__pycache__/data.cpython-312.pyc differ diff --git a/flame/components/__init__.py b/flame/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/components/__pycache__/__init__.cpython-312.pyc b/flame/components/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a02c989fd48052899975a391861955a2310936d Binary files /dev/null and b/flame/components/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/components/__pycache__/checkpoint.cpython-312.pyc b/flame/components/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18dbee654393d0ed515599c513502fef262e581 Binary files /dev/null and b/flame/components/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/flame/components/checkpoint.py b/flame/components/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a21fbbcaf9503c1f0f8d965acce420b223201b --- /dev/null +++ b/flame/components/checkpoint.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from datetime import timedelta +from io import BytesIO +from typing import Any, Dict, List + +import torch +from torch.distributed.checkpoint.stateful import Stateful + + +@dataclass +class TrainState(Stateful): + step: int = 0 + skipped_step: int = 0 + token: int = 0 + elapsed: timedelta = timedelta(0) + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32), + "token": torch.tensor(self.token, dtype=torch.int64), + "elapsed": self.elapsed, + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict) -> None: + self.step = state_dict["step"].item() + self.skipped_step = state_dict.get("skipped_step", 0).item() + self.token = state_dict["token"].item() + self.elapsed = state_dict["elapsed"] + state_dict["global_avg_losses"].seek(0) + self.global_avg_losses = torch.load( + state_dict["global_avg_losses"], weights_only=False + ) + state_dict["global_max_losses"].seek(0) + self.global_max_losses = torch.load( + state_dict["global_max_losses"], weights_only=False + ) + state_dict["log_steps"].seek(0) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) diff --git a/flame/models/__pycache__/__init__.cpython-312.pyc b/flame/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92371b191dc8ce79b7536c0f5990f056169da6ca Binary files /dev/null and b/flame/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/models/__pycache__/parallelize_fla.cpython-312.pyc b/flame/models/__pycache__/parallelize_fla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69adefd3f273aea0a37159d94931b26aec371ee2 Binary files /dev/null and b/flame/models/__pycache__/parallelize_fla.cpython-312.pyc differ diff --git a/flame/models/__pycache__/pipeline_fla.cpython-312.pyc b/flame/models/__pycache__/pipeline_fla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07152bb020c117d039f57386f4c95f4baf955c89 Binary files /dev/null and b/flame/models/__pycache__/pipeline_fla.cpython-312.pyc differ diff --git a/flame/models/activation_offloading.py b/flame/models/activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..80012c714ab205dd5529176c3c2e2ab18263ad63 --- /dev/null +++ b/flame/models/activation_offloading.py @@ -0,0 +1,447 @@ +# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Union +from warnings import warn + +import psutil +import torch +from torch import nn +from torch.autograd.graph import saved_tensors_hooks + +from torchtitan.tools.logging import logger + +try: + import torchao + from torchao.dtypes.nf4tensor import NF4Tensor +except ImportError: + torchao = None + NF4Tensor = None + logger.warning("torchao not found. ") + +# from torchtune.modules import TiedLinear + + +class OffloadActivations(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be offloaded. + + Enable the memory efficiency technique of activation offloading, where activations bigger than + min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. + This is in contrast to maintaining the activation on GPU VRAM throughout the program. + + This manager contains the option of using one additional CUDA stream to handle the communication + between CUDA and CPU, which is intended to overlap with the default computation stream to improve + runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between + runtime vs memory usage. + + Args: + use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned + memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly + but is a limited resource. Default: True. + + use_streams (bool): Whether or not to use streams for performance optimization where + the communications get overlapped with the computation. Requires a torch build + after torch-2.5.0.]. Default: True. + + max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of + consecutive activations to keep alive during the forward pass. This number must be at + least 1. Keeping alive more activations will potentially allow more overlap between the + communication and compute streams at the cost of increasing memory usage. Keeping alive + fewer activations will conserve memory, but may cause poor overlap between the streams, + increasing runtime. Default: 5. + + min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify + for offloading. If the tensor is too small, we do not want to waste bandwidth and resources + moving it to CPU and back. Default: 1024 bytes. + + Raises: + ValueError: if max_fwd_stash_size is not at least 1. + + Example: + >>> with OffloadActivations(): + >>> logits = model(inputs) + >>> loss = ... + >>> loss.backward() + """ + + def __init__( + self, + use_pin_memory: bool = True, + use_streams: bool = True, + max_fwd_stash_size: int = 5, + min_offload_size: int = 1024, + ) -> None: + + self.use_streams: bool = use_streams + + self.min_tensor_size_bytes = ( + min_offload_size # we don't want to bother with small tensors + ) + self.tracker = ( + {} + ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where + self.tensor_id: int = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + self.is_first_forward_pass = True + + # managing cpu memory + self.use_pin_memory: bool = use_pin_memory + self.virtual_memory_safe_pct = ( + 60 # we should not exceed this percentage of memory + ) + + self.s0 = torch.cuda.default_stream() # comp stream + + # for streaming + if self.use_streams: + self.s1 = torch.cuda.Stream() # comms stream + self.fwd_stash = {} # tensor_id => (activation, ev1) + if max_fwd_stash_size < 1: + raise ValueError( + f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" + ) + self.max_fwd_stash_size = max_fwd_stash_size + self.bwd_tensor_stash = {} # tensor_id => activation + self.bwd_ev_stash = {} # tensor_id => ev0 + self.curr_graph_id = None + self.curr_autograd_node = None + + # -------- platform util functions -------- # + def verify_sufficient_virtual_memory(): + curr_pct = get_cpu_ram_pct() + if curr_pct > self.virtual_memory_safe_pct: + warn( + f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" + ) + + def get_cpu_ram_pct() -> float: + # get the percentage of memory used by the system + return psutil.virtual_memory().percent + + def get_tensor_id() -> int: + # create a unique id for each tensor we are managing + self.tensor_id += 1 + return self.tensor_id + + def get_num_bytes_tensor(x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return ( + x.element_size() * x.nelement() + ) # x.element_size() * x._base_storage().nbytes() + + # -------- core pack / unpack work -------- # + def pack_tensor(activation: torch.Tensor) -> int: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward_call: + assert ( + len(self.tracker) == 0 + ), "backward pass should have cleared tracker of all tensors" + + # set training phase trackers + self.is_first_forward_call = False + self.is_first_backward_call = True + + # query for basic tensor info + num_bytes = get_num_bytes_tensor(activation) + tensor_id = get_tensor_id() + + # only offload hefty bois if they're activations on CUDA (our heuristic + # for that is to check if they're not params or buffers)! + if ( + activation.is_cuda + and num_bytes >= self.min_tensor_size_bytes + and ( + not isinstance(activation, torch.nn.Parameter) + and not isinstance(activation, torch.nn.Buffer) + ) + ): + if self.use_streams: + # First, sync back and dereference previously offloaded tensors + # as the offloading should be done sufficiently long ago. + for id in [k for k in self.fwd_stash.keys()]: + if id <= tensor_id - self.max_fwd_stash_size: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + else: + break + + # Sync in, offload, and add an event to sync back later + self.s1.wait_stream(self.s0) + + stream = self.s1 if self.use_streams else self.s0 + with torch.cuda.stream(stream): + try: + cpu_tensor = torch.empty_like( + activation, pin_memory=self.use_pin_memory, device="cpu" + ) + except NotImplementedError as e: + if ( + isinstance(activation, NF4Tensor) + and torchao.__version__ < "0.6.0.dev20240917" + ): + raise RuntimeError( + "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later" + ) from e + raise e + cpu_tensor.copy_(activation, non_blocking=True) + self.tracker[tensor_id] = ( + cpu_tensor, + True, + ) # True = (in future) modified + + if self.use_streams: + event = self.s1.record_event() + + # Stash to keep activation alive til s1 is done + self.fwd_stash[tensor_id] = (activation, event) + else: + self.tracker[tensor_id] = ( + activation, + False, + ) # False = not modified, tensor is as is + + return tensor_id + + def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) + maybe_gpu_tensor = gpu_tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + self.curr_graph_id = torch._C._current_graph_task_id() + + def wait_and_del_remaining_references() -> None: + for id in [k for k in self.bwd_tensor_stash.keys()]: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + # Register a callback to the end of autograd to clean everything up + torch.autograd.variable.Variable._execution_engine.queue_callback( + wait_and_del_remaining_references + ) + + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + # Get data on the current autograd node + graph_id = torch._C._current_graph_task_id() + node = torch._C._current_autograd_node() + prev_node_ids = [] + + # If we're on a new node, mark prev node's tensors to be freed later + if graph_id == self.curr_graph_id and self.curr_autograd_node != node: + self.curr_autograd_node = node + prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] + + brought_back_from_cpu = True + if unpack_tensor_id in self.fwd_stash: + maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] + brought_back_from_cpu = False + else: + # Kick off the process to bring tensors back + with torch.cuda.stream(self.s1): + gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) + maybe_gpu_tensor = gpu_tensor + + # Tell comp stream to wait for the info to be loaded before executing + self.s0.wait_stream(self.s1) + + # Stash the tensor to keep memory alive until compute stream is complete + self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor + + # Note: [Track views of the unpacked] + # Why do we get the use count of the unpacked tensor here? We want an + # initial count to compare to later, during the post-hook of the + # backward node, when we need to decide whether we're allowed to free + # the tensor yet. In what obscure cases must we delay freeing the + # tensor (and thus call record_stream)? + # 1. Any of the outputs of the backward node is a view of the unpacked + # tensor. + # 2. In the case that this unpacked tensor will be used in a + # checkpointed region, if one of the recomputed saved tensors ends + # up as a view of the unpacked tensor. + # 3. The user abuses the system somehow and manually relies on the + # unpacked tensor to exist after the backward node has executed. + storage_refcount = torch._C._storage_Use_Count( + maybe_gpu_tensor.untyped_storage()._cdata + ) + + def hook(outputs, inputs): + # create events for the current node inputs/outputs if they were streamed in + if brought_back_from_cpu: + # See Note: [Track views of the unpacked] + # IF any of the outputs is a view of the tensor, OR if a view of + # the tensor has been saved as a part of checkpoint's recompute + # process, OR the user has abusedly incurred a reference on the + # unpacked tensor, THEN the tensor might be used later and we + # cannot presume to delete it after only the current node is + # done! So we use our frenemy, record_stream, to ensure the + # Tensor stays unmessed with until it's done getting used in the + # compute stream (s0 here). Note that the con here is we introduce + # non-deterministic (thus higher) memory usage, but this case + # should not happen often. + unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] + if ( + torch._C._storage_Use_Count( + unpacked_tensor.untyped_storage()._cdata + ) + > storage_refcount + ): + unpacked_tensor.record_stream(self.s0) + del self.bwd_tensor_stash[unpack_tensor_id] + else: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event + + # if there are still things in the fwd_stash, get rid of them as we're in bwd now + for id in [k for k in self.fwd_stash.keys()]: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + + # wait on prev node's events and del those + for id in prev_node_ids: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + return outputs + + node.register_hook(hook) + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + unpack_tensor = ( + unpack_tensor_with_streams + if self.use_streams + else unpack_tensor_single_stream + ) + super().__init__(pack_tensor, unpack_tensor) + + +class NoOpManager(saved_tensors_hooks): + """ + A saved_tensors_hook manager used to disable any other saved_tensors_hook manager + applied before. This relies on the behavior that only the most recently registered + saved_tensors_hook will run. + + One example usage is to opt a local region of code out of activations offloading, + which is usually applied globally to best track state. + """ + + def __init__(self) -> None: + def noop(tensor): + return tensor + + super().__init__(noop, noop) + + +def get_act_offloading_ctx_manager( + model: nn.Module, enable_activation_offloading: bool +) -> Union[OffloadActivations, contextlib.nullcontext]: + """Returns the activation offloading context manager for the model, which will be + a null context if enable_activation_offloading is False. + + If activation offloading is enabled, we return the OffloadActivations context manager. + If activation offloading is disabled, we return a NoOpManager context manager. + + Args: + model (nn.Module): the model to wrap with the activation offloading context manager. + enable_activation_offloading (bool): whether or not to enable activation offloading + for the model. + + Returns: + contextlib.ContextDecorator: the activation offloading context manager for the model. + + Raises: + NotImplementedError: If the model is a multimodal model and activation offloading is enabled. + """ + if enable_activation_offloading: + activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + output_head_detected = False + noop_ctx = NoOpManager() + + if hasattr(model, "output"): + if isinstance(model.output, nn.Module): + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + print("registering hooks for model.output ============ ") + output_head_detected = True + # ================================ + # ! TODO[flame] check if we need to detal with TiedLinear + # The following code appears in `torchtune` + # elif isinstance(model.output, TiedLinear): + # model.output.linear.register_forward_pre_hook( + # lambda *args: noop_ctx.__enter__() + # ) + # model.output.linear.register_forward_hook( + # lambda *args: noop_ctx.__exit__(), always_call=True + # ) + # output_head_detected = True + + if not output_head_detected: + logger.warning( + "During activation offloading, no output head was detected. " + "If your model has an output head, it will be offloaded. " + "This usually greatly slows training, given the large vocabulary size. " + "To change this behavior, set your output head as model.output and make it " + "an nn.Module." + ) + + else: + activations_handling_ctx = contextlib.nullcontext() + + return activations_handling_ctx diff --git a/flame/models/fla.toml b/flame/models/fla.toml new file mode 100644 index 0000000000000000000000000000000000000000..afd3a212bbef8206c10714307c6df738051e83db --- /dev/null +++ b/flame/models/fla.toml @@ -0,0 +1,67 @@ +[model] +config = "fla-hub/transformer-1.3B-100B" +tokenizer_path = "fla-hub/transformer-1.3B-100B" + +[job] +dump_folder = "exp" +print_args = true + +[training] +batch_size = 32 +seq_len = 2048 +context_len = 2048 +gradient_accumulation_steps = 1 +steps = 20480 +max_norm = 1.0 +skip_nan_inf = true +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +compile = false +dataset = "HuggingFaceFW/fineweb-edu" +dataset_name = "default" +num_workers = 32 +pin_memory = false +persistent_workers = false +prefetch_factor = 2 +seed = 42 +varlen = false + +[optimizer] +name = "AdamW" +eps = 1e-15 +lr = 3e-4 + +[lr_scheduler] +warmup_steps = 1024 +decay_type = "cosine" +lr_min = 0.1 + +[checkpoint] +enable_checkpoint = true +folder = "checkpoint" +interval_type = "steps" +interval = 2048 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 512 + +[metrics] +log_freq = 32 +enable_wandb = true + +[experimental] +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false + +[activation_checkpoint] +mode = "none" \ No newline at end of file diff --git a/flame/models/parallelize_fla.py b/flame/models/parallelize_fla.py new file mode 100644 index 0000000000000000000000000000000000000000..37178af1bf365b3f5179cefc62000bf8f2f4ded3 --- /dev/null +++ b/flame/models/parallelize_fla.py @@ -0,0 +1,550 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, + parallelize_module +) + +from fla.modules.fused_linear_cross_entropy import LinearLossParallel +from fla.modules.mlp import SwiGLULinearParallel +from fla.modules.parallel import PrepareModuleWeight +from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.tools.logging import logger + + +def parallelize_fla( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if ( + job_config.experimental.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + enable_float8_linear = "float8" in job_config.model.converters + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-block compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + + +class TPPlan: + def __init__( + self, + model=None, + loss_parallel=False, + enable_float8=False, + ): + self.model = model + self.loss_parallel = loss_parallel + self.enable_float8 = enable_float8 + self.base_model_prefix = getattr(model, "base_model_prefix", "model") + + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + try: + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput + ) + except ImportError: + Float8ColwiseParallel = None + Float8RowwiseParallel = None + PrepareFloat8ModuleInput = None + if self.enable_float8 and Float8ColwiseParallel is not None: + self.rowwise_parallel = Float8RowwiseParallel + self.colwise_parallel = Float8ColwiseParallel + self.prepare_module_input = PrepareFloat8ModuleInput + self.prepare_module_output = PrepareModuleOutput + else: + self.rowwise_parallel = RowwiseParallel + self.colwise_parallel = ColwiseParallel + self.prepare_module_input = PrepareModuleInput + self.prepare_module_output = PrepareModuleOutput + + @property + def model_plan(self): + plans = { + f"{self.base_model_prefix}.embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + f"{self.base_model_prefix}.norm": SequenceParallel(), + } + if self.loss_parallel: + plans.update( + { + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if self.loss_parallel else Replicate(), + use_local_output=not self.loss_parallel, + ), + } + ) + else: + plans.update( + { + "lm_head": PrepareModuleWeight(layouts=Replicate()), + "criterion": LinearLossParallel(), + } + ) + return plans + + @property + def layer_plan(self): + return { + "attn_norm": SequenceParallel(), + **self.attn_plan, + "mlp_norm": SequenceParallel(), + **self.mlp_plan, + } + + @property + def attn_plan(self): + raise NotImplementedError( + f"TP plans for token mixing layers of {self.model.config.model_type} not implemented" + ) + + @property + def mlp_plan(self): + return { + "mlp": self.prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": self.colwise_parallel(), + "mlp.up_proj": self.colwise_parallel(), + "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)), + "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)), + } + + +class TransformerTPPlan(TPPlan): + + @property + def attn_plan(self): + return { + "attn": self.prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "attn.q_proj": self.colwise_parallel(), + "attn.k_proj": self.colwise_parallel(), + "attn.v_proj": self.colwise_parallel(), + "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), + } + + +class GLATPPlan(TPPlan): + + @property + def attn_plan(self): + return { + "attn": self.prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "attn.q_proj": self.colwise_parallel(), + "attn.k_proj": self.colwise_parallel(), + "attn.v_proj": self.colwise_parallel(), + "attn.g_proj": self.colwise_parallel(), + "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()), + "attn.gk_proj.1": self.colwise_parallel(), + "attn.g_norm": SequenceParallel(sequence_dim=-1), + "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), + } + + +TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan} + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + tp_plan = TP_PLAN_MAP[model.config.model_type]( + model, loss_parallel=loss_parallel, enable_float8=enable_float8 + ) + parallelize_module(model, tp_mesh, tp_plan.model_plan) + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for tensor parallelism") + else: + for _, block in enumerate(blocks): + parallelize_module( + module=block, + device_mesh=tp_mesh, + parallelize_plan=tp_plan.layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for activation checkpointing") + return + + for layer_id, block in blocks.named_children(): + block = _apply_ac_to_block(block, ac_config) + blocks.register_module(layer_id, block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each block, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for torch.compile") + else: + for layer_id, block in blocks.named_children(): + block = torch.compile(block) + blocks.register_module(layer_id, block) + logger.info("Compiling each block with torch.compile") + + real_model = get_model(model) + + logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile") + embeddings_key = get_components_name(real_model, "tok_embeddings") + if embeddings_key is not None: + embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True) + real_model.register_module(embeddings_key, embeddings) + + norm_key = get_components_name(real_model, "norm") + if norm_key is not None: + norm = torch.compile(getattr(real_model, norm_key), fullgraph=True) + real_model.register_module(norm_key, norm) + + lm_head_key = get_components_name(model, "lm_head") + if lm_head_key is not None: + lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True) + model.register_module(lm_head_key, lm_head) + + logger.info("Compiling the entire model with torch.compile") + model = torch.compile(model) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): + The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for FSDP") + else: + total_blocks = len(blocks) + for layer_id, block in enumerate(blocks): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < total_blocks - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + fully_shard( + block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + + +def get_model(model): + base_model_prefix = getattr(model, "base_model_prefix", "model") + if not hasattr(model, base_model_prefix): + return None + model = getattr(model, base_model_prefix) + return model + + +def get_blocks(model): + # TODO[flame]: adapt for network not using 'layers' attribute + model = get_model(model) + if not hasattr(model, "layers"): + logger.warning('no "layers" in model can be found') + return None + return model.layers + + +def get_components_name(model, component_name): + """ + We try to catch tok_embeddings, norm layers and lm_head layers + We do not catch the layer names in the blocks, for blocks see `get_blocks` + We assume the model has the following structure: + LlamaForCausalLM: + Model: + embed_tokens, + layers, + norm, + lm_head + *** + so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)` + and for 'lm_head' we need to pass `model` + *** + """ + + if component_name == "tok_embeddings": + if hasattr(model, "tok_embeddings"): + return "tok_embeddings" + elif hasattr(model, "embed_tokens"): + return "embed_tokens" + elif hasattr(model, "embeddings"): + return "embeddings" + else: + logger.warning("No tok_embeddings found in model") + return None + + elif component_name == "norm": + if hasattr(model, "norm"): + return "norm" + elif hasattr(model, "norms"): + return "norms" + elif hasattr(model, "layernorm"): + return "layernorm" + else: + logger.warning("No norm found in model") + return None + + elif component_name == "lm_head": + if hasattr(model, "lm_head"): + return "lm_head" + else: + logger.warning("No lm_head found in model") + return None diff --git a/flame/models/pipeline_fla.py b/flame/models/pipeline_fla.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b29f521c25b607ec49b04bb240f59c61641f6 --- /dev/null +++ b/flame/models/pipeline_fla.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class +from transformers import PretrainedConfig + +from flame.models.parallelize_fla import get_blocks, get_components_name, get_model +from torchtitan.config_manager import JobConfig +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank +from torchtitan.tools.logging import logger + +DeviceType = Union[int, str, torch.device] + + +def pipeline_fla( + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: PretrainedConfig, + loss_fn: Callable[..., torch.Tensor], +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + stages, models = pipeline_fla_manual_split( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, models, has_first_stage, has_last_stage + + +def pipeline_fla_manual_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: PretrainedConfig, +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + splits = ( + job_config.experimental.pipeline_parallel_split_points + or generate_split_points( + job_config, parallel_dims.pp, model_config.num_hidden_layers + ) + ) + + def _build_stage( + stage_idx: int, + start_layer: Optional[str], + stop_layer: Optional[str], + is_first: bool = False, + is_last: bool = False, + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + if not is_first: + # we do `model.tok_embeddings = None` here + real_model = get_model(model) + tok_embeddings_name = get_components_name(real_model, "tok_embeddings") + setattr(real_model, tok_embeddings_name, None) + + drop_layers = start_layer is not None + # Get module dictionary from get_blocks(model) + # and Create a list of keys before modifying dictionary + module_dict = get_blocks(model)._modules # Store reference + layer_names = list(module_dict.keys()) + + # Iterate over the list of keys instead of `_modules.items()` + for name in layer_names: + # Dynamically determine prefix (blocks.* or layers.*) + prefix = start_layer.split(".")[0] if start_layer else "layers" + layer_name = f"{prefix}.{name}" # Construct the correct name format + + # Ensure `drop_layers` activation is based on actual naming + if layer_name == start_layer: + drop_layers = False + if layer_name == stop_layer: + drop_layers = True + + # Delete layer if drop_layers is active + if drop_layers: + del module_dict[name] # Safe deletion from stored dictionary + + if not is_last: + # we do `model.norm = None` and `model.output = None` + real_model = get_model(model) + norm_name = get_components_name(real_model, "norm") + setattr(real_model, norm_name, None) + + head_name = get_components_name(model, "lm_head") + setattr(model, head_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(splits) + 1 + stage_idx = pp_rank + + stages = [] + models = [] + + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None + stage, model_chunk = _build_stage( + stage_idx, + start_layer, + stop_layer, + is_first=stage_idx == 0, + is_last=stage_idx == num_stages - 1, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx}" + f" with start_layer {start_layer}, stop_layer {stop_layer}" + ) + stages.append(stage) + models.append(model_chunk) + return stages, models diff --git a/flame/tools/__init__.py b/flame/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/tools/__pycache__/__init__.cpython-312.pyc b/flame/tools/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6134de6455767b7b141903ce75536a10d146a9bb Binary files /dev/null and b/flame/tools/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/tools/__pycache__/utils.cpython-312.pyc b/flame/tools/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..742edce5495c39d47801e1b7c374a6cedc51f6e4 Binary files /dev/null and b/flame/tools/__pycache__/utils.cpython-312.pyc differ diff --git a/flame/utils/__init__.py b/flame/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/utils/__pycache__/__init__.cpython-312.pyc b/flame/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9f73e478d776ed3a28dca57e580580546befd30 Binary files /dev/null and b/flame/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/utils/__pycache__/checkpoint.cpython-312.pyc b/flame/utils/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0733b186380dc625c91b3a1639655d201d5f8a15 Binary files /dev/null and b/flame/utils/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc b/flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588b942603f45c83a11e9167b76ce8ce664145d2 Binary files /dev/null and b/flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc differ diff --git a/flame/utils/__pycache__/hf_utils.cpython-312.pyc b/flame/utils/__pycache__/hf_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8601f94e59c68d724b98f20a1698f99e1451eb36 Binary files /dev/null and b/flame/utils/__pycache__/hf_utils.cpython-312.pyc differ diff --git a/flame/utils/checkpoint.py b/flame/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..839ac7df075c3bfca6747855781953c8a82a4c28 --- /dev/null +++ b/flame/utils/checkpoint.py @@ -0,0 +1,50 @@ +import os +import glob +import re +import shutil +from torchtitan.tools.logging import logger + + +def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int): + """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats.""" + if keep_latest_k <= 0: + return # Keep all checkpoints + + logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}") + + # Cleanup DCP checkpoints (step-*) + dcp_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*")), + key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1, + reverse=True + ) + # Filter out HF format directories + dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")] + + if len(dcp_checkpoints) > keep_latest_k: + checkpoints_to_delete = dcp_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") + + + # Cleanup HF checkpoints (step-*-hf) + hf_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*-hf")), + key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1, + reverse=True + ) + + if len(hf_checkpoints) > keep_latest_k: + checkpoints_to_delete = hf_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") diff --git a/flame/utils/convert_dcp_to_hf.py b/flame/utils/convert_dcp_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..2240578d54b6be51241ddad0e253c548cb362492 --- /dev/null +++ b/flame/utils/convert_dcp_to_hf.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import argparse +import io +import os +import tempfile +from datetime import timedelta + +import torch +import torch.serialization +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import fla # noqa +from torchtitan.tools.logging import init_logger, logger + + +@torch.inference_mode() +def save_pretrained( + path: str, + step: int, + config: str, + tokenizer: str +): + logger.info(f"Loading the config from {config}") + config = AutoConfig.from_pretrained(config, trust_remote_code=True) + + logger.info(f"Saving the config to {path}") + config.save_pretrained(path) + logger.info(f"Loading the tokenizer from {tokenizer}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + logger.info(f"Saving the tokenizer to {path}") + tokenizer.save_pretrained(path) + + with tempfile.TemporaryDirectory() as tmpdir: + # base_checkpoint_dir = os.path.dirname(path) + base_checkpoint_dir = path + checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}') + checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt') + logger.info(f"Saving the distributed checkpoint to {checkpoint_path}") + dcp_to_torch_save(checkpoint, checkpoint_path) + + logger.info(f"Initializing the model from config\n{config}") + model = AutoModelForCausalLM.from_config(config) + logger.info(model) + logger.info("Loading state dict from the checkpoint") + + # Add datetime.timedelta and io.BytesIO to safe globals + torch.serialization.add_safe_globals([timedelta, io.BytesIO]) + # torch.load now with default weights_only=True will work + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model']) + + logger.info(f"Saving the model to {path}") + model.save_pretrained(path) + + +if __name__ == "__main__": + init_logger() + parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.") + parser.add_argument("--path", type=str, required=True) + parser.add_argument("--step", type=int, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--tokenizer", type=str, required=True) + args = parser.parse_args() + save_pretrained(args.path, args.step, args.config, args.tokenizer) diff --git a/flame/utils/convert_hf_to_dcp.py b/flame/utils/convert_hf_to_dcp.py new file mode 100644 index 0000000000000000000000000000000000000000..bab94ebf80ea8822139b851e0c64b95854c2e78b --- /dev/null +++ b/flame/utils/convert_hf_to_dcp.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as DCP +from transformers import AutoModelForCausalLM + +import fla # noqa +from torchtitan.tools.logging import init_logger, logger + + +@torch.inference_mode() +def convert_hf_weights(model: str, checkpoint: str): + logger.info(f"Loading model from {model}") + model = AutoModelForCausalLM.from_pretrained(model) + state_dict = model.state_dict() + + logger.info(f"Writing to DCP at '{checkpoint}'") + checkpoint.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) + DCP.save({"model": state_dict}, storage_writer=storage_writer) + + +if __name__ == "__main__": + init_logger() + parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.") + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--checkpoint", type=Path, required=True) + args = parser.parse_args() + + convert_hf_weights(args.model, args.checkpoint) diff --git a/flame/utils/hf_utils.py b/flame/utils/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8954965dbde4c33131bdf05811fdf803c247168 --- /dev/null +++ b/flame/utils/hf_utils.py @@ -0,0 +1,77 @@ +import os +import re +from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo +from torchtitan.tools.logging import logger + +def upload_checkpoint_to_hf( + local_path: str, + step: int, + hf_repo_id_for_run: str, + hf_keep_latest_k: int, + upload_format: str +): + """Uploads a checkpoint directory to HF Hub and manages retention.""" + if not os.path.isdir(local_path): + logger.error(f"Local path for upload does not exist or is not a directory: {local_path}") + return + + api = HfApi() + token = HfFolder.get_token() + if not token: + logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.") + return + + # --- Ensure the specific repository for this run exists --- + try: + logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...") + # Use create_repo which handles creation only if it doesn't exist + create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True) + logger.info(f"Repository {hf_repo_id_for_run} ensured.") + except Exception as e: + logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True) + return # Stop if repo interaction fails + + commit_message = f"Upload {upload_format.upper()} checkpoint step {step}" + path_in_repo = f"step-{step}" + + logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...") + try: + api.upload_folder( + folder_path=local_path, + path_in_repo=path_in_repo, + repo_id=hf_repo_id_for_run, + repo_type="model", + commit_message=commit_message, + token=token, + ) + logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.") + except Exception as e: + logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True) + if hf_keep_latest_k > 0: + logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}") + try: + repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False) + step_folders = [ + item.path for item in repo_files + if item.path.startswith("step-") and item.path[5:].isdigit() + ] + + step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True) + + if len(step_folders) > hf_keep_latest_k: + folders_to_delete = step_folders[hf_keep_latest_k:] + logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}") + for folder in folders_to_delete: + # Deleting requires repo_id, path_in_repo, and token + api.delete_folder( + repo_id=hf_repo_id_for_run, + path_in_repo=folder, + repo_type="model", + commit_message=f"Delete old checkpoint {folder}", + token=token + ) + logger.info("Hub cleanup complete.") + else: + logger.info("No old checkpoints found on Hub to delete.") + except Exception as e: + logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True) \ No newline at end of file diff --git a/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log b/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..2f730f162a9b90ed3037eb91d7141d5d17aef034 --- /dev/null +++ b/tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/logs/debug-internal.log @@ -0,0 +1,90 @@ +{"time":"2025-07-16T22:10:00.785425491Z","level":"INFO","msg":"stream: starting","core version":"0.21.0"} +{"time":"2025-07-16T22:10:01.508654924Z","level":"INFO","msg":"stream: created new stream","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-16T22:10:01.508690211Z","level":"INFO","msg":"stream: started","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-16T22:10:01.508739999Z","level":"INFO","msg":"writer: Do: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-16T22:10:01.508759314Z","level":"INFO","msg":"handler: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-16T22:10:01.508803829Z","level":"INFO","msg":"sender: started","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-16T23:09:45.740737848Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-16T23:18:29.56428269Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-16T23:19:01.917480335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:19:36.868918826Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-16T23:20:16.297827588Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-16T23:20:18.619477493Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:46470->127.0.0.53:53: i/o timeout"} +{"time":"2025-07-16T23:20:30.740650327Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp: lookup api.wandb.ai on 127.0.0.53:53: read udp 127.0.0.1:47482->127.0.0.53:53: i/o timeout"} +{"time":"2025-07-16T23:21:04.536690541Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:21:49.291673175Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-16T23:22:07.542159208Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:23:23.103733736Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-16T23:23:37.543151076Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:25:07.544031298Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:26:37.545971769Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:27:42.194377246Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:27:59.564813743Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000912631,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"} +{"time":"2025-07-16T23:28:07.547697617Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-16T23:29:37.549836886Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:31:01.930916994Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000672411,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-16T23:31:02.101966833Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000995925,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-16T23:31:07.103368571Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000796336,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"} +{"time":"2025-07-16T23:31:07.551682713Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:32:37.553473869Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-16T23:33:58.248779065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"} +{"time":"2025-07-16T23:34:58.351555112Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1018.787711083,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"ft8cf3fgtodg\" connection_id:\"1(@)\")"} +{"time":"2025-07-16T23:34:58.351650283Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.421498346,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-16T23:34:58.351778293Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":831.249242004,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"} +{"time":"2025-07-16T23:34:58.351785775Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":836.250829923,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-17T01:31:13.353253854Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T08:06:16.748740406Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-17T09:50:19.526737851Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": read tcp 10.0.2.15:54882->35.186.228.49:443: read: connection reset by peer"} +{"time":"2025-07-17T09:52:30.348552703Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T09:53:02.422139335Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T09:53:36.600890938Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T09:54:16.203516351Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T09:55:05.357439477Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T09:56:15.05960959Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T09:57:45.061688428Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T09:59:15.063226591Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T10:00:45.065259852Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T10:01:04.518171545Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T10:02:00.347889757Z","level":"WARN","msg":"sender: taking a long time","seconds":600.000372919,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"} +{"time":"2025-07-17T10:02:15.066174619Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T10:03:45.067145051Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T10:05:02.098970791Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000073665,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-17T10:05:07.474477054Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000841939,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"} +{"time":"2025-07-17T10:05:15.068468165Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2025-07-17T10:05:16.930808745Z","level":"WARN","msg":"runwork: taking a long time","seconds":600.000229861,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-17T10:06:07.008582668Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T10:06:45.070340311Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-17T10:07:57.799911415Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": unexpected EOF"} +{"time":"2025-07-17T10:08:57.969386735Z","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":1017.621908973,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"it0uq1ptdf5l\" connection_id:\"1(@)\")"} +{"time":"2025-07-17T10:08:57.969579361Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":835.870728331,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-17T10:08:57.969680501Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":821.039158554,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"} +{"time":"2025-07-17T10:08:57.969682134Z","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":830.496074059,"work":"WorkRecord(*service_go_proto.Request_PartialHistory); Control(local:true connection_id:\"1(@)\")"} +{"time":"2025-07-17T12:53:12.780364188Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-17T16:43:31.998287109Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n\n\n502 Server Error\n\n\n

Error: Server Error

\n

The server encountered a temporary error and could not complete your request.

Please try again in 30 seconds.

\n

\n\n"} +{"time":"2025-07-18T00:01:06.015630566Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-18T06:56:24.118529653Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-18T14:32:12.830145916Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-18T19:51:31.703829065Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"} +{"time":"2025-07-19T03:35:03.743864446Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-19T21:22:32.639517404Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:51870->35.186.228.49:443: read: connection reset by peer"} +{"time":"2025-07-19T21:31:32.643369264Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": read tcp 10.0.2.15:38482->35.186.228.49:443: read: connection reset by peer"} +{"time":"2025-07-20T00:27:42.221361901Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T09:40:16.319872482Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T09:45:18.218885403Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T19:19:37.674808147Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T20:26:46.102126738Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T21:40:42.245223721Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T21:42:31.526229193Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-20T22:42:07.859288654Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-21T03:41:28.397742169Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-21T04:49:16.742257697Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-21T05:48:28.62347913Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-21T06:22:31.529351974Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream\": dial tcp 35.186.228.49:443: connect: connection refused"} +{"time":"2025-07-21T14:47:44.545628902Z","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/zaydzuhri/fla/mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/file_stream","body":"\n\n\n502 Server Error\n\n\n

Error: Server Error

\n

The server encountered a temporary error and could not complete your request.

Please try again in 30 seconds.

\n

\n\n"} +{"time":"2025-07-21T21:19:44.840025606Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"} +{"time":"2025-07-21T21:19:44.94975041Z","level":"INFO","msg":"handler: operation stats","stats":{}} +{"time":"2025-07-21T21:19:44.958211652Z","level":"INFO","msg":"stream: closing","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-21T21:19:44.958407771Z","level":"INFO","msg":"writer: Close: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-21T21:19:44.958426934Z","level":"INFO","msg":"handler: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-21T21:19:44.958428316Z","level":"INFO","msg":"sender: closed","stream_id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} +{"time":"2025-07-21T21:19:44.958480192Z","level":"INFO","msg":"stream: closed","id":"mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201"} diff --git a/torchtitan/components/__pycache__/float8.cpython-312.pyc b/torchtitan/components/__pycache__/float8.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfd989dbf1575a29357b1f7c75e6d76c01fe03c9 Binary files /dev/null and b/torchtitan/components/__pycache__/float8.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/ft.cpython-312.pyc b/torchtitan/components/__pycache__/ft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81f232d3764d4051a9a873f24caf9657b828a061 Binary files /dev/null and b/torchtitan/components/__pycache__/ft.cpython-312.pyc differ diff --git a/torchtitan/components/float8.py b/torchtitan/components/float8.py new file mode 100644 index 0000000000000000000000000000000000000000..b01c5063bcb90aacdd77c8ad3f078898ea6bfd62 --- /dev/null +++ b/torchtitan/components/float8.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# [Note] Getting the 'torchao' package: +# This script requires the 'torchao' package to function correctly. +# Please ensure you have this package installed from the appropriate repository. +# You can obtain it from https://github.com/pytorch/ao by following the +# installation instructions. + +# Note: Performance +# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance + +import torch +import torch.nn as nn + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.protocols.model_converter import ( + ModelConverter, + register_model_converter, +) +from torchtitan.tools.logging import logger + + +def _is_sm89_or_later(): + # Float8 is only supported on SM89 or later (H100+ GPUs) + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +class Float8Converter(ModelConverter): + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.enabled = False + + float8_config = job_config.float8 + if not _is_sm89_or_later(): + logger.warning( + "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", + ) + return + try: + from torchao.float8 import Float8LinearConfig + except ImportError as e: + raise ImportError( + "torchao is not installed. Please install it to use float8 linear layers." + ) from e + + if float8_config.recipe_name is not None and not hasattr( + Float8LinearConfig, "from_recipe_name" + ): + logger.warning( + "Failed to swap to Float8Linear with recipe lookup because the torchao version " + "is too old, please install torchao v0.9.0 or later and try again", + ) + return + + self.enabled = True + self.filter_fqns = float8_config.filter_fqns + + if float8_config.recipe_name is not None: + assert ( + not float8_config.enable_fsdp_float8_all_gather + ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" + assert ( + not float8_config.force_recompute_fp8_weight_in_bwd + ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" + self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) + self.precompute_scale = False + logger.info( + f"Float8 training active with recipe {float8_config.recipe_name}" + ) + + else: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_float8_all_gather = ( + parallel_dims.dp_shard_enabled + and float8_config.enable_fsdp_float8_all_gather + ) + self.config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, + ) + # for precompute_float8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + logger.info("Float8 tensorwise scaled training active") + + def convert(self, model: nn.Module): + return self.convert_to_float8_training(model) + + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): + return self.precompute_float8_dynamic_scale_for_fsdp(model) + + def convert_to_float8_training(self, model: nn.Module): + """ + This function converts the linear layers of `model` to `Float8Linear`. + Note that today, only dynamic tensor scaling (the default) is supported. + This will mutate the model inplace. + """ + if not self.enabled: + return + + from torchao.float8 import convert_to_float8_training + + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear + convert_to_float8_training( + model, + config=self.config, + module_filter_fn=self._module_filter_fn, + ) + logger.info( + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" + f"{self.config.enable_fsdp_float8_all_gather}" + ) + + def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool: + if not isinstance(mod, nn.Linear): + return False + + # All dims must be divisible by 16 due to float8 tensorcore hardware requirements. + dims_multiples_of_16 = ( + mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0 + ) + + # If the fqn matches any filtered fqn, then we should not convert this module. + is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns) + + return dims_multiples_of_16 and not is_filtered_fqn + + def precompute_float8_dynamic_scale_for_fsdp( + self, model: nn.Module | list[nn.Module] + ): + if not self.enabled: + return + + if not self.precompute_scale: + return + + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + + models = [model] if isinstance(model, nn.Module) else model + for m in models: + precompute_float8_dynamic_scale_for_fsdp(m) + + +register_model_converter(Float8Converter, "float8") diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d270d26dd9ceb604f1189c2caee0b6733dd73e --- /dev/null +++ b/torchtitan/components/ft.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +if importlib.util.find_spec("torchft") is not None: + import torchft as ft + + has_torchft = True +else: + has_torchft = False + + +class FTManager: + def __init__( + self, + manager: Optional["ft.Manager"], + group_size: int = 1, + replica_id: int = 0, + ) -> None: + self._manager = manager + self.group_size = group_size + self.replica_id = replica_id + + @property + def enabled(self) -> bool: + return self._manager is not None + + @property + def manager(self) -> "ft.Manager": + assert self._manager is not None + return self._manager + + def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: + return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank + + +def init_ft_manager(job: JobConfig) -> FTManager: + """Initialize the FT manager if TorchFT is enabled. + + Args: + job (JobConfig): The job configuration. + + Returns: + Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None. + """ + if not job.fault_tolerance.enable: + return FTManager(None) + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + if job.fault_tolerance.min_replica_size < 1: + raise ValueError("At least one FT replica is required.") + + pg = ft.ProcessGroupBabyNCCL() + + return FTManager( + ft.Manager( + pg=pg, + min_replica_size=job.fault_tolerance.min_replica_size, + load_state_dict=None, + state_dict=None, + use_async_quorum=True, + replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}", + ), + group_size=job.fault_tolerance.group_size, + replica_id=job.fault_tolerance.replica_id, + ) + + +@dataclass +class FTParallelDims(ParallelDims): + ft_manager: FTManager + + def build_mesh(self, device_type: str) -> DeviceMesh: + def func( + device_type: str, mesh_shape: list[int], mesh_dim_names: list[str] + ) -> DeviceMesh: + from torchft.process_group import ft_init_device_mesh + + return ft_init_device_mesh( + device_type=device_type, + mesh_shape=mesh_shape, + mesh_dim_names=mesh_dim_names, + replicate_dim=mesh_dim_names.index("dp_replicate"), + manager=self.ft_manager.manager, + ) + + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1 or name == "dp_replicate": + dims.append(d) + names.append(name) + + return self._build_mesh(device_type, dims, names, func) + + @property + def dp_replicate_enabled(self): + return True + + +def ft_dist_reduce( + x: torch.Tensor, reduceOp: str, mesh: DeviceMesh +) -> tuple[torch.Tensor, str, DeviceMesh]: + if has_torchft and isinstance(mesh, ft.process_group._FlattenDeviceMesh): + x = funcol.all_reduce( + x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg + ) + return x, reduceOp, mesh.managed_mesh.mesh + return x, reduceOp, mesh + + +def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor: + if has_torchft: + mesh = total_norm._spec.mesh + if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + # The gradients along the replicated dim has already been reduced. + # So we don't need another reducution beforing removing the + # replicate dimension + local_tensor = total_norm.to_local() + placements = list(copy.copy(total_norm._spec.placements)) + placements.pop(mesh.replicate_dim) + return DTensor.from_local(local_tensor, mesh.mesh, placements) + + return total_norm diff --git a/torchtitan/experiments/deepseek_v3/model_config.py b/torchtitan/experiments/deepseek_v3/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d559d4ee94ecf7fccc933cf1a243161d1796a123 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model_config.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class ModelArgs: + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within + `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + vocab_size: int = 129280 + hidden_size: int = 7168 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 61 + num_nextn_predict_layers: int = 1 + num_attention_heads: int = 128 + num_key_value_heads: int = 128 + n_shared_experts: int = 1 + n_routed_experts: int = 256 + ep_size: int = 1 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + n_group: int = 8 + topk_group: int = 4 + num_experts_per_tok: int = 8 + moe_layer_freq: int = 1 + first_k_dense_replace: int = 3 + norm_topk_prob: bool = True + scoring_func: str = "sigmoid" + aux_loss_alpha: float = 0.001 + seq_aux: bool = True + hidden_act: str = "silu" + max_position_embeddings: int = 163840 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: dict = field( + default_factory=lambda: { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + } + ) + attention_bias: bool = False + attention_dropout: float = 0.0 + pad_token_id = None + # Added for symmetric memory + max_seq_len: int = 4096 + dtype: str = "bfloat16" + # Added for pipeline parallel + num_stages: int = 1 + stage_idx: int = 0 + + +# This is the configuration for deepseek-ai/DeepSeek-V2-Lite. +deepseek_v2_lite_config = ModelArgs( + vocab_size=102400, + hidden_size=2048, + intermediate_size=10944, + moe_intermediate_size=1408, + num_hidden_layers=27, + num_attention_heads=16, + num_key_value_heads=16, + n_shared_experts=2, + n_routed_experts=64, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="greedy", + n_group=1, + topk_group=1, + num_experts_per_tok=6, + first_k_dense_replace=1, + norm_topk_prob=False, + scoring_func="softmax", + max_position_embeddings=4096, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, +) + + +# Model configuration registry +# Key is the model distribution ID on HuggingFace Hub +deepseek_config_registry = { + "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config, + "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config, + "deepseek-ai/deepseek-v3": ModelArgs(), +} diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..335bc2d966efbe486418525cb784078a6ec879d5 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .triton_on_device_all_to_all_v import OnDeviceAllToAllV + +__all__ = [ + "OnDeviceAllToAllV", +] diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd9b283f41daffab3f4ce4d1e0a5d844f2a2c70 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + +from .triton_utils import get_flat_bid, get_flat_tid + + +@triton.jit +def send_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def wait_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def blockwise_barrier( + signal_pad_ptrs, + block_id, + rank: tl.constexpr, + world_size: tl.constexpr, + sem: tl.constexpr, +): + """ + Synchronizes blocks with matching block_id across participating devices. + + Note: the function itself is not a system level barrier/fence. It is a + building block for expressing different synchronization patterns. + + Pattern 0: Ensures that all writes to symm_mem buffers from previous + kernels across all devices are visible to the current kernel: + + blockwise_barrier(..., sem="relaxed") + sync_threads() + + Pattern 1: Ensures that all writes to symm_mem buffers from the current + block are visible to all remote blocks with matching blockIdx: + + sync_threads() + blockwise_barrier(..., sem="acq_rel") + sync_threads() + + Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe + for writing by subsequent kernels across all devices. + + sync_threads() + blockwise_barrier(..., sem="relaxed") + + CUDA graph friendliness: + + This barrier operates through atomic operations on a zero-filled signal + pad, which resets to a zero-filled state after each successful + synchronization. This design eliminates the need for incrementing a + flag from host. + """ + if block_id is None: + block_id = get_flat_bid() + flat_tid = get_flat_tid() + + remote_ranks = tl.arange(0, world_size) + signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) + remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to( + tl.pointer_type(tl.uint32) + ) + send_addrs = remote_signal_pad_addrs + block_id * world_size + rank + + local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to( + tl.pointer_type(tl.uint32) + ) + wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks + + if flat_tid < world_size: + send_signal(send_addrs, sem) + wait_signal(wait_addrs, sem) diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd023c36bd9737bfb03da22ea38ef57a448eb80 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import triton +import triton.language as tl + +from .triton_barrier import blockwise_barrier +from .triton_utils import sync_threads + + +@triton.jit +def _exchange_row_offsets( + split_sizes_ptrs, + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, +): + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + + # split_sizes_ptr for all ranks + # All these vector stacks into split_sizes_matrix + split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64)) + + # split_sizes_matrix[remote_rank, :] + input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to( + tl.pointer_type(tl.int64) + ) + + offsets_ = tl.arange(0, world_size) + input_split_sizes = tl.load( + input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0 + ) + + num_rows = tl.load(input_split_sizes_ptr + rank) + input_row_offset = tl.sum(input_split_sizes) - num_rows + + # split_sizes_matrix[:, rank] + output_split_sizes_ptrs = ( + tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank + ) + output_split_sizes = tl.load( + output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0 + ) + output_row_offset = tl.sum(output_split_sizes) - num_rows + + return input_row_offset, output_row_offset, num_rows + + +@triton.jit +def on_device_all_to_all_v_kernel( + output_ptr, + output_splits_ptr, + input_ptrs, + input_splits_ptr, + signal_pad_ptrs, + dim: tl.constexpr, # Separate dim for easier vectorization + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, + UNROLL_FACTOR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + sync_threads() + + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK + + input_row_offset, output_row_offset, num_rows = _exchange_row_offsets( + input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK + ) + + output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64)) + if block_offset == 0: + # Update output_splits + tl.store(output_splits_ptr + remote_rank, num_rows) + + input_ptr = ( + tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to( + tl.pointer_type(tl.bfloat16) + ) + + input_row_offset * dim + ) + output_ptr = output_ptr + output_row_offset * dim + + outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR + outer_loop_iters_per_rank = tl.cdiv( + tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK + ) + numel_per_rank = outer_loop_step * outer_loop_iters_per_rank + offset = numel_per_rank * block_offset + end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim) + + unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step + for i in tl.range(offset, offset + unroll_region_size, outer_loop_step): + datas = [] + for j in tl.range( + i, + i + outer_loop_step, + BLOCK_SIZE, + loop_unroll_factor=UNROLL_FACTOR, + ): + offsets = j + tl.arange(0, BLOCK_SIZE) + data = tl.load(input_ptr + offsets) + tl.store(output_ptr + offsets, data) + + offset += unroll_region_size + while offset < end: + offsets = offset + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_rows * dim + data = tl.load(input_ptr + offsets, mask=mask) + tl.store(output_ptr + offsets, data, mask=mask) + offset += BLOCK_SIZE + + sync_threads() + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + return + + +def _on_device_all_to_all_v( + output: torch.Tensor, + output_splits: torch.Tensor, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + BLOCKS_PER_REMOTE_RANK=8, + UNROLL_FACTOR: int = 8, + BLOCK_SIZE: int = 16384, +): + assert output.dim() == 2, f"{output.shape}" + assert input.dim() == 2, f"{input.shape}" + assert output.shape[1] == input.shape[1] + + dim = output.shape[1] + input_hdl = symm_mem.rendezvous(input, group=group) + input_splits_hdl = symm_mem.rendezvous(input_splits, group=group) + + num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK + kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)]( + output, + output_splits, + input_hdl.buffer_ptrs_dev, + input_splits_hdl.buffer_ptrs_dev, + input_hdl.signal_pad_ptrs_dev, + dim=dim, + rank=input_hdl.rank, + world_size=input_hdl.world_size, + BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK, + UNROLL_FACTOR=UNROLL_FACTOR, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=16, + ) + # log_triton_kernel(kernel) + return output + + +class OnDeviceAllToAllV(torch.autograd.Function): + # A symmetric memory holding the grad_output during backward + grad_output_buf = None + # A symmetric memory for exchanges split sizes during both forward and backward + splits_buf = None + # Maximum output length (need to be set before use of OnDeviceAllToAllV) + max_output_len = None + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + ): + """ + Args: + input: input tensor with data for all ranks concatenated. + input_splits: input splits of shape (group.world_size,) + group: process group to scope the collective. + """ + # Initialize input splits buffer (one time only) + if OnDeviceAllToAllV.splits_buf is None: + OnDeviceAllToAllV.splits_buf = symm_mem.empty( + *input_splits.shape, + dtype=input_splits.dtype, + device=input_splits.device, + ) + + if OnDeviceAllToAllV.max_output_len is None: + raise RuntimeError( + "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`" + ) + + # Allocate output buffer + output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:]) + # Allocate output splits tensor + output_splits = torch.empty_like(input_splits) + # Copy input splits to the buffer + OnDeviceAllToAllV.splits_buf.copy_(input_splits) + + # Shuffle input to output + _on_device_all_to_all_v( + output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group + ) + + # Output splits in forward is the input splits in backward + ctx.save_for_backward(output_splits) + ctx.group = group + ctx.input_shape = input.shape + return output, output_splits + + @staticmethod + def backward(ctx, grad_output, grad_splits): + """ + Backward is implemented as a shuffle of the output's gradients to the input. + Args: + `grad_output`: output's gradients passed from the downstream. + `grad_splits`: unused. + """ + + # Initialize grad_output buffer (one time only) + if OnDeviceAllToAllV.grad_output_buf is None: + assert ( + OnDeviceAllToAllV.max_output_len is not None + ), "`max_output_len` not set" + OnDeviceAllToAllV.grad_output_buf = symm_mem.empty( + OnDeviceAllToAllV.max_output_len, + *grad_output.shape[1:], + dtype=grad_output.dtype, + device=grad_output.device, + ) + + # TODO: is there a way to tell autograd to feed grad_output directly to + # our symm_mem buffer? + OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_( + grad_output + ) + + # Size info + (grad_output_splits,) = ctx.saved_tensors + OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits) + grad_input_splits = torch.empty_like(grad_output_splits) # unused + grad_input = grad_output.new_empty(*ctx.input_shape) + + # Shuffle gradients back to the input + _on_device_all_to_all_v( + grad_input, + grad_input_splits, + OnDeviceAllToAllV.grad_output_buf, + OnDeviceAllToAllV.splits_buf, + group=ctx.group, + ) + return grad_input, None, None + + +# Alias +on_device_all_to_all_v = OnDeviceAllToAllV.apply diff --git a/torchtitan/experiments/deepseek_v3/train.py b/torchtitan/experiments/deepseek_v3/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9ed2dd65164744686647964de3ffdfa3813771 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/train.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# torchrun --standalone --nproc-per-node 8 run.py +import torch +import torch.distributed as dist +from checkpoint import load_weights_from_hf +from model import DeepseekForCausalLM +from model_config import deepseek_config_registry + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import fully_shard +from torch.distributed.pipelining import PipelineStage, Schedule1F1B + + +# Use DeepSeek-V2-Lite as a proxy +model_id = "deepseek-ai/DeepSeek-V2-Lite" + + +# Run full model +def run_full_model( + mesh: DeviceMesh, +): + rank = dist.get_rank() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + pp_mesh = mesh["pp"] + ep_mesh = mesh["ep"] + pp_rank = pp_mesh.get_local_rank() + ep_rank = ep_mesh.get_local_rank() + pp_size = pp_mesh.size() + ep_size = ep_mesh.size() + + # Get model configs + model_args = deepseek_config_registry[model_id] + # [Note]: I am making the model smaller for testing / avoiding OOM. If you + # have sufficient GPUs for model parallelism, you can remove this line. + model_args.num_hidden_layers = 16 + + # Apply model parallelism + model_args.ep_size = ep_size + model_args.num_stages = pp_size + model_args.stage_idx = pp_rank + print(model_args) + + # Instantiate model + with device, mesh: + model = DeepseekForCausalLM(model_args) + + # Load weights + load_weights_from_hf(model, model_id, device) + model.train() + + # Apply data parallelism + fsdp_mesh = mesh["fsdp"] + hsdp_mesh = mesh["ep", "fsdp"] + # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the + # optimizer (Zero-1) and gradients (Zero-2), but not the model weights. + # Reason: the MoE is "sparsely activated" compared to the dense model, thus + # it will be ineconomical re-gather the weights. + for layer in model.model.layers.values(): + # Apply FSDP to experts + if hasattr(layer.mlp, "experts"): + for expert in layer.mlp.experts.values(): + fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False) + # Apply HSDP to other parts such as attention, layernorm, because they + # are doing DDP on EP dimension + fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False) + + # Apply HSDP on root model (lm_head, embeddings, etc) + fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False) + + # Synthetic setting + microbatches = pp_size * 2 + + # Use Symmetric Memory for MoE token shuffle. + # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is + # currently supported for forward only. See `generate.py`. + # model.setup_symm_mem(torch.bfloat16, device) + + # Example inputs + torch.manual_seed(ep_rank) + bs = 4 + seqlen = 128 + x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device) + label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device) + + # Create loss function + loss_fn = torch.nn.functional.cross_entropy + + # Run forward and backward + steps = 2 + for _ in range(steps): + if pp_size > 1: + # Create pipeline stage + stage = PipelineStage( + model, + pp_rank, + pp_size, + device, + group=pp_mesh.get_group(), + ) + + # Create pipeline schedule + losses = [] + pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn) + + if pp_rank == 0: + y = pp_schedule.step(x) + elif pp_rank == pp_size - 1: + y = pp_schedule.step(target=label, losses=losses) + loss = torch.mean(torch.stack(losses)) + else: + pp_schedule.step() + else: + y = model(x) + loss = loss_fn(y, label) + loss.backward() + + if pp_rank == pp_size - 1: + print(f"logits: {y.shape}") + print(f"{loss=}") + + if pp_rank == 0: + param = model.get_parameter("model.layers.0.self_attn.q_proj.weight") + print(f"{torch.linalg.norm(param.grad)=}") + + model.zero_grad() + + print("Backward done") + + +if __name__ == "__main__": + mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp")) + + run_full_model(mesh) + + dist.destroy_process_group() diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..090bfc955152d87614f03793fd606330995da39d --- /dev/null +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + + +from typing import List + +from torchtitan.components.tokenizer import Tokenizer +from transformers import CLIPTokenizer, T5Tokenizer + + +class FluxTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the T5 or Clip tokenizer. + + Args: + model_path (str): Path to the tokenzier from hugging face. + + """ + + def __init__(self, model_path: str = "t5-small", max_length: int = 77): + super().__init__() + self._n_words = 8 # TODO(jianiw): check + self._max_length = max_length + + self.is_clip = model_path.startswith("openai") + + if self.is_clip: + self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + model_path, max_length=max_length + ) + else: + self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( + model_path, max_length=max_length + ) + + def encode( + self, + s: str, + ) -> List[int]: + """ + Encode the prompt text into tokens. + """ + tokens = self._tokenizer( + s, + truncation=True, + max_length=self._max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", # return pytorch tensors, default return List[int] + )["input_ids"] + return tokens + + def decode(self, t: List[int]) -> str: + """ + Decode function. This function will not be called. + """ + return self._tokenizer.decode(t) diff --git a/torchtitan/experiments/flux/flux_argparser.py b/torchtitan/experiments/flux/flux_argparser.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d290134ce1b502e6268afe58d3a231af2d447f --- /dev/null +++ b/torchtitan/experiments/flux/flux_argparser.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch + + +def extend_parser(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--training.guidance", + type=float, + default=3.5, + help="guidance value used for guidance distillation", + ) + parser.add_argument( + "--encoder.t5_encoder", + type=str, + default="google/t5-v1_1-small", + help="T5 encoder to use, HuggingFace model name.", + ) + parser.add_argument( + "--encoder.clip_encoder", + type=str, + default="openai/clip-vit-large-patch14", + help="Clip encoder to use, HuggingFace model name.", + ) + parser.add_argument( + "--encoder.encoder_dtype", + type=torch.dtype, + default=torch.bfloat16, + help="Which dtype to load for autoencoder. ", + ) + parser.add_argument( + "--encoder.max_t5_encoding_len", + type=int, + default=512, + help="Maximum length of the T5 encoding.", + ) diff --git a/torchtitan/experiments/flux/model/autoencoder.py b/torchtitan/experiments/flux/model/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a68d5fb750d04b37d059dbef1de1f399bd3caea2 --- /dev/null +++ b/torchtitan/experiments/flux/model/autoencoder.py @@ -0,0 +1,388 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from dataclasses import dataclass + +import torch +from einops import rearrange +from safetensors.torch import load_file as load_sft +from torch import nn, Tensor + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: tuple[int] = (1, 2, 4, 4) + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def load_ae( + ckpt_path: str, + autoencoder_params: AutoEncoderParams, + device: str | torch.device = "cuda", + dtype=torch.bfloat16, +) -> AutoEncoder: + """ + Load the autoencoder from the given model name. + Args: + name (str): The name of the autoencoder. + device (str or torch.device): The device to load the autoencoder to. + Returns: + AutoEncoder: The loaded autoencoder. + """ + # Loading the autoencoder + print("Init AE") + with torch.device(device): + ae = AutoEncoder(autoencoder_params) + + if not os.path.exists(ckpt_path): + raise ValueError( + f"Autoencoder path {ckpt_path} does not exist. Please download it first." + ) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + if len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + if len(unexpected) > 0: + print( + f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) + ) + return ae.to(dtype=dtype) diff --git a/torchtitan/experiments/flux/model/hf_embedder.py b/torchtitan/experiments/flux/model/hf_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..495fd7a81d16cc0cadeaab3b390a638339ff0f94 --- /dev/null +++ b/torchtitan/experiments/flux/model/hf_embedder.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn, Tensor +from transformers import CLIPTextModel, T5EncoderModel + + +class FluxEmbedder(nn.Module): + def __init__(self, version: str, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs + ) + else: + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + version, **hf_kwargs + ) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, batch_tokens: Tensor) -> Tensor: + """ + batch_tokens: [bsz, embedding_length] + + For T5 Encoder, embeding_length is 768 + For CLIP, embedding_length is 256 + """ + outputs = self.hf_module( + input_ids=batch_tokens.to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/torchtitan/experiments/flux/model/layers.py b/torchtitan/experiments/flux/model/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..73141b373a5d579b8c8988fa66d1f9594e5bad3f --- /dev/null +++ b/torchtitan/experiments/flux/model/layers.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# imported from black-forest-labs/FLUX +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import nn, Tensor + +from torchtitan.experiments.flux.model.math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/torchtitan/experiments/flux/model/model.py b/torchtitan/experiments/flux/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..67b9e6aeaacee709c4fdc7d86f338eec050bf322 --- /dev/null +++ b/torchtitan/experiments/flux/model/model.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch + +from torch import nn, Tensor +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig + +from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams +from torchtitan.experiments.flux.model.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + +from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol +from torchtitan.tools.logging import logger + + +@dataclass +class FluxModelArgs(BaseModelArgs): + in_channels: int = 64 + out_channels: int = 64 + vec_in_dim: int = 768 + context_in_dim: int = 512 + hidden_size: int = 3072 + mlp_ratio: float = 4.0 + num_heads: int = 24 + depth: int = 19 + depth_single_blocks: int = 38 + axes_dim: tuple = (16, 56, 56) + theta: int = 10_000 + qkv_bias: bool = True + guidance_embed: bool = True + autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams) + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + # context_in_dim is the same as the T5 embedding dimension + self.context_in_dim = job_config.encoder.max_t5_encoding_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + # TODO(jianiw): Add the number of flops for the autoencoder + nparams = sum(p.numel() for p in model.parameters()) + logger.warning("FLUX model haven't implement get_nparams_and_flops() function") + return nparams, 1 + + +class FluxModel(nn.Module, ModelProtocol): + """ + Transformer model for flow matching on sequences. + + Agrs: + model_args: FluxModelArgs. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + """ + + def __init__(self, model_args: FluxModelArgs): + super().__init__() + + self.model_args = model_args + self.in_channels = model_args.in_channels + self.out_channels = model_args.out_channels + if model_args.hidden_size % model_args.num_heads != 0: + raise ValueError( + f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}" + ) + pe_dim = model_args.hidden_size // model_args.num_heads + if sum(model_args.axes_dim) != pe_dim: + raise ValueError( + f"Got {model_args.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = model_args.hidden_size + self.num_heads = model_args.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if model_args.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=model_args.mlp_ratio, + qkv_bias=model_args.qkv_bias, + ) + for _ in range(model_args.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio + ) + for _ in range(model_args.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def init_weights(self, buffer_device=None): + # TODO(jianiw): replace placeholder with real weight init + for param in self.parameters(): + param.data.uniform_(0, 0.1) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.model_args.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + @classmethod + def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel": + """ + Initialize a Flux model from a FluxModelArgs object. + + Args: + model_args (FluxModelArgs): Model configuration arguments. + + Returns: + FluxModel: FluxModel model. + + """ + return cls(model_args) diff --git a/torchtitan/experiments/flux/scripts/download_autoencoder.py b/torchtitan/experiments/flux/scripts/download_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4dd4437bc583987da69ace57e61ef1b8314d582 --- /dev/null +++ b/torchtitan/experiments/flux/scripts/download_autoencoder.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download( + repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None +) -> None: + from huggingface_hub import hf_hub_download + + try: + hf_hub_download( + repo_id=repo_id, + filename=file_path, + local_dir=local_dir, + local_dir_use_symlinks=False, + token=hf_token, + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") + parser.add_argument( + "--repo_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + help="Repository ID to download from. default to Flux-dev model", + ) + parser.add_argument( + "--ae_path", + type=str, + default="ae.safetensors", + help="the autoencoder path relative to repo_id", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token" + ) + parser.add_argument( + "--local_dir", + type=str, + default="torchtitan/experiments/flux/assets/autoencoder/", + help="local directory to save the autoencoder", + ) + + args = parser.parse_args() + hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token) diff --git a/torchtitan/experiments/flux/tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/test_flux_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..fc87f1b8b4ae3ad7daf1558835716720127e3b42 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_flux_dataloader.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from torchtitan.config_manager import JobConfig +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + + +class TestFluxDataLoader: + def test_flux_dataloader(self): + dataset_name = "cc12m" + batch_size = 32 + world_size = 4 + rank = 0 + + num_steps = 10 + + path = "torchtitan.experiments.flux.flux_argparser" + sys.argv.append(f"--experimental.custom_args_module={path}") + config = JobConfig() + config.maybe_add_custom_args() + config.parse_args( + [ + # Profiling options + # "--profiling.enable_profiling", + # "--profiling.profile_freq", + # "5", + # "--profiling.enable_memory_snapshot", + # "--profiling.save_memory_snapshot_folder", + # "memory_snapshot_flux", + "--training.dataset", + dataset_name, + "--training.batch_size", + str(batch_size), + "--encoder.t5_encoder", + "google/t5-v1_1-small", + "--encoder.clip_encoder", + "openai/clip-vit-large-patch14", + "--encoder.max_t5_encoding_len", + "512", + ] + ) + + with maybe_enable_profiling( + config, global_step=0 + ) as torch_profiler, maybe_enable_memory_snapshot( + config, global_step=0 + ) as memory_profiler: + dl = self._build_dataloader( + config, + world_size, + rank, + ) + dl = iter(dl) + + for i in range(0, num_steps): + input_data, labels = next(dl) + print(f"Step {i} image size: {labels.shape}") + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + print(len(input_data["clip_tokens"])) + for k, v in input_data.items(): + print(f"Step {i} {k} value: {type(v), v.shape}") + + assert len(input_data) == 2 # (clip_encodings, t5_encodings) + assert labels.shape == (batch_size, 3, 256, 256) + # assert input_data["clip_tokens"].shape[0] == batch_size + # assert input_data["t5_tokens"].shape == (batch_size, 512, 512) + + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step(exit_ctx=True) + + def test_preprocess(self): + # TODO + pass + + def _build_dataloader( + self, + job_config, + world_size, + rank, + ): + + return build_flux_dataloader( + dp_world_size=world_size, + dp_rank=rank, + job_config=job_config, + tokenizer=None, + infinite=False, + ) diff --git a/torchtitan/experiments/flux/tests/test_generate_image.py b/torchtitan/experiments/flux/tests/test_generate_image.py new file mode 100644 index 0000000000000000000000000000000000000000..86d8d16cfbbcbfaa706e6ff6713403520744efd5 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_generate_image.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from typing import Callable + +import torch +from einops import rearrange + +from PIL import ExifTags, Image + +from torch import Tensor + +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer + +from torchtitan.experiments.flux.model.autoencoder import ( + AutoEncoder, + AutoEncoderParams, + load_ae, +) +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder + +from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + generate_noise_latent, + pack_latents, + preprocess_flux_data, + unpack_latents, +) + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +class TestGenerateImage: + def test_generate_image(self): + """ + Run a forward pass of flux model to generate an image. + """ + name = "flux-dev" + img_width = 512 + img_height = 512 + seed = None + prompt = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ) + device = "cuda" + num_steps = None + loop = False + guidance = 3.5 + output_dir = "output" + add_sampling_metadata = True + + prompt = prompt.split("|") + if len(prompt) == 1: + prompt = prompt[0] + additional_prompts = None + else: + additional_prompts = prompt[1:] + prompt = prompt[0] + + assert not ( + (additional_prompts is not None) and loop + ), "Do not provide additional prompts and set loop to True" + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 30 + + # allow for packing and conversion to latent space + img_height = 16 * (img_height // 16) + img_width = 16 * (img_width // 16) + + # init all components + model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16) + + ae = load_ae( + ckpt_path="assets/autoencoder/ae.safetensors", + autoencoder_params=AutoEncoderParams(), + device=torch_device, + dtype=torch.bfloat16, + ) + clip_tokenizer = FluxTokenizer( + model_path="openai/clip-vit-large-patch14", max_length=77 + ) + t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512) + clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to( + torch_device, dtype=torch.bfloat16 + ) + t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to( + torch_device, dtype=torch.bfloat16 + ) + + rng = torch.Generator(device="cpu") + + if seed is None: + seed = rng.seed() + print(f"Generating with seed {seed}:\n{prompt}") + t0 = time.perf_counter() + output_name = os.path.join(output_dir, f"img_{seed}.jpg") + + # Tokenize the prompt, on CPU + clip_tokens = clip_tokenizer.encode(prompt) + t5_tokens = t5_tokenizer.encode(prompt) + + batch = preprocess_flux_data( + device=torch_device, + dtype=torch.bfloat16, + autoencoder=None, + clip_encoder=clip_encoder, + t5_encoder=t5_encoder, + batch={ + "clip_tokens": clip_tokens, + "t5_tokens": t5_tokens, + }, + ) + + img = self._generate_images( + device=torch_device, + dtype=torch.bfloat16, + model=model, + decoder=ae, + img_width=img_width, + img_height=img_height, + denoising_steps=num_steps, + seed=seed, + clip_encodings=batch["clip_encodings"], + t5_encodings=batch["t5_encodings"], + guidance=guidance, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + + print(f"Done in {t1 - t0:.1f}s.") + + self._save_image(name, output_name, img, add_sampling_metadata, prompt) + + def _generate_images( + self, + device: torch.device, + dtype: torch.dtype, + model: FluxModel, + decoder: AutoEncoder, + # image params: + img_width: int, + img_height: int, + # sampling params: + denoising_steps: int, + seed: int, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + guidance: float = 4.0, + ): + + bsz = clip_encodings.shape[0] + latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed) + _, latent_channels, latent_height, latent_width = latents.shape + + # create denoising schedule + timesteps = get_schedule(denoising_steps, latent_channels, shift=True) + + # create positional encodings + POSITION_DIM = 3 # constant for Flux flow model + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ).to(latents) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) + + # convert img-like latents into sequences of patches + latents = pack_latents(latents) + + # this is ignored for schnell + guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device) + pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=t_vec, + guidance=guidance_vec, + ) + + latents = latents + (t_prev - t_curr) * pred + + # convert sequences of patches into img-like latents + latents = unpack_latents(latents, latent_height, latent_width) + + img = decoder.decode(latents) + return img + + def _save_image( + self, + name: str, + output_name: str, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, + ): + print(f"Saving {output_name}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(output_name, exif=exif_data, quality=95, subsampling=0) diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..250a71d60ec28028b548803bad7f14b6b3a6db62 --- /dev/null +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -0,0 +1,68 @@ + +[job] +dump_folder = "./outputs" +description = "Flux debug model" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "flux" +flavor = "flux-debug" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +# tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" + + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +batch_size = 32 +seq_len = 512 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "cc12m" +guidance = 3.5 +seed = 0 + +[encoder] +t5_encoder="google/t5-v1_1-small" +clip_encoder="openai/clip-vit-large-patch14" +max_t5_encoding_len=512 +auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = 1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[experimental] +custom_args_module = "torchtitan.experiments.flux.flux_argparser" diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py new file mode 100644 index 0000000000000000000000000000000000000000..7e893a54443a6c05a548b35325421e66db321d43 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py @@ -0,0 +1,885 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import time + +from typing import Dict, List, Tuple + +# import numpy as np +import torch # +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +# from torchao_pr.mg_grouped_gemm import mg_grouped_gemm + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Try to import the optimized MG GEMM implementation +try: + from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward, + grouped_gemm_forward, + ) + + has_mg_gemm = True +except ImportError: + logging.warning("MG GEMM implementation not found. Will use manual looping only.") + has_mg_gemm = False + + +class Router(nn.Module): + """ + Router module that assigns tokens to experts. + """ + + def __init__(self, input_dim: int, num_experts: int, top_k: int = 2): + super().__init__() + self.input_dim = input_dim + self.num_experts = num_experts + self.top_k = top_k + + # Routing layer + self.router = nn.Linear(input_dim, num_experts) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Route input tokens to experts. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim) + + Returns: + Tuple containing: + - router_logits: Raw routing probabilities + - dispatch_tensor: One-hot tensor indicating expert assignment + - expert_indices: List of indices for each expert's tokens + """ + batch_size, seq_len, _ = x.shape + + # Flatten batch and sequence dimensions + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + + # Compute routing probabilities + router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts) + + # Apply softmax to get probabilities + router_probs = F.softmax(router_logits, dim=-1) + + # Get top-k experts for each token + top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1) + + # Normalize top-k probabilities + top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) + + # Create dispatch tensor (one-hot representation of assignments) + dispatch_tensor = torch.zeros_like(router_probs) + token_indices = ( + torch.arange(router_probs.size(0), device=router_probs.device) + .unsqueeze(1) + .expand(-1, self.top_k) + ) + dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1)) + + # For each expert, get the indices of tokens routed to it + expert_indices = [] + for expert_idx in range(self.num_experts): + # Get indices of tokens that have non-zero probability for this expert + indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[ + 0 + ] + expert_indices.append(indices) + + return router_logits, dispatch_tensor, expert_indices + + +class Expert(nn.Module): + """ + Individual expert module. + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MixtureOfExperts(nn.Module): + """ + Mixture of Experts layer with support for both manual looping and grouped GEMM. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_experts: int, + top_k: int = 2, + use_mg_gemm: bool = False, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_experts = num_experts + self.top_k = top_k + self.use_mg_gemm = use_mg_gemm and has_mg_gemm + + # Router + self.router = Router(input_dim, num_experts, top_k) + + # Create expert modules + if self.use_mg_gemm: + # For MG GEMM, we need a single weight tensor for all experts + # First layer (input -> hidden) + self.expert_fc1_weight = nn.Parameter( + torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim) + ) + # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim)) + + # Second layer (hidden -> output) + self.expert_fc2_weight = nn.Parameter( + torch.randn(num_experts * output_dim, hidden_dim) + / math.sqrt(hidden_dim) + ) + # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim)) + else: + # For manual looping, create separate experts + self.experts = nn.ModuleList( + [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)] + ) + + def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using manual looping over experts. + """ + batch_size, seq_len, _ = x.shape + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + + # Get routing information + router_logits, dispatch_tensor, expert_indices = self.router(x) + + # Initialize output tensor + final_output = torch.zeros( + batch_size * seq_len, self.output_dim, device=x.device + ) + + # Process each expert + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + # Get tokens routed to this expert + expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim) + + # Process tokens through expert + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # (num_tokens_for_expert, output_dim) + + # Scale outputs by router probabilities + scaled_outputs = expert_outputs * dispatch_tensor[ + indices, expert_idx + ].unsqueeze(1) + + # Add to final output + final_output.index_add_(0, indices, scaled_outputs) + + # Reshape back to original dimensions + output = final_output.reshape(batch_size, seq_len, self.output_dim) + + return output, router_logits + + def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + total_tokens = batch_size * seq_len + + # Get routing information + router_logits, dispatch_tensor, expert_indices = self.router(x) + + # Get token counts for each expert + token_counts = [indices.numel() for indices in expert_indices] + m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device) + + print(f"Token counts per expert: {token_counts}") + print(f"m_sizes: {m_sizes}") + + # Create the combined input tensor + combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device) + + start_idx = 0 + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + end_idx = start_idx + indices.numel() + combined_input[start_idx:end_idx] = x_flat[indices] + start_idx = end_idx + + print(f"combined_input shape: {combined_input.shape}") + + # First layer: input -> hidden + fc1_weight_reshaped = self.expert_fc1_weight.reshape( + self.num_experts, self.hidden_dim, self.input_dim + ) + fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim) + + print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}") + + # Run the grouped GEMM + hidden_outputs = grouped_gemm_forward( + combined_input, fc1_weight_combined, m_sizes + ) + + print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}") + + # Apply activation + hidden_outputs = F.gelu(hidden_outputs) + + print(f"hidden_outputs shape after activation: {hidden_outputs.shape}") + + # Second layer: hidden -> output + # Reshape hidden_outputs to match expected dimensions + reshaped_hidden_outputs = [] + start_idx = 0 + + for expert_idx, count in enumerate(token_counts): + if count > 0: + end_idx = start_idx + count + # Take this expert's outputs and reshape to [count, hidden_dim] + expert_output = hidden_outputs[ + start_idx:end_idx, + expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim, + ] + reshaped_hidden_outputs.append(expert_output) + start_idx = end_idx + + # Concatenate all reshaped outputs + hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0) + + # Reshape expert weights for second layer + fc2_weight_reshaped = self.expert_fc2_weight.reshape( + self.num_experts, self.output_dim, self.hidden_dim + ) + fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim) + + print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}") + + # Run the second grouped GEMM + expert_outputs_combined = grouped_gemm_forward( + hidden_outputs, fc2_weight_combined, m_sizes + ) + + # Initialize final output tensor with correct shape + final_output = torch.zeros(total_tokens, self.output_dim, device=x.device) + + # Distribute the outputs back to the original token positions + start_idx = 0 + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + end_idx = start_idx + indices.numel() + # Get this expert's outputs + expert_outputs = expert_outputs_combined[start_idx:end_idx] + + print( + f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}" + ) + + # Scale outputs by router probabilities + scaled_outputs = expert_outputs * dispatch_tensor[ + indices, expert_idx + ].unsqueeze(1) + + # Ensure dimensions match before using index_add_ + if scaled_outputs.shape[1] != final_output.shape[1]: + # print( + # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}" + # ) + # Reshape if needed - make sure output_dim is correct + scaled_outputs = scaled_outputs[:, : self.output_dim] + + # Add to final output + final_output.index_add_(0, indices, scaled_outputs) + + start_idx = end_idx + + # Reshape back to original dimensions + output = final_output.reshape(batch_size, seq_len, self.output_dim) + + return output, router_logits + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_mg_gemm and has_mg_gemm: + return self.forward_mg_gemm(x) + else: + return self.forward_manual_loop(x) + + +class MoEModel(nn.Module): + """ + Simple model using MoE layers. + """ + + def __init__( + self, + vocab_size: int, + embed_dim: int, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + use_mg_gemm: bool = False, + ): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.moe_layer = MixtureOfExperts( + input_dim=embed_dim, + hidden_dim=hidden_dim, + output_dim=embed_dim, + num_experts=num_experts, + top_k=top_k, + use_mg_gemm=use_mg_gemm, + ) + self.output_layer = nn.Linear(embed_dim, vocab_size) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x shape: (batch_size, seq_len) + embedded = self.embedding(x) # (batch_size, seq_len, embed_dim) + moe_output, router_logits = self.moe_layer( + embedded + ) # (batch_size, seq_len, embed_dim) + logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size) + return logits, router_logits + + +def compute_load_balancing_loss( + router_logits: torch.Tensor, num_experts: int +) -> torch.Tensor: + """ + Compute the load balancing loss for MoE training. + + Args: + router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts) + num_experts (int): Number of experts + + Returns: + torch.Tensor: Load balancing loss + """ + # Get router probabilities + router_probs = F.softmax( + router_logits, dim=-1 + ) # (batch_size * seq_len, num_experts) + + # Compute fraction of tokens routed to each expert + # Sum across the batch dimension and normalize + router_probs_sum = router_probs.sum(dim=0) # (num_experts,) + router_probs_sum = router_probs_sum / router_probs_sum.sum() + + # Compute the mean probability per expert + mean_prob = 1.0 / num_experts + + # Compute the fraction of tokens routed to each expert + # The goal is to have uniform routing across experts + load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum) + + return load_balancing_loss + + +def generate_sample_data( + batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate sample data for training. + + Args: + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + device (str): Device to use + + Returns: + Tuple of input tokens and target tokens + """ + # Generate random input tokens + inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + + # Generate random target tokens + targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + + return inputs, targets + + +def train_epoch( + model: nn.Module, + optimizer: torch.optim.Optimizer, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, + load_balance_coef: float = 0.01, +) -> Dict[str, float]: + """ + Train the model for one epoch. + + Args: + model (nn.Module): Model to train + optimizer (torch.optim.Optimizer): Optimizer + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches per epoch + device (str): Device to use + load_balance_coef (float): Coefficient for load balancing loss + + Returns: + Dict containing training metrics + """ + model.train() + total_loss = 0.0 + total_acc = 0.0 + start_time = time.time() + + for i in range(num_batches): + # Generate sample data + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + + # Forward pass + optimizer.zero_grad() + logits, router_logits = model(inputs) + + # Compute loss + # Reshape for cross entropy loss + logits_flat = logits.reshape(-1, vocab_size) + targets_flat = targets.reshape(-1) + + # Cross entropy loss + ce_loss = F.cross_entropy(logits_flat, targets_flat) + + # Load balancing loss + lb_loss = compute_load_balancing_loss( + router_logits, model.moe_layer.num_experts + ) + + # Combined loss + loss = ce_loss + load_balance_coef * lb_loss + + # Backward pass + loss.backward() + optimizer.step() + + # Compute accuracy + preds = logits_flat.argmax(dim=-1) + correct = (preds == targets_flat).float().sum() + acc = correct / (batch_size * seq_len) + + # Accumulate metrics + total_loss += loss.item() + total_acc += acc.item() + + # Log progress + if (i + 1) % 10 == 0: + logging.info( + f"Batch {i + 1}/{num_batches} | " + f"Loss: {loss.item():.4f} | " + f"CE Loss: {ce_loss.item():.4f} | " + f"LB Loss: {lb_loss.item():.4f} | " + f"Acc: {acc.item():.4f}" + ) + + # Compute average metrics + avg_loss = total_loss / num_batches + avg_acc = total_acc / num_batches + epoch_time = time.time() - start_time + + return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time} + + +def evaluate( + model: nn.Module, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, +) -> Dict[str, float]: + """ + Evaluate the model. + + Args: + model (nn.Module): Model to evaluate + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches for evaluation + device (str): Device to use + + Returns: + Dict containing evaluation metrics + """ + model.eval() + total_loss = 0.0 + total_acc = 0.0 + + with torch.no_grad(): + for i in range(num_batches): + # Generate sample data + inputs, targets = generate_sample_data( + batch_size, seq_len, vocab_size, device + ) + + # Forward pass + logits, router_logits = model(inputs) + + # Compute loss + logits_flat = logits.reshape(-1, vocab_size) + targets_flat = targets.reshape(-1) + + # Cross entropy loss + loss = F.cross_entropy(logits_flat, targets_flat) + + # Compute accuracy + preds = logits_flat.argmax(dim=-1) + correct = (preds == targets_flat).float().sum() + acc = correct / (batch_size * seq_len) + + # Accumulate metrics + total_loss += loss.item() + total_acc += acc.item() + + # Compute average metrics + avg_loss = total_loss / num_batches + avg_acc = total_acc / num_batches + + return {"loss": avg_loss, "acc": avg_acc} + + +def measure_performance( + model: nn.Module, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, +) -> Dict[str, float]: + """ + Measure forward and backward pass performance. + + Args: + model (nn.Module): Model to evaluate + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches for measurement + device (str): Device to use + + Returns: + Dict containing performance metrics + """ + model.train() + + # Create dummy optimizer + optimizer = optim.Adam(model.parameters(), lr=0.001) + + # Warmup + for _ in range(5): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + logits, router_logits = model(inputs) + loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1)) + loss.backward() + optimizer.zero_grad() + + # Measure forward pass time + torch.cuda.synchronize() + forward_start = time.time() + + for _ in range(num_batches): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + with torch.no_grad(): + logits, router_logits = model(inputs) + + torch.cuda.synchronize() + forward_end = time.time() + forward_time = (forward_end - forward_start) / num_batches + + # Measure backward pass time + torch.cuda.synchronize() + backward_start = time.time() + + for _ in range(num_batches): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + logits, router_logits = model(inputs) + loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1)) + loss.backward() + optimizer.zero_grad() + + torch.cuda.synchronize() + backward_end = time.time() + backward_time = (backward_end - backward_start) / num_batches + + return { + "forward_time": forward_time * 1000, # Convert to ms + "backward_time": backward_time * 1000, # Convert to ms + "total_time": (forward_time + backward_time) * 1000, # Convert to ms + } + + +def compare_methods(args): + """ + Compare manual looping and MG GEMM implementations. + """ + device = torch.device(args.device) + + # Create models + manual_model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=False, + ).to(device) + + if has_mg_gemm: + mg_model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=True, + ).to(device) + else: + mg_model = None + + # Measure performance + logging.info("Measuring performance of manual looping method...") + manual_perf = measure_performance( + manual_model, + args.batch_size, + args.seq_len, + args.vocab_size, + args.perf_batches, + device, + ) + + if mg_model is not None: + logging.info("Measuring performance of MG GEMM method...") + mg_perf = measure_performance( + mg_model, + args.batch_size, + args.seq_len, + args.vocab_size, + args.perf_batches, + device, + ) + else: + mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0} + + # Log results + logging.info("\n===== Performance Comparison =====") + logging.info("Model Configuration:") + logging.info(f" - Batch Size: {args.batch_size}") + logging.info(f" - Sequence Length: {args.seq_len}") + logging.info(f" - Embed Dimension: {args.embed_dim}") + logging.info(f" - Hidden Dimension: {args.hidden_dim}") + logging.info(f" - Number of Experts: {args.num_experts}") + logging.info(f" - Top-K: {args.top_k}") + logging.info("") + + logging.info("Manual Looping Method:") + logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms") + logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms") + logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms") + logging.info("") + + if mg_model is not None: + logging.info("MG GEMM Method:") + logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms") + logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms") + logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms") + logging.info("") + + # Calculate speedup + forward_speedup = ( + manual_perf["forward_time"] / mg_perf["forward_time"] + if mg_perf["forward_time"] > 0 + else 0 + ) + backward_speedup = ( + manual_perf["backward_time"] / mg_perf["backward_time"] + if mg_perf["backward_time"] > 0 + else 0 + ) + total_speedup = ( + manual_perf["total_time"] / mg_perf["total_time"] + if mg_perf["total_time"] > 0 + else 0 + ) + + logging.info("Speedup (MG GEMM vs Manual):") + logging.info(f" - Forward Speedup: {forward_speedup:.2f}x") + logging.info(f" - Backward Speedup: {backward_speedup:.2f}x") + logging.info(f" - Total Speedup: {total_speedup:.2f}x") + else: + logging.info("MG GEMM method not available.") + + +def train_model(args): + """ + Train an MoE model. + """ + device = torch.device(args.device) + + # Create model + model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=args.use_mg_gemm and has_mg_gemm, + ).to(device) + + # Create optimizer + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # Log model information + logging.info("Model configuration:") + logging.info(f" - Vocabulary Size: {args.vocab_size}") + logging.info(f" - Embedding Dimension: {args.embed_dim}") + logging.info(f" - Hidden Dimension: {args.hidden_dim}") + logging.info(f" - Number of Experts: {args.num_experts}") + logging.info(f" - Top-K: {args.top_k}") + logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}") + + # Training loop + for epoch in range(args.epochs): + logging.info(f"\nEpoch {epoch + 1}/{args.epochs}") + + # Train + train_metrics = train_epoch( + model=model, + optimizer=optimizer, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=args.vocab_size, + num_batches=args.train_batches, + device=device, + load_balance_coef=args.load_balance_coef, + ) + + # Evaluate + eval_metrics = evaluate( + model=model, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=args.vocab_size, + num_batches=args.eval_batches, + device=device, + ) + + # Log metrics + logging.info( + f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}" + ) + logging.info( + f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}" + ) + logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train MoE model") + + # Model parameters + parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size") + parser.add_argument( + "--embed_dim", type=int, default=512, help="Embedding dimension" + ) + parser.add_argument( + "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts" + ) + parser.add_argument("--num_experts", type=int, default=8, help="Number of experts") + parser.add_argument( + "--top_k", type=int, default=2, help="Top-k experts to route to" + ) + + # Training parameters + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") + parser.add_argument( + "--train_batches", + type=int, + default=100, + help="Number of training batches per epoch", + ) + parser.add_argument( + "--eval_batches", type=int, default=20, help="Number of evaluation batches" + ) + parser.add_argument( + "--perf_batches", + type=int, + default=50, + help="Number of batches for performance testing", + ) + parser.add_argument( + "--load_balance_coef", + type=float, + default=0.01, + help="Load balancing loss coefficient", + ) + + # Runtime parameters + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use (cuda or cpu)", + ) + parser.add_argument( + "--use_mg_gemm", + action="store_true", + help="Use MG GEMM implementation if available", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Compare manual and MG GEMM implementations", + ) + parser.add_argument("--train", action="store_true", help="Train the model") + + args = parser.parse_args() + + # Check for CUDA + if args.device == "cuda" and not torch.cuda.is_available(): + logging.warning("CUDA not available, using CPU instead.") + args.device = "cpu" + + # Log basic information + logging.info(f"PyTorch version: {torch.__version__}") + logging.info(f"Device: {args.device}") + logging.info(f"MG GEMM available: {has_mg_gemm}") + + # Run the requested action + if args.compare: + compare_methods(args) + elif args.train: + train_model(args) + else: + # Default to comparison if no action specified + compare_methods(args) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py new file mode 100644 index 0000000000000000000000000000000000000000..76e0b12d882fa46ed1f11139352141f06d899f59 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +from reference_utils import ( + analyze_tensor_differences, + compute_reference_backward, + compute_reference_forward, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Import grouped GEMM implementations +try: + from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def test_forward_pass(): + """ + A simple test for the M*G grouped GEMM forward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 1 # Number of groups + M_sizes = [ + 2048, + ] # 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn(M_total, K, dtype=torch.float16, device=device) + w = torch.randn(N, K, dtype=torch.float16, device=device) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Run forward pass + logging.info("Running forward pass with grouped GEMM") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Compute reference result + logging.info("Computing reference result with PyTorch") + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare results + logging.info("Comparing with PyTorch reference") + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + return forward_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_backward_pass(): + """ + A simple test for the M*G grouped GEMM backward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 4 # Number of groups + M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn( + M_total, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Step 1: Run forward pass + logging.info("Running forward pass") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Create a gradient for backpropagation + grad_output = torch.randn_like(result) + logging.info(f"Created gradient with shape: {grad_output.shape}") + + # Step 2: Run backward pass directly + logging.info("Running backward pass directly") + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Verify gradient shapes + logging.info( + f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" + ) + + # Step 3: Verify gradient computation using PyTorch's autograd + logging.info("Running PyTorch reference implementation") + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output) + + # Compare gradients + logging.info("Comparing gradients with PyTorch reference") + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Log overall result + if grad_x_close and grad_w_close: + logging.info("✓ SUCCESS: Gradients match the PyTorch reference") + else: + logging.error("✗ FAILURE: Gradient mismatch detected") + + return grad_x_close and grad_w_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_multiple_deepseek_configs(): + """ + Test multiple DeepSeek model configurations with both forward and backward pass verification. + """ + # DeepSeek configurations: (G, M, K, N) + configs = [ + (4, 8192, 7168, 4096), # Config 1 + (4, 8192, 2048, 7168), # Config 2 + (8, 4096, 7168, 4096), # Config 3 + (8, 4096, 2048, 7168), # Config 4 + ] + + results = [] + + for config_idx, (G, M, K, N) in enumerate(configs): + logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====") + logging.info(f"G={G}, M={M}, K={K}, N={N}") + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create even group sizes + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors using float16 for higher precision + x = torch.randn( + M, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn( + N, K, dtype=torch.float16, device=device, requires_grad=True + ) + + logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}") + + # Run forward pass + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # ===== FORWARD PASS VERIFICATION ===== + # Compute reference forward result + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare forward results + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + # ===== BACKWARD PASS VERIFICATION ===== + # Create gradient for backpropagation + grad_output = torch.randn_like(result) + + # Run backward pass + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward( + x, w, m_sizes, grad_output + ) + + # Compare backward results + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Overall config result + backward_close = grad_x_close and grad_w_close + config_success = forward_close and backward_close + results.append( + (config_idx + 1, config_success, forward_close, backward_close) + ) + + # Log overall config result + if config_success: + logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!") + else: + logging.error( + f"✗ FAILURE: Config {config_idx+1} failed one or more tests" + ) + + except Exception as e: + logging.error(f"Config {config_idx+1} test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + results.append((config_idx + 1, False, False, False)) + + # Summary + logging.info("\n===== Test Results Summary =====") + for config_idx, overall_success, forward_success, backward_success in results: + overall_status = "✓ PASSED" if overall_success else "✗ FAILED" + forward_status = "✓ PASSED" if forward_success else "✗ FAILED" + backward_status = "✓ PASSED" if backward_success else "✗ FAILED" + + logging.info(f"Config {config_idx}: {overall_status}") + logging.info(f" - Forward pass: {forward_status}") + logging.info(f" - Backward pass: {backward_status}") + + return all(overall_success for _, overall_success, _, _ in results) + + +if __name__ == "__main__": + logging.info( + "Running verification for both forward and backward pass of M*G grouped GEMM" + ) + + # Run basic forward pass test + logging.info("\n===== Running basic forward pass test =====") + success_forward = test_forward_pass() + logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}") + + # Run basic backward pass test + logging.info("\n===== Running basic backward pass test =====") + success_backward = test_backward_pass() + logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}") + + # Run multiple DeepSeek configs with forward and backward verification + logging.info("\n===== Running tests for all DeepSeek configs =====") + success_configs = test_multiple_deepseek_configs() + logging.info( + f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}" + ) + + # Overall result + overall_success = success_forward and success_backward and success_configs + logging.info( + f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}" + ) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..37bf59f29e89b0bd3abb69d3e5d75bc14721b97b --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py @@ -0,0 +1,1304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - flat index forward kernel is derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools +import logging + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from tma_autotuning import ( + ALIGN_SIZE_M, + _NV_CONFIGS, + CudaUtils, + early_config_prune, + TmaDescriptorHelper, +) + + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# ============== Start Triton Kernels =============== + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_hopper( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_EPILOGUE_SUBTILING: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel for Hopper. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty # output dtype + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store + + M_end = 0 + M_start = 0 + processed_tiles = 0 + # Size of individual weight matrix + n_size = N // G + n_start = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + n_start = n_size * g + + if m_size > 0: + # Process this group + + # Acquire hold on c_desc_ptr for TMA Store + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * n_size, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # columnwise + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (n_start + n_offset).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [global_n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c0, [m_offset, n_offset] + ) + c1 = acc1.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + ) + else: + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + # move to next tile in group + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_tma( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + a_scale_ptr, + b_scale_ptr, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_FP8: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # TMA Store prep + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_no_tma( + a_ptr, + b_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For bc and Ampere, we never use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + c_desc_ptr = None + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] + b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] + + for k_offset in range(0, K, BLOCK_SIZE_K): + # Load with bounds checking + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + + # Main matmul + accumulator += tl.dot(a, b.T) + + # Update pointers for next block + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # Store without TMA + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c = accumulator.to(c_dtype) + + tl.store( + c_ptr + + (M_start + offs_am[:, None]) * N # Row stride is N + + offs_bn[None, :], # Column offset + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +""" +Backward pass for grouped GEMM with Triton, where grouping is M*G +We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`). +""" + + +# ---- dx flat linear indexed ---- +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dx_tma( + grad_output_desc_ptr, # [MG, N] + w_desc_ptr, # [N, K] + grad_input_ptr, # output grad_x [MG, K] + workspace, # for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + TMA-optimized kernel for computing gradients with respect to input (dx). + For the forward pass Y = X @ W.T, the backward for input is: + grad_X = grad_Y @ W + + This maps to [MG, N] @ [N, K] -> [MG, K] + + Key differences from forward: + 1. W is used directly and not transposed + 2. The reduction dimension is now N (not K) + 3. Output is [M, K] instead of [M, N] + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = grad_input_ptr.dtype.element_ty + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups - same as forward + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + # tiles for this group - now producing [M, K] output + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles + + # TMA Store prep for [M, K] output + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=grad_input_ptr + M_start * K, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[m_size, K], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # Different tiling scheme for [M, K] output + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles + + # for grad_input block [M, K] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + # Position in full matrix + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + # reduce along N dimension (instead of K in forward) + for n_offset in range(0, N, BLOCK_SIZE_N): + # grad_output block [M, N] + grad_output = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # weight block [N, K] - no transpose needed + w = tl._experimental_descriptor_load( + w_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + # grad_x = grad_output @ w + # reducing along N dimension + accumulator += tl.dot(grad_output, w) + + # Store using TMA + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, k_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +# ---- dw flat linear indexed ---- + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dw_tma( + x_desc_ptr, # input descriptor [M_total, K] + grad_output_desc_ptr, # grad_output descriptor [M_total, N] + grad_weight_ptr, # output grad_w [N, K] + workspace, # workspace for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension +) -> None: + """ + Improved TMA-optimized kernel for computing gradients with respect to weights (dw). + Uses flat index structure similar to forward. + + For the forward pass Y = X @ W.T, + the backward for weights is: + grad_W = grad_Y.T @ X + + Where: + - grad_Y is [MG, N] + - X is [MG, K] + - grad_W is [N, K] + - we return [N,K] + """ + # Get thread block index l + tbidx = tl.program_id(0) + + # Get output data type + c_dtype = grad_weight_ptr.dtype.element_ty + + # Calculate number of output tiles + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + total_output_tiles = num_n_tiles * num_k_tiles + + # Process tiles in strided manner across SMs + for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): + # Calculate tile indices + tile_n_idx = tile_idx % num_n_tiles + tile_k_idx = tile_idx // num_n_tiles + + # Calculate global offsets + n_offset = tile_n_idx * BLOCK_SIZE_N + k_offset = tile_k_idx * BLOCK_SIZE_K + + # Initialize accumulator for this output tile [N, K] + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + # Process each group + M_end = 0 + for g in range(G): + # Get group boundaries + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + # Only process if group is non-empty + if m_size > 0: + # Process this group in chunks along the M dimension + for m_offset in range(0, m_size, BLOCK_SIZE_M): + # Calculate actual block size (handling boundary) + m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + + # Only process if we have actual work to do + if m_block_size > 0: + # Global offset for this chunk + m_global_offset = M_start + m_offset + + if USE_TMA_LOAD: + # Load input chunk [M_chunk, K] using TMA + x_block = tl._experimental_descriptor_load( + x_desc_ptr, + [m_global_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + + # Load grad_output chunk [M_chunk, N] using TMA + grad_output_block = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_global_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # Apply masks for valid regions + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < m_block_size + + # Zero out invalid elements + x_block = tl.where(m_mask[:, None], x_block, 0.0) + grad_output_block = tl.where( + m_mask[:, None], grad_output_block, 0.0 + ) + else: + # Manual load with bounds checking + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks + m_mask = offs_m < m_block_size + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + + # Combined masks + mk_mask = m_mask[:, None] & k_mask[None, :] + mn_mask = m_mask[:, None] & n_mask[None, :] + + # Global offsets for loading + m_global_offs = m_global_offset + offs_m + + # Load x block [M_chunk, K] + x_block = tl.load( + x_desc_ptr + + m_global_offs[:, None] * K + + (k_offset + offs_k)[None, :], + mask=mk_mask, + other=0.0, + ) + + # Load grad_output block [M_chunk, N] + grad_output_block = tl.load( + grad_output_desc_ptr + + m_global_offs[:, None] * N + + (n_offset + offs_n)[None, :], + mask=mn_mask, + other=0.0, + ) + + # Compute partial contribution: grad_W += grad_Y.T @ X + # transpose grad_output for the matmul + contribution = tl.dot( + grad_output_block.to(tl.float32).T, # [N, M_chunk] + x_block.to(tl.float32), # [M_chunk, K] + ) + + # Accumulate + accumulator += contribution + + # Store the result + if USE_TMA_STORE: + # Store using TMA + tl._experimental_descriptor_store( + workspace, # TMA store descriptor + accumulator.to(c_dtype), + [n_offset, k_offset], + ) + else: + # Manual store with bounds checking + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks for bounds checking + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + output_mask = n_mask[:, None] & k_mask[None, :] + + # Store the result + tl.store( + grad_weight_ptr + + (n_offset + offs_n)[:, None] * K + + (k_offset + offs_k)[None, :], + accumulator.to(c_dtype), + mask=output_mask, + ) + + +# ======== End Triton kernels ======== + +# ======== Triton wrapper functions ======== + +# ----- main forward pass wrapper ----- + + +def grouped_gemm_forward( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + tma_size: int = 128, +) -> torch.Tensor: + """ + M*G style grouped GEMM with TMA and Float8 support. + # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. + + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + # Total input size is now [M_total, K] where M_total is the sum of all group sizes + M_total, K = x.shape + N = w.shape[0] # N is now the same for all groups + + assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" + + # Verify that all group sizes are multiples of ALIGN_SIZE_M + # This check is commented out because it will involve a GPU-CPU sync + # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" + + # Create output tensor with correct shape [M_total, N] + y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) + + if M_total == 0: + return y + + NUM_SMS = CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + USE_EPILOGUE_SUBTILING = False + + # TMA descriptor helper + desc_helper = None + desc_x = x + desc_w = w + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + if USE_TMA_STORE: + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + _kernel_mg_forward_hopper[grid]( + desc_x, + desc_w, + y, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + TMA_SIZE=tma_size, + USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, + ) + + return y + + +# ======== Improved Backward ============= +def grouped_gemm_backward( + grad_output: torch.Tensor, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Unified backward pass for grouped GeMM with M*G grouping. + Uses optimized TMA-based implementations for both dx and dw when available. + + Args: + grad_output: Gradient of output, shape [M_total, N] + x: Input tensor from forward pass, shape [M_total, K] + w: Weight tensor from forward pass, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + + + Returns: + Tuple of gradients with respect to x and w: (grad_x, grad_w) + """ + logging.info("Starting unified grouped_gemm_backward") + + # do this once, seems expensive + NUM_SMS = CudaUtils.get_num_sms() + + # Basic validation + G = m_sizes.shape[0] + M_total, K_x = x.shape + M_grad, N = grad_output.shape + N_w, K_w = w.shape + + # Check dimensions + if K_x != K_w: + raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") + if M_total != M_grad: + raise ValueError( + f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}" + ) + + # Check total M matches sum of group sizes + sum_m_sizes = m_sizes.sum().item() + if M_total != sum_m_sizes: + raise ValueError( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Make sure inputs are contiguous + grad_output = grad_output.contiguous() + x = x.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() + + # Check TMA support + can_use_tma = use_tma and CudaUtils.verify_tma() + if use_tma and not can_use_tma: + logging.info("TMA requested but not supported on this device") + use_tma = False + + # Compute grad_x using flat linear implementation + try: + logging.info(f"Computing grad_x with flat linear kernel") + + # Use TMA-optimized implementation + grad_x = grouped_gemm_dx_tma( + grad_output=grad_output, + w=w, + m_sizes=m_sizes, + num_sms=NUM_SMS, + tma_size=tma_size, + ) + + except Exception as e: + logging.error(f"Error in grad_x computation: {e}") + raise + + # Compute grad_w using flat linear style implementation + try: + logging.info(f"Computing grad_w with flat linear kernel") + + grad_w = grouped_gemm_dw_tma( + x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size + ) + except Exception as e: + logging.error(f"Error in grad_w computation: {e}") + raise + + return grad_x, grad_w + + +# ----- dx backward pass wrapper ----- + + +def grouped_gemm_dx_tma( + grad_output: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized backward pass wrapper for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + # using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + """ + Optimized backward pass for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Optimized dx computation requires TMA support") + + G = m_sizes.shape[0] + + assert grad_output.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M_total, N_grad = grad_output.shape + N_w, K = w.shape + + # Check dimensions + assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + M_total == sum_m_sizes + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_x) with shape [M_total, K] + grad_x = torch.empty( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + NUM_SMS = num_sms # CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("grad_output") + desc_helper.init_tma_descriptor("w") + desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + # Allocate workspace for TMA store + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=grad_output.device, + dtype=torch.uint8, + ) + + def grid(META): + # Fill TMA descriptors with appropriate dimensions + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N_grad, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N_w, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + # Launch the flat linear kernel for computing grad_x + _kernel_mg_dx_tma[grid]( + desc_grad_output, + desc_w, + grad_x, + workspace, + m_sizes, + G, + M_BUCKET, + N_grad, # N dimension is now the reduction dimension + K, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_x + + +# ======== dw wrapper function ========== + + +def grouped_gemm_dw_tma( + x: torch.Tensor, + grad_output: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. + For the forward pass Y = X @ W.T, the backward for weights is: + grad_W = grad_Y.T @ X + + Args: + x: Input tensor, shape [M_total, K] + grad_output: Gradient of output, shape [M_total, N] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor in bytes + + + Returns: + grad_w: Gradient with respect to weights, shape [N, K] + """ + # Check TMA support + has_tma_support = CudaUtils.verify_tma() + + # Get group count + G = m_sizes.shape[0] + + # Ensure contiguous tensors + x = x.contiguous() + grad_output = grad_output.contiguous() + m_sizes = m_sizes.contiguous() + + # Get dimensions + M_total, K_x = x.shape + M_grad, N = grad_output.shape + + # Check dimensions + assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + sum_m_sizes == M_total + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_w) with shape [N, K] + grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + + NUM_SMS = num_sms + + # TODO - hardcoded for now...but should set TMA flags based on hardware support + USE_TMA_LOAD = True # has_tma_support + USE_TMA_STORE = True # has_tma_support + + # Set up TMA descriptors or direct pointers + if USE_TMA_LOAD or USE_TMA_STORE: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + + if USE_TMA_LOAD: + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("grad_output") + x_desc = desc_helper.get_tma_descriptor_kernel_param("x") + grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( + "grad_output" + ) + else: + x_desc = x + grad_output_desc = grad_output + + if USE_TMA_STORE: + desc_helper.init_tma_descriptor("grad_w") + workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") + else: + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + else: + # If not using TMA, just use the tensors directly + x_desc = x + grad_output_desc = grad_output + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + + # M_BUCKET for grid size + M_BUCKET = triton.next_power_of_2(M_total) + + # Define grid for kernel launch + def grid(META): + if USE_TMA_LOAD or USE_TMA_STORE: + + if USE_TMA_LOAD: + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K_x, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + if USE_TMA_STORE: + desc_helper.fill_2d_tma_descriptor( + "grad_w", + grad_w.data_ptr(), + N, + K_x, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + grad_w.element_size(), + ) + + # Return grid size - one block per SM for balanced work distribution + return (NUM_SMS,) + + # Launch the optimized kernel + _kernel_mg_dw_tma[grid]( + x_desc, + grad_output_desc, + grad_w, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K_x, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_w + + +# ======== End Backwards Wrapper Functions ============= + +# ======== PyTorch wrapper functions ======== + + +class GroupedGEMM_mg(torch.autograd.Function): + """ + Autograd function for GroupedGEMM with M*G grouping. + Supports both standard and FP8 quantized operations. + """ + + @staticmethod + def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128): + """ + Forward pass of GroupedGEMM. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + + # Use regular forward without quantization + output = grouped_gemm_forward( + x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False + ) + + # Save inputs and parameters for backward pass + ctx.save_for_backward(x, w, m_sizes) + ctx.use_tma = use_tma + ctx.tma_size = tma_size + + ctx.save_for_backward(x, w, m_sizes) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of M*G GroupedGEMM. + + Args: + grad_output: Gradient of output, shape [M_total, N] + + Returns: + Tuple of gradients: + - grad_x: Gradient with respect to x, shape [M_total, K] + - grad_w: Gradient with respect to w, shape [N, K] + - None: Gradient with respect to m_sizes (not differentiable) + - None: Gradient with respect to use_tma (not differentiable) + - None: Gradient with respect to tma_size (not differentiable) + + """ + # Retrieve saved tensors and parameters + + x, w, m_sizes = ctx.saved_tensors + + use_tma = ctx.use_tma + tma_size = ctx.tma_size + + # Compute gradients using the unified implementation + grad_x, grad_w = grouped_gemm_backward( + grad_output=grad_output, + x=x, + w=w, + m_sizes=m_sizes, + use_tma=use_tma, + tma_size=tma_size, + ) + + # Return gradients for all inputs (None for non-differentiable parameters) + return grad_x, grad_w, None, None + + +def mg_grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + Unified differentiable grouped GEMM operation for M*G grouped GEMM. + Supports both standard precision and FP8 quantized operations. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0835132c3ebf31f8c88a066e5bf19eed4c4acd69 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def compute_reference_forward(x, w, m_sizes): + """ + Compute reference forward pass using PyTorch operations. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + + Returns: + torch.Tensor: Reference output tensor of shape (M, N) + """ + result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) + + m_start = 0 + for g in range(len(m_sizes)): + m_size = m_sizes[g].item() + if m_size > 0: + m_end = m_start + m_size + + # Extract group input + x_g = x[m_start:m_end] + + # Compute group output: y_g = x_g @ w.T + y_g = torch.matmul(x_g, w.T) + + # Store result + result[m_start:m_end] = y_g + + # Update start index + m_start = m_end + + return result + + +def compute_reference_backward(x, w, m_sizes, grad_output): + """ + Compute reference backward pass using PyTorch autograd. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + grad_output (torch.Tensor): Gradient tensor of shape (M, N) + + Returns: + tuple: (grad_x, grad_w) gradient tensors + """ + # Create autograd-enabled copies + x_autograd = x.detach().clone().requires_grad_(True) + w_autograd = w.detach().clone().requires_grad_(True) + + # Compute forward pass + output = compute_reference_forward(x_autograd, w_autograd, m_sizes) + + # Backpropagate + output.backward(grad_output) + + return x_autograd.grad, w_autograd.grad + + +def analyze_tensor_differences(actual, expected, name): + """ + Analyze differences between actual and expected tensors. + + Args: + actual (torch.Tensor): Actual tensor + expected (torch.Tensor): Expected tensor + name (str): Name of the tensor for logging + + Returns: + bool: True if tensors are close enough + """ + rtol = 0.5 # Relative tolerance for float16 + atol = 0.5 # Absolute tolerance for float16 + + # Analyze differences + diff = (actual - expected).abs() + max_idx = diff.argmax().item() + idx = np.unravel_index(max_idx, actual.shape) + max_diff = diff.max().item() + + logging.info(f"Largest {name} difference: {max_diff} at {idx}") + logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}") + + is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) + + if is_close: + logging.info(f"✓ SUCCESS: {name} matches PyTorch reference") + else: + logging.error(f"✗ FAILURE: {name} mismatch detected") + + # Count zeros + zeros_actual = (actual == 0).sum().item() + zeros_expected = (expected == 0).sum().item() + logging.info( + f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)" + ) + logging.info( + f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)" + ) + + # Check for NaNs + nan_actual = torch.isnan(actual).sum().item() + if nan_actual > 0: + logging.error(f"NaN values detected in {name}: {nan_actual}") + + return is_close diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py new file mode 100644 index 0000000000000000000000000000000000000000..2429432d756ae4d5bb6f91a6108c7ba8a4b9c627 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging +import unittest +from typing import Tuple + +import torch +import torch.nn as nn + +from mg_grouped_gemm import grouped_gemm_forward + + +class TestMG_GroupedGEMM(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2020) + + def _run_grouped_gemm_test( + self, + shape: Tuple[int, int, int, int], + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + atol: float = 1e-5, + rtol: float = 1.6e-2, + ) -> None: + G, M, N, K = shape + # In M*G grouping, input is [M*G, K] and weights are [N*G, K] + a = torch.randn(M * G, K, dtype=dtype, device=device) + b = torch.randn(N * G, K, dtype=dtype, device=device) + + # Create equal-sized groups for simplicity + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + result = grouped_gemm_forward(a, b, m_sizes) + self.assertTrue(result.shape == (M * G, N)) + + expected_result = torch.zeros(M * G, N, dtype=dtype, device=device) + m_start = 0 + for g in range(G): + m_end = m_start + m_sizes[g] + b_slice = b[N * g : N * (g+1), :] + expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T + m_start = m_end + + # Convert result to match input dtype if needed + result = result.to(dtype) + torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol) + + def test_MG_grouped_gemm_bf16(self) -> None: + for G in (1, 4, 16): + for M in (128, 512, 1024): + print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}") + self._run_grouped_gemm_test( + (G, M, 1024, 1024), + torch.device("cuda"), + dtype=torch.bfloat16, + atol=1e-5, + rtol=1.6e-2, + ) + + def test_MG_grouped_gemm_deepseek_shapes(self) -> None: + """Test with shapes from Deepseek model.""" + deepseek_shapes = [ + (4, 2048, 4096, 7168), # G, M, N, K + (4, 2048, 7168, 2048), + (8, 512, 4096, 7168), + (8, 512, 7168, 2048), + ] + + device = torch.device("cuda") + + for shape in deepseek_shapes: + G, M, N, K = shape + print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}") + self._run_grouped_gemm_test( + shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2 + ) diff --git a/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e2ce93498f5520041d85a163300092d6d46d1d Binary files /dev/null and b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0325212c50e53e2f328491835825173c0e8c3008 Binary files /dev/null and b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d369912405f6ed0b931adec9e4ac3041ead006cd Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c0bf6f5f7f71063afb7983f29ab137acb93da83 Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..0b925b36207875dedc13a16be10890c3671cdabb --- /dev/null +++ b/torchtitan/experiments/llama4/model/moe.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn + +from .args import TransformerModelArgs + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + + def forward( + self, + x: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_local_tokens_per_expert is not None: + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_local_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.w1[expert_idx], + self.w2[expert_idx], + self.w3[expert_idx], + ) + h = F.silu(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # TODO:optimize with GroupedGEMM + # https://github.com/pytorch/pytorch/pull/150374 + # _gouped_mm requires shapes to be multiple of 8 + # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32) + # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)) + # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, self.w1)) + h = h * torch.bmm(x, self.w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, self.w2) + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + dim (int): Dimension of input tokens. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + ): + super().__init__() + self.gate = nn.Linear(dim, num_experts, bias=False) + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_local_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = self.gate(x) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + else: + scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype) + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_local_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +# TODO: implement load balancing auxiliary loss for token-choice routing +class MoE(nn.Module): + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + dim = model_args.dim + hidden_dim = 4 * model_args.dim + ffn_dim_multiplier = model_args.ffn_dim_multiplier + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + + num_experts = model_args.num_experts + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) + + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + self.experts = GroupedExperts( + dim=dim, hidden_dim=hidden_dim, num_experts=num_experts + ) + self.router = TokenChoiceTopKRouter( + dim=dim, num_experts=num_experts, top_k=model_args.top_k + ) + self.shared_expert = ( + GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1) + if model_args.use_shared_expert + else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_local_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_local_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim)) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + routed_input = routed_input * top_scores.reshape(-1, 1) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_local_tokens_per_expert) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights(self, init_std: float): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..6530b8ce992c8c33ccec94614e026d73964710ee --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh @@ -0,0 +1,26 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_hf_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_hf_to_dcp_with_gpus.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/multimodal/tests/test_multimodal_model.py b/torchtitan/experiments/multimodal/tests/test_multimodal_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b5acc51bb3d186674267a4fc47d9075f04122a60 --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/test_multimodal_model.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtitan.experiments.llama_multimodal import ( + ModelArgs, + MultimodalDecoder, + VisionEncoder, +) + +from .test_utils import fixed_init_model, fixed_init_tensor + + +@pytest.fixture +def encoder_config(): + return ModelArgs( + encoder_embed_dim=32, + encoder_num_layers=2, + encoder_num_heads=4, + tile_size=49, + patch_size=9, + max_num_tiles=4, + in_channels=3, + return_intermediates=[0, 1], + num_layers_projection=2, + decoder_embed_dim=128, + ) + + +@pytest.fixture +def decoder_config(): + return ModelArgs( + decoder_embed_dim=512, + vocab_size=10000, + fusion_interval=2, + num_special_tokens=3, + decoder_num_layers=6, + decoder_num_heads=8, + decoder_num_kv_heads=4, + max_seq_len=512, + rope_theta=50000.0, + ) + + +class TestMultimodalModelVisionEncoder: + @pytest.fixture(autouse=True) + def setup_class(self, encoder_config): + self.model_args = encoder_config + self.batch_size = 1 + self.num_imgs = 2 + self.num_tiles = 4 + self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape( + self.batch_size, self.num_imgs, 2 + ) + image = torch.rand( + ( + self.batch_size, + self.num_imgs, + self.num_tiles, + self.model_args.in_channels, + self.model_args.tile_size, + self.model_args.tile_size, + ) + ) + self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1) + + def test_llama_mm_vision_encoder(self): + model = VisionEncoder(self.model_args) + fixed_init_model(model, min_val=-1, max_val=1) + output = model(self.image, self.aspect_ratio) + expected_shape = ( + self.batch_size, + self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1), + self.model_args.decoder_embed_dim, + ) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + # TODO: Need to ensure numerical stability before doing convergence test. + # output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is + # the test value from the original torch tune test + # assert torch.allclose( + # output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3 + # ) + + +class TestMultimodalModelDecoder: + @pytest.fixture(autouse=True) + def setup_class(self, decoder_config): + self.model_args = decoder_config + self.batch_size = 1 + self.decoder_embed_dim = self.model_args.decoder_embed_dim + self.vocab_size = self.model_args.vocab_size + self.seq_len = 128 + self.input = { + "tokens": torch.arange(self.batch_size * self.seq_len).reshape( + self.batch_size, self.seq_len + ), + "encoder_input": fixed_init_tensor( + (self.batch_size, self.seq_len, self.decoder_embed_dim), + min_val=-1, + max_val=1, + ), + "encoder_mask": None, + } + + @torch.no_grad() + def test_llama_mm_decoder(self): + model = MultimodalDecoder(self.model_args) + fixed_init_model(model, min_val=-1, max_val=1) + output = model(**self.input) + expected_shape = (self.batch_size, self.seq_len, self.vocab_size) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + # TODO: Need to ensure numerical stability before doing convergence test. + # output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is + # the test value from the original torch tune test + # assert torch.allclose( + # output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3 + # ) diff --git a/torchtitan/experiments/multimodal/tests/test_utils.py b/torchtitan/experiments/multimodal/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3817db8699966a8d848ad744ccd6b6dabb3836 --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/test_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import Optional, Union + +import torch +from torch import nn + + +def fixed_init_tensor( + shape: torch.Size, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: torch.dtype = torch.float, +): + """ + Utility for generating deterministic tensors of a given shape. In general stuff + like torch.ones, torch.eye, etc can result in trivial outputs. This utility + generates a range tensor [min_val, max_val) of a specified dtype, applies + a sine function if nonlinear=True, then reshapes to the appropriate shape. + """ + n_elements = math.prod(shape) + step_size = (max_val - min_val) / n_elements + x = torch.arange(min_val, max_val, step_size, dtype=dtype) + x = x.reshape(shape) + if nonlinear: + return torch.sin(x) + return x + + +@torch.no_grad +def fixed_init_model( + model: nn.Module, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: Optional[torch.dtype] = None, +): + """ + This utility initializes all parameters of a model deterministically using the + function fixed_init_tensor above. See that docstring for details of each parameter. + """ + for _, param in model.named_parameters(): + param.copy_( + fixed_init_tensor( + param.shape, + min_val=min_val, + max_val=max_val, + nonlinear=nonlinear, + dtype=param.dtype if dtype is None else dtype, + ) + ) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..887653ac0298369a04df9b791b9676bd7c6107c1 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -0,0 +1,40 @@ +## SimpleFSDP + +This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. + +### Enable SimpleFSDP Training + +```bash +CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32 +``` + +Note: The mixed precision training support is on-going. We set `training.mixed_precision_param` to `float32` for now and will remove it once the integration is completed. + +### Composability Support + +Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features: + +| Feature | Support | +| :--------: | :--------: | +|Meta Initialization| ✅ | +|Activation Checkpointing| ✅ | +|Mixed Precision Training| 🚧 | +|Tensor Parallelism| 🚧 | +|Context Parallelism| ✅ | +|Pipeline Parallelism| ✅ | +|Distributed Checkpointing| 🚧 | +|Float8 Training| ❌ | + + +### Citation + +If you find SimpleFSDP useful, please kindly consider citing the following paper: + +```latex +@article{zhang2024simplefsdp, + title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile}, + author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco}, + journal={arXiv preprint arXiv:2411.00284}, + year={2024} +} +``` diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768031b67289693c95faa4a21071230016672373 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9137fe4ce47e54f6c3a8a7ec409c25b2c0040014 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e12df1cd3fa60199d0c1e4895ccfbc4fac77af69 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21df1f66a6a2c80c648fe2aa57b2bc49b37e4c8f Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/model.py b/torchtitan/experiments/simple_fsdp/model.py new file mode 100644 index 0000000000000000000000000000000000000000..63104169b8fa14ed7032182c1ad08b782cd715fe --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.llama3 import Transformer, TransformerModelArgs +from .simple_fsdp import disable_data_parallel + + +class SimpleFSDPTransformer(Transformer): + def __init__(self, model_args: TransformerModelArgs): + super().__init__(model_args) + self.init_weights() + + def init_weights(self, *args, **kwargs): + with disable_data_parallel(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/simple_fsdp/tests/__init__.py b/torchtitan/experiments/simple_fsdp/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py new file mode 100644 index 0000000000000000000000000000000000000000..3c15ce573b9c65f9f26cefcbdbcd0f5b2f5c9713 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy + +import torch +from torch.distributed._composable.fsdp import fully_shard + +from torch.testing._internal.common_fsdp import FSDPTest + +from torchtitan.components.loss import cross_entropy_loss +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel + + +class TestSimpleFSDP(FSDPTest): + def init_test(self): + self.optimizer = torch.optim.Adam + self.loss_fn = cross_entropy_loss + data_parallel_shard_degree = -1 + if self.mode == "replicate": + self.dp_mesh_dim_names = ("dp_replicate",) + data_parallel_replicate_degree = self.world_size + elif self.mode == "fully_shard": + self.dp_mesh_dim_names = ("dp_shard_cp",) + data_parallel_replicate_degree = 1 + elif self.mode == "hybrid_shard": + self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + data_parallel_replicate_degree = self.world_size // 2 + else: + raise ValueError(f"Unsupported mode {mode}") + + self.parallel_dims = ParallelDims( + dp_shard=data_parallel_shard_degree, + dp_replicate=data_parallel_replicate_degree, + cp=1, + tp=1, + pp=1, + world_size=self.world_size, + enable_loss_parallel=True, + ) + self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda") + + def get_input(self): + inputs = torch.randn(8, 8).cuda() + labels = torch.randn(8, 8).cuda() + model = torch.nn.Linear(8, 8) + return model, inputs, labels + + def run_fsdp2(self, model, inputs, labels, epoch=20): + fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def run_simple_fsdp(self, model, inputs, labels, epoch=20): + model = data_parallel( + model, + device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + mode=self.mode, + ) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def test_replicate_convergence(self): + # unit test for replicate mode + self.mode = "replicate" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_replicate_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_replicate_loss in zip( + fsdp2_losses, simple_fsdp_replicate_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss) + + def test_fullyshard_convergence(self): + # unit test for fully_shard mode + self.mode = "fully_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_fullyshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_fullyshard_loss in zip( + fsdp2_losses, simple_fsdp_fullyshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss) + + def test_hybridshard_convergence(self): + # unit test for hybrid_shard mode + self.mode = "hybrid_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_hybridshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_hybridshard_loss in zip( + fsdp2_losses, simple_fsdp_hybridshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss) diff --git a/torchtitan/models/__pycache__/attention.cpython-312.pyc b/torchtitan/models/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4019fef0644b8e03933e75b242cef8f0999e9e72 Binary files /dev/null and b/torchtitan/models/__pycache__/attention.cpython-312.pyc differ diff --git a/torchtitan/models/__pycache__/norms.cpython-312.pyc b/torchtitan/models/__pycache__/norms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4078d8d97327acf24d7b51fd3bf1589526b0b44c Binary files /dev/null and b/torchtitan/models/__pycache__/norms.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea183c8bbfa8cfe2ed387c298e4940a2c6a890d1 --- /dev/null +++ b/torchtitan/models/llama3/__init__.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model import Transformer, TransformerModelArgs +from .parallelize_llama import parallelize_llama +from .pipeline_llama import pipeline_llama + +__all__ = [ + "parallelize_llama", + "pipeline_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] + + +llama3_configs = { + "debugmodel": TransformerModelArgs( + dim=256, n_layers=8, n_heads=16, rope_theta=500000 + ), + "8B": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": TransformerModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": TransformerModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +register_train_spec( + TrainSpec( + name="llama3", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6c915395dfd41a7bb4542079e8e9b79ebc55409 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b8116601e7f28fcf0f0c10cd807686ad21624c5 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..e956030a2d8e9676ecce3e6f8df22dd7dfed68e1 --- /dev/null +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -0,0 +1,74 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Llama 3 debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output" diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml new file mode 100644 index 0000000000000000000000000000000000000000..4532e451043597a51c9316d07c971fc77f0055c8 --- /dev/null +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -0,0 +1,63 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/original/tokenizer.model" +converters = "float8" + +[optimizer] +name = "AdamW" +lr = 8e-5 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps + +[training] +batch_size = 2 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = true +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_async_tensor_parallel = true +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true +filter_fqns = "output" diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml new file mode 100644 index 0000000000000000000000000000000000000000..64ef62ebfe96a6a8fbdbbc0aaa1849992769a1b8 --- /dev/null +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -0,0 +1,62 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 70B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "70B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/original/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 1.5e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps + +[training] +batch_size = 8 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output"