|
|
|
from typing import Tuple, Union |
|
|
|
import torch |
|
|
|
|
|
def transformer_random_masking( |
|
x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Random mask patches per sample |
|
|
|
Parameters |
|
---------- |
|
x : token tensor (N, L, D) |
|
mask_ratio: float - ratio of image to mask |
|
constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks |
|
|
|
Returns |
|
------- |
|
x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D) |
|
mask : binary mask indicated masked tokens (1 where masked) (N, L) |
|
ind_restore : locations of masked tokens, needed for decoder |
|
""" |
|
|
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
|
|
if constant_noise is not None: |
|
noise = constant_noise |
|
else: |
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
shuffled_tokens = torch.argsort(noise, dim=1) |
|
ind_restore = torch.argsort(shuffled_tokens, dim=1) |
|
|
|
|
|
tokens_to_keep = shuffled_tokens[:, :len_keep] |
|
x_masked = torch.gather( |
|
x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D) |
|
) |
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
mask = torch.gather( |
|
mask, dim=1, index=ind_restore |
|
) |
|
|
|
return x_masked, mask, ind_restore |
|
|