zaydzuhri's picture
Add files using upload-large-folder tool
75b6530 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
from typing import Callable, ClassVar, Optional
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
class FlexAttention(torch.nn.Module):
# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[str]] = set()
# Attention mask type to the created BlockMask.
# This allows us to keep track the created block masks for each
# new batch. We will use this to update the block mask when a
# new batch is created. This also allows user to create different
# block masks for different layers.
block_masks: ClassVar[dict[str, BlockMask]] = {}
# Instance variables.
attn_mask_type: str
def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
FlexAttention.used_attn_mask_types.add(attn_mask_type)
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
block_mask = FlexAttention.block_masks[self.attn_mask_type]
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask)
@staticmethod
def _get_causal_mask_fn() -> Callable:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
return causal_mask
@staticmethod
def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable:
# batch is [b, s, h, d] shape
mask = batch == eos_id
mask[:, -1] = True
acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1)
seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32)
seq_idx[:, 1:] = acc_mask[:, :-1]
def block_causal_mask(b, h, q_idx, kv_idx):
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
return block_causal_mask
@staticmethod
@torch.no_grad()
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
# batch is [b, s, h, d] shape
for attn_mask_type in FlexAttention.used_attn_mask_types:
match attn_mask_type:
case "causal":
if FlexAttention.block_masks.get(attn_mask_type, None) is not None:
continue
# We don't care about batch dimension --
# all samples have the same lower triangle mask.
batch_dimension = 1
mask_fn = FlexAttention._get_causal_mask_fn()
case "block_causal":
if eos_id is None:
raise RuntimeError(
"eos_id must be provided for block_causal mask."
)
batch_dimension = batch.shape[0]
mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id)
case _:
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")
seq_len = batch.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
mask_fn, batch_dimension, None, seq_len, seq_len
)
FlexAttention.block_masks[attn_mask_type] = block_mask
class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
def build_attention(use_flex_attn: bool, attn_mask_type: str):
if use_flex_attn:
return FlexAttention(attn_mask_type)
else:
return ScaledDotProductAttention(attn_mask_type)
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
FlexAttention.init_attention_mask(batch, eos_id)