|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, ClassVar, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.attention.flex_attention import ( |
|
BlockMask, |
|
create_block_mask, |
|
flex_attention, |
|
) |
|
|
|
|
|
class FlexAttention(torch.nn.Module): |
|
|
|
|
|
flex_attn: ClassVar[Callable] = torch.compile( |
|
flex_attention, mode="max-autotune-no-cudagraphs" |
|
) |
|
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) |
|
used_attn_mask_types: ClassVar[set[str]] = set() |
|
|
|
|
|
|
|
|
|
|
|
block_masks: ClassVar[dict[str, BlockMask]] = {} |
|
|
|
|
|
attn_mask_type: str |
|
|
|
def __init__(self, attn_mask_type: str) -> None: |
|
super().__init__() |
|
if attn_mask_type not in ["causal", "block_causal"]: |
|
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") |
|
self.attn_mask_type = attn_mask_type |
|
FlexAttention.used_attn_mask_types.add(attn_mask_type) |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
|
) -> torch.Tensor: |
|
block_mask = FlexAttention.block_masks[self.attn_mask_type] |
|
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) |
|
|
|
@staticmethod |
|
def _get_causal_mask_fn() -> Callable: |
|
def causal_mask(b, h, q_idx, kv_idx): |
|
return q_idx >= kv_idx |
|
|
|
return causal_mask |
|
|
|
@staticmethod |
|
def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable: |
|
|
|
mask = batch == eos_id |
|
mask[:, -1] = True |
|
acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) |
|
seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) |
|
seq_idx[:, 1:] = acc_mask[:, :-1] |
|
|
|
def block_causal_mask(b, h, q_idx, kv_idx): |
|
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) |
|
|
|
return block_causal_mask |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: |
|
|
|
for attn_mask_type in FlexAttention.used_attn_mask_types: |
|
match attn_mask_type: |
|
case "causal": |
|
if FlexAttention.block_masks.get(attn_mask_type, None) is not None: |
|
continue |
|
|
|
|
|
batch_dimension = 1 |
|
mask_fn = FlexAttention._get_causal_mask_fn() |
|
case "block_causal": |
|
if eos_id is None: |
|
raise RuntimeError( |
|
"eos_id must be provided for block_causal mask." |
|
) |
|
batch_dimension = batch.shape[0] |
|
mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id) |
|
case _: |
|
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") |
|
|
|
seq_len = batch.shape[1] |
|
block_mask = FlexAttention.compiled_create_block_mask( |
|
mask_fn, batch_dimension, None, seq_len, seq_len |
|
) |
|
FlexAttention.block_masks[attn_mask_type] = block_mask |
|
|
|
|
|
class ScaledDotProductAttention(torch.nn.Module): |
|
def __init__(self, attn_mask_type: str) -> None: |
|
super().__init__() |
|
if attn_mask_type != "causal": |
|
raise ValueError( |
|
"TorchTitan with SDPA currently only supports causal mask." |
|
) |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
|
) -> torch.Tensor: |
|
return F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
|
|
|
def build_attention(use_flex_attn: bool, attn_mask_type: str): |
|
if use_flex_attn: |
|
return FlexAttention(attn_mask_type) |
|
else: |
|
return ScaledDotProductAttention(attn_mask_type) |
|
|
|
|
|
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: |
|
FlexAttention.init_attention_mask(batch, eos_id) |
|
|