File size: 8,172 Bytes
bb65ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# SPDX-License-Identifier: Apache-2.0
import functools
import math
from dataclasses import dataclass
import torch
from vsa import video_sparse_attn
from typing import Any

VSA_TILE_SIZE = (4, 4, 4)


@functools.lru_cache(maxsize=10)
def get_tile_partition_indices(

    dit_seq_shape: tuple[int, int, int],

    tile_size: tuple[int, int, int],

    device: torch.device,

) -> torch.LongTensor:
    T, H, W = dit_seq_shape
    ts, hs, ws = tile_size
    indices = torch.arange(T * H * W, device=device,
                           dtype=torch.long).reshape(T, H, W)
    ls = []
    for t in range(math.ceil(T / ts)):
        for h in range(math.ceil(H / hs)):
            for w in range(math.ceil(W / ws)):
                ls.append(indices[t * ts:min(t * ts + ts, T),
                                  h * hs:min(h * hs + hs, H),
                                  w * ws:min(w * ws + ws, W)].flatten())
    index = torch.cat(ls, dim=0)
    return index


@functools.lru_cache(maxsize=10)
def get_reverse_tile_partition_indices(

    dit_seq_shape: tuple[int, int, int],

    tile_size: tuple[int, int, int],

    device: torch.device,

) -> torch.LongTensor:
    return torch.argsort(
        get_tile_partition_indices(dit_seq_shape, tile_size, device))


@functools.lru_cache(maxsize=10)
def construct_variable_block_sizes(

    dit_seq_shape: tuple[int, int, int],

    num_tiles: tuple[int, int, int],

    device: torch.device,

) -> torch.LongTensor:
    """

    Compute the number of valid (non‑padded) tokens inside every

    (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order

    (t‑tile, h‑tile, w‑tile) that `rearrange` uses.



    Returns

    -------

    torch.LongTensor  # shape: [∏ full_window_size]

    """
    # unpack
    t, h, w = dit_seq_shape
    ts_t, ts_h, ts_w = VSA_TILE_SIZE
    n_t, n_h, n_w = num_tiles

    def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
        """Vector with the size of each tile along one dimension."""
        sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
        # size of last (possibly partial) tile
        remainder = dim_len - (n_tiles - 1) * tile
        sizes[-1] = remainder if remainder > 0 else tile
        return sizes

    t_sizes = _sizes(t, ts_t, n_t)  # [n_t]
    h_sizes = _sizes(h, ts_h, n_h)  # [n_h]
    w_sizes = _sizes(w, ts_w, n_w)  # [n_w]

    # broadcast‑multiply to get voxels per tile, then flatten
    block_sizes = (
        t_sizes[:, None, None]  # [n_t, 1,   1]
        * h_sizes[None, :, None]  # [1,   n_h, 1]
        * w_sizes[None, None, :]  # [1,   1,   n_w]
    ).reshape(-1)  # [n_t * n_h * n_w]

    return block_sizes


@functools.lru_cache(maxsize=10)
def get_non_pad_index(

    variable_block_sizes: torch.LongTensor,

    max_block_size: int,

):
    n_win = variable_block_sizes.shape[0]
    device = variable_block_sizes.device
    starts_pad = torch.arange(n_win, device=device) * max_block_size
    index_pad = starts_pad[:, None] + torch.arange(max_block_size,
                                                   device=device)[None, :]
    index_mask = torch.arange(
        max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
    return index_pad[index_mask]



@dataclass
class VideoSparseAttentionMetadata():
    current_timestep: int
    dit_seq_shape: list[int]
    VSA_sparsity: float
    num_tiles: list[int]
    total_seq_length: int
    tile_partition_indices: torch.LongTensor
    reverse_tile_partition_indices: torch.LongTensor
    variable_block_sizes: torch.LongTensor
    non_pad_index: torch.LongTensor


def build(

    current_timestep: int,

    raw_latent_shape: tuple[int, int, int],

    patch_size: tuple[int, int, int],

    VSA_sparsity: float,

    device: torch.device,

    **kwargs: dict[str, Any],

) -> VideoSparseAttentionMetadata:
    patch_size = patch_size
    dit_seq_shape = (raw_latent_shape[0] // patch_size[0],
                     raw_latent_shape[1] // patch_size[1],
                     raw_latent_shape[2] // patch_size[2])

    num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
                 math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
                 math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
    total_seq_length = math.prod(dit_seq_shape)

    tile_partition_indices = get_tile_partition_indices(
        dit_seq_shape, VSA_TILE_SIZE, device)
    reverse_tile_partition_indices = get_reverse_tile_partition_indices(
        dit_seq_shape, VSA_TILE_SIZE, device)
    variable_block_sizes = construct_variable_block_sizes(
        dit_seq_shape, num_tiles, device)
    non_pad_index = get_non_pad_index(variable_block_sizes,
                                      math.prod(VSA_TILE_SIZE))

    return VideoSparseAttentionMetadata(
        current_timestep=current_timestep,
        dit_seq_shape=dit_seq_shape,  # type: ignore
        VSA_sparsity=VSA_sparsity,  # type: ignore
        num_tiles=num_tiles,  # type: ignore
        total_seq_length=total_seq_length,  # type: ignore
        tile_partition_indices=tile_partition_indices,  # type: ignore
        reverse_tile_partition_indices=reverse_tile_partition_indices,
        variable_block_sizes=variable_block_sizes,
        non_pad_index=non_pad_index)



class VideoSparseAttentionImpl():

    def __init__(

        self,

        num_heads: int,

        head_size: int,

        causal: bool,

        softmax_scale: float,

        num_kv_heads: int | None = None,

        prefix: str = "",

        **extra_impl_args,

    ) -> None:
        self.prefix = prefix

    def tile(self, x: torch.Tensor, num_tiles: list[int],

             tile_partition_indices: torch.LongTensor,

             non_pad_index: torch.LongTensor) -> torch.Tensor:
        t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
        h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
        w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]

        x_padded = torch.zeros(
            (x.shape[0], t_padded_size * h_padded_size * w_padded_size,
             x.shape[-2], x.shape[-1]),
            device=x.device,
            dtype=x.dtype)
        x_padded[:, non_pad_index] = x[:, tile_partition_indices]
        return x_padded

    def untile(self, x: torch.Tensor,

               reverse_tile_partition_indices: torch.LongTensor,

               non_pad_index: torch.LongTensor) -> torch.Tensor:
        x = x[:, non_pad_index][:, reverse_tile_partition_indices]
        return x

    def preprocess_qkv(

        self,

        qkv: torch.Tensor,

        attn_metadata: VideoSparseAttentionMetadata,

    ) -> torch.Tensor:
        return self.tile(qkv, attn_metadata.num_tiles,
                         attn_metadata.tile_partition_indices,
                         attn_metadata.non_pad_index)

    def postprocess_output(

        self,

        output: torch.Tensor,

        attn_metadata: VideoSparseAttentionMetadata,

    ) -> torch.Tensor:
        return self.untile(output, attn_metadata.reverse_tile_partition_indices,
                           attn_metadata.non_pad_index)

    def forward(  # type: ignore[override]

        self,

        query: torch.Tensor,

        key: torch.Tensor,

        value: torch.Tensor,

        attn_metadata: VideoSparseAttentionMetadata,

    ) -> torch.Tensor:
        query = query.transpose(1, 2).contiguous()
        key = key.transpose(1, 2).contiguous()
        value = value.transpose(1, 2).contiguous()

        VSA_sparsity = attn_metadata.VSA_sparsity

        cur_topk = math.ceil(
            (1 - VSA_sparsity) *
            (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))

        hidden_states = video_sparse_attn(
            query,
            key,
            value,
            variable_block_sizes=attn_metadata.variable_block_sizes,
            topk=cur_topk,
            block_size=VSA_TILE_SIZE,
            compress_attn_weight=None).transpose(1, 2)

        return hidden_states