Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-License-Identifier: Apache-2.0 | |
import functools | |
import math | |
from dataclasses import dataclass | |
import torch | |
from vsa import video_sparse_attn | |
from typing import Any | |
VSA_TILE_SIZE = (4, 4, 4) | |
def get_tile_partition_indices( | |
dit_seq_shape: tuple[int, int, int], | |
tile_size: tuple[int, int, int], | |
device: torch.device, | |
) -> torch.LongTensor: | |
T, H, W = dit_seq_shape | |
ts, hs, ws = tile_size | |
indices = torch.arange(T * H * W, device=device, | |
dtype=torch.long).reshape(T, H, W) | |
ls = [] | |
for t in range(math.ceil(T / ts)): | |
for h in range(math.ceil(H / hs)): | |
for w in range(math.ceil(W / ws)): | |
ls.append(indices[t * ts:min(t * ts + ts, T), | |
h * hs:min(h * hs + hs, H), | |
w * ws:min(w * ws + ws, W)].flatten()) | |
index = torch.cat(ls, dim=0) | |
return index | |
def get_reverse_tile_partition_indices( | |
dit_seq_shape: tuple[int, int, int], | |
tile_size: tuple[int, int, int], | |
device: torch.device, | |
) -> torch.LongTensor: | |
return torch.argsort( | |
get_tile_partition_indices(dit_seq_shape, tile_size, device)) | |
def construct_variable_block_sizes( | |
dit_seq_shape: tuple[int, int, int], | |
num_tiles: tuple[int, int, int], | |
device: torch.device, | |
) -> torch.LongTensor: | |
""" | |
Compute the number of valid (non‑padded) tokens inside every | |
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order | |
(t‑tile, h‑tile, w‑tile) that `rearrange` uses. | |
Returns | |
------- | |
torch.LongTensor # shape: [∏ full_window_size] | |
""" | |
# unpack | |
t, h, w = dit_seq_shape | |
ts_t, ts_h, ts_w = VSA_TILE_SIZE | |
n_t, n_h, n_w = num_tiles | |
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: | |
"""Vector with the size of each tile along one dimension.""" | |
sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device) | |
# size of last (possibly partial) tile | |
remainder = dim_len - (n_tiles - 1) * tile | |
sizes[-1] = remainder if remainder > 0 else tile | |
return sizes | |
t_sizes = _sizes(t, ts_t, n_t) # [n_t] | |
h_sizes = _sizes(h, ts_h, n_h) # [n_h] | |
w_sizes = _sizes(w, ts_w, n_w) # [n_w] | |
# broadcast‑multiply to get voxels per tile, then flatten | |
block_sizes = ( | |
t_sizes[:, None, None] # [n_t, 1, 1] | |
* h_sizes[None, :, None] # [1, n_h, 1] | |
* w_sizes[None, None, :] # [1, 1, n_w] | |
).reshape(-1) # [n_t * n_h * n_w] | |
return block_sizes | |
def get_non_pad_index( | |
variable_block_sizes: torch.LongTensor, | |
max_block_size: int, | |
): | |
n_win = variable_block_sizes.shape[0] | |
device = variable_block_sizes.device | |
starts_pad = torch.arange(n_win, device=device) * max_block_size | |
index_pad = starts_pad[:, None] + torch.arange(max_block_size, | |
device=device)[None, :] | |
index_mask = torch.arange( | |
max_block_size, device=device)[None, :] < variable_block_sizes[:, None] | |
return index_pad[index_mask] | |
class VideoSparseAttentionMetadata(): | |
current_timestep: int | |
dit_seq_shape: list[int] | |
VSA_sparsity: float | |
num_tiles: list[int] | |
total_seq_length: int | |
tile_partition_indices: torch.LongTensor | |
reverse_tile_partition_indices: torch.LongTensor | |
variable_block_sizes: torch.LongTensor | |
non_pad_index: torch.LongTensor | |
def build( | |
current_timestep: int, | |
raw_latent_shape: tuple[int, int, int], | |
patch_size: tuple[int, int, int], | |
VSA_sparsity: float, | |
device: torch.device, | |
**kwargs: dict[str, Any], | |
) -> VideoSparseAttentionMetadata: | |
patch_size = patch_size | |
dit_seq_shape = (raw_latent_shape[0] // patch_size[0], | |
raw_latent_shape[1] // patch_size[1], | |
raw_latent_shape[2] // patch_size[2]) | |
num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), | |
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]), | |
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2])) | |
total_seq_length = math.prod(dit_seq_shape) | |
tile_partition_indices = get_tile_partition_indices( | |
dit_seq_shape, VSA_TILE_SIZE, device) | |
reverse_tile_partition_indices = get_reverse_tile_partition_indices( | |
dit_seq_shape, VSA_TILE_SIZE, device) | |
variable_block_sizes = construct_variable_block_sizes( | |
dit_seq_shape, num_tiles, device) | |
non_pad_index = get_non_pad_index(variable_block_sizes, | |
math.prod(VSA_TILE_SIZE)) | |
return VideoSparseAttentionMetadata( | |
current_timestep=current_timestep, | |
dit_seq_shape=dit_seq_shape, # type: ignore | |
VSA_sparsity=VSA_sparsity, # type: ignore | |
num_tiles=num_tiles, # type: ignore | |
total_seq_length=total_seq_length, # type: ignore | |
tile_partition_indices=tile_partition_indices, # type: ignore | |
reverse_tile_partition_indices=reverse_tile_partition_indices, | |
variable_block_sizes=variable_block_sizes, | |
non_pad_index=non_pad_index) | |
class VideoSparseAttentionImpl(): | |
def __init__( | |
self, | |
num_heads: int, | |
head_size: int, | |
causal: bool, | |
softmax_scale: float, | |
num_kv_heads: int | None = None, | |
prefix: str = "", | |
**extra_impl_args, | |
) -> None: | |
self.prefix = prefix | |
def tile(self, x: torch.Tensor, num_tiles: list[int], | |
tile_partition_indices: torch.LongTensor, | |
non_pad_index: torch.LongTensor) -> torch.Tensor: | |
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0] | |
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1] | |
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2] | |
x_padded = torch.zeros( | |
(x.shape[0], t_padded_size * h_padded_size * w_padded_size, | |
x.shape[-2], x.shape[-1]), | |
device=x.device, | |
dtype=x.dtype) | |
x_padded[:, non_pad_index] = x[:, tile_partition_indices] | |
return x_padded | |
def untile(self, x: torch.Tensor, | |
reverse_tile_partition_indices: torch.LongTensor, | |
non_pad_index: torch.LongTensor) -> torch.Tensor: | |
x = x[:, non_pad_index][:, reverse_tile_partition_indices] | |
return x | |
def preprocess_qkv( | |
self, | |
qkv: torch.Tensor, | |
attn_metadata: VideoSparseAttentionMetadata, | |
) -> torch.Tensor: | |
return self.tile(qkv, attn_metadata.num_tiles, | |
attn_metadata.tile_partition_indices, | |
attn_metadata.non_pad_index) | |
def postprocess_output( | |
self, | |
output: torch.Tensor, | |
attn_metadata: VideoSparseAttentionMetadata, | |
) -> torch.Tensor: | |
return self.untile(output, attn_metadata.reverse_tile_partition_indices, | |
attn_metadata.non_pad_index) | |
def forward( # type: ignore[override] | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_metadata: VideoSparseAttentionMetadata, | |
) -> torch.Tensor: | |
query = query.transpose(1, 2).contiguous() | |
key = key.transpose(1, 2).contiguous() | |
value = value.transpose(1, 2).contiguous() | |
VSA_sparsity = attn_metadata.VSA_sparsity | |
cur_topk = math.ceil( | |
(1 - VSA_sparsity) * | |
(attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE))) | |
hidden_states = video_sparse_attn( | |
query, | |
key, | |
value, | |
variable_block_sizes=attn_metadata.variable_block_sizes, | |
topk=cur_topk, | |
block_size=VSA_TILE_SIZE, | |
compress_attn_weight=None).transpose(1, 2) | |
return hidden_states | |