# SPDX-License-Identifier: Apache-2.0 import functools import math from dataclasses import dataclass import torch from vsa import video_sparse_attn from typing import Any VSA_TILE_SIZE = (4, 4, 4) @functools.lru_cache(maxsize=10) def get_tile_partition_indices( dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: torch.device, ) -> torch.LongTensor: T, H, W = dit_seq_shape ts, hs, ws = tile_size indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W) ls = [] for t in range(math.ceil(T / ts)): for h in range(math.ceil(H / hs)): for w in range(math.ceil(W / ws)): ls.append(indices[t * ts:min(t * ts + ts, T), h * hs:min(h * hs + hs, H), w * ws:min(w * ws + ws, W)].flatten()) index = torch.cat(ls, dim=0) return index @functools.lru_cache(maxsize=10) def get_reverse_tile_partition_indices( dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: torch.device, ) -> torch.LongTensor: return torch.argsort( get_tile_partition_indices(dit_seq_shape, tile_size, device)) @functools.lru_cache(maxsize=10) def construct_variable_block_sizes( dit_seq_shape: tuple[int, int, int], num_tiles: tuple[int, int, int], device: torch.device, ) -> torch.LongTensor: """ Compute the number of valid (non‑padded) tokens inside every (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order (t‑tile, h‑tile, w‑tile) that `rearrange` uses. Returns ------- torch.LongTensor # shape: [∏ full_window_size] """ # unpack t, h, w = dit_seq_shape ts_t, ts_h, ts_w = VSA_TILE_SIZE n_t, n_h, n_w = num_tiles def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: """Vector with the size of each tile along one dimension.""" sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device) # size of last (possibly partial) tile remainder = dim_len - (n_tiles - 1) * tile sizes[-1] = remainder if remainder > 0 else tile return sizes t_sizes = _sizes(t, ts_t, n_t) # [n_t] h_sizes = _sizes(h, ts_h, n_h) # [n_h] w_sizes = _sizes(w, ts_w, n_w) # [n_w] # broadcast‑multiply to get voxels per tile, then flatten block_sizes = ( t_sizes[:, None, None] # [n_t, 1, 1] * h_sizes[None, :, None] # [1, n_h, 1] * w_sizes[None, None, :] # [1, 1, n_w] ).reshape(-1) # [n_t * n_h * n_w] return block_sizes @functools.lru_cache(maxsize=10) def get_non_pad_index( variable_block_sizes: torch.LongTensor, max_block_size: int, ): n_win = variable_block_sizes.shape[0] device = variable_block_sizes.device starts_pad = torch.arange(n_win, device=device) * max_block_size index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] index_mask = torch.arange( max_block_size, device=device)[None, :] < variable_block_sizes[:, None] return index_pad[index_mask] @dataclass class VideoSparseAttentionMetadata(): current_timestep: int dit_seq_shape: list[int] VSA_sparsity: float num_tiles: list[int] total_seq_length: int tile_partition_indices: torch.LongTensor reverse_tile_partition_indices: torch.LongTensor variable_block_sizes: torch.LongTensor non_pad_index: torch.LongTensor def build( current_timestep: int, raw_latent_shape: tuple[int, int, int], patch_size: tuple[int, int, int], VSA_sparsity: float, device: torch.device, **kwargs: dict[str, Any], ) -> VideoSparseAttentionMetadata: patch_size = patch_size dit_seq_shape = (raw_latent_shape[0] // patch_size[0], raw_latent_shape[1] // patch_size[1], raw_latent_shape[2] // patch_size[2]) num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]), math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2])) total_seq_length = math.prod(dit_seq_shape) tile_partition_indices = get_tile_partition_indices( dit_seq_shape, VSA_TILE_SIZE, device) reverse_tile_partition_indices = get_reverse_tile_partition_indices( dit_seq_shape, VSA_TILE_SIZE, device) variable_block_sizes = construct_variable_block_sizes( dit_seq_shape, num_tiles, device) non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE)) return VideoSparseAttentionMetadata( current_timestep=current_timestep, dit_seq_shape=dit_seq_shape, # type: ignore VSA_sparsity=VSA_sparsity, # type: ignore num_tiles=num_tiles, # type: ignore total_seq_length=total_seq_length, # type: ignore tile_partition_indices=tile_partition_indices, # type: ignore reverse_tile_partition_indices=reverse_tile_partition_indices, variable_block_sizes=variable_block_sizes, non_pad_index=non_pad_index) class VideoSparseAttentionImpl(): def __init__( self, num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = "", **extra_impl_args, ) -> None: self.prefix = prefix def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) -> torch.Tensor: t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0] h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1] w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2] x_padded = torch.zeros( (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]), device=x.device, dtype=x.dtype) x_padded[:, non_pad_index] = x[:, tile_partition_indices] return x_padded def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) -> torch.Tensor: x = x[:, non_pad_index][:, reverse_tile_partition_indices] return x def preprocess_qkv( self, qkv: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata, ) -> torch.Tensor: return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices, attn_metadata.non_pad_index) def postprocess_output( self, output: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata, ) -> torch.Tensor: return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index) def forward( # type: ignore[override] self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata, ) -> torch.Tensor: query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() VSA_sparsity = attn_metadata.VSA_sparsity cur_topk = math.ceil( (1 - VSA_sparsity) * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE))) hidden_states = video_sparse_attn( query, key, value, variable_block_sizes=attn_metadata.variable_block_sizes, topk=cur_topk, block_size=VSA_TILE_SIZE, compress_attn_weight=None).transpose(1, 2) return hidden_states