Spaces:
Sleeping
Sleeping
import itertools | |
import math | |
import warnings | |
from pathlib import Path | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import einsum, nn | |
class CROMAWrapper(nn.Module): | |
def __init__( | |
self, | |
weights_path: Path, | |
size="base", | |
modality="optical", | |
do_pool=True, | |
temporal_pooling: str = "mean", | |
): | |
super().__init__() | |
assert modality in ["SAR", "optical"] | |
if size == "base": | |
self.croma = PretrainedCROMA( | |
str(weights_path / "CROMA_base.pt"), size, modality=modality, image_resolution=120 | |
) | |
self.dim = 768 | |
elif size == "large": | |
self.croma = PretrainedCROMA( | |
str(weights_path / "CROMA_large.pt"), size, modality=modality, image_resolution=120 | |
) | |
self.dim = 1024 | |
else: | |
raise ValueError(f"size must be base or large, not {size}") | |
self.image_resolution = 120 | |
self.patch_size = 8 | |
self.grid_size = int(self.image_resolution / self.patch_size) | |
self.do_pool = do_pool | |
if temporal_pooling not in ["mean", "max"]: | |
raise ValueError( | |
f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
) | |
self.temporal_pooling = temporal_pooling | |
def resize(self, images): | |
images = F.interpolate( | |
images, | |
size=(self.image_resolution, self.image_resolution), | |
mode="bilinear", | |
align_corners=False, | |
) | |
return images | |
def preproccess(self, images): | |
images = rearrange(images, "b h w c -> b c h w") | |
assert images.shape[1] == 13 | |
# remove cirrus | |
remove_idx = 10 | |
images = torch.cat( | |
[images[:, :remove_idx, :, :], images[:, (remove_idx + 1) :, :, :]], dim=1 | |
) | |
assert images.shape[1] == 12 | |
return self.resize(images) # (bsz, 12, 120, 120) | |
def preproccess_s1(self, images): | |
images = rearrange(images, "b h w c -> b c h w") | |
assert images.shape[1] == 2 | |
return self.resize(images) # (bsz, 2, 120, 120) | |
def forward(self, s2=None, s1=None, months=None): | |
output_key = "optical_GAP" if self.do_pool else "optical_encodings" | |
if s1 is not None: | |
assert s2 is None, "joint s2 and s1 not implemented for CROMA" | |
if len(s1.shape) == 5: | |
outputs: List[torch.Tensor] = [] | |
for timestep in range(s1.shape[3]): | |
image = self.preproccess_s1(s1[:, :, :, timestep]) | |
outputs.append(self.croma(SAR_images=image)[output_key]) | |
outputs_t = torch.stack(outputs, dim=-1) # b h w d t | |
if self.temporal_pooling == "mean": | |
return outputs_t.mean(dim=-1) | |
else: | |
return torch.amax(outputs_t, dim=-1) | |
else: | |
s1 = self.preproccess_s1(s1) | |
return self.croma(SAR_images=s1)[output_key] | |
else: | |
# just S2 | |
if len(s2.shape) == 5: | |
outputs: List[torch.Tensor] = [] | |
for timestep in range(s2.shape[3]): | |
image = self.preproccess(s2[:, :, :, timestep]) | |
outputs.append(self.croma(optical_images=image)[output_key]) | |
outputs_t = torch.stack(outputs, dim=-1) # b h w d t | |
if self.temporal_pooling == "mean": | |
return outputs_t.mean(dim=-1) | |
else: | |
return torch.amax(outputs_t, dim=-1) | |
else: | |
s2 = self.preproccess(s2) | |
return self.croma(optical_images=s2)[output_key] | |
class PretrainedCROMA(nn.Module): | |
def __init__( | |
self, pretrained_path="CROMA_base.pt", size="base", modality="both", image_resolution=120 | |
): | |
""" | |
NOTE: image_resolution is not the spatial, spectral, or temporal resolution. It is the height and width of the image, in pixels. | |
E.g., CROMA was pretrained on 120x120px images, hence image_resolution is 120 by default | |
""" | |
super().__init__() | |
# check types | |
assert isinstance(pretrained_path, str) | |
assert isinstance(size, str) | |
assert isinstance(modality, str) | |
assert isinstance(image_resolution, int) | |
# check values | |
assert size in ["base", "large"], f"size must be either base or large, not {size}" | |
assert ( | |
image_resolution % 8 == 0 | |
), f"image_resolution must be a multiple of 8, not {image_resolution}" | |
assert modality in [ | |
"both", | |
"SAR", | |
"optical", | |
], f"modality must be either both, SAR, or optical, not {modality}" | |
# warn the user if the path contains a different size than the size parameter | |
if size == "base" and "large" in pretrained_path: | |
warnings.warn( | |
"The size is set to base, but the word large appears in the pretrained path!" | |
) | |
elif size == "large" and "base" in pretrained_path: | |
warnings.warn( | |
"The size is set to large, but the word base appears in the pretrained path!" | |
) | |
if size == "base": | |
self.encoder_dim = 768 | |
self.encoder_depth = 12 | |
self.num_heads = 16 | |
self.patch_size = 8 | |
else: | |
# large by default | |
self.encoder_dim = 1024 | |
self.encoder_depth = 24 | |
self.num_heads = 16 | |
self.patch_size = 8 | |
self.modality = modality | |
self.num_patches = int((image_resolution / 8) ** 2) | |
self.s1_channels = 2 # fixed at 2 SAR backscatter channels | |
self.s2_channels = 12 # fixed at 12 multispectral optical channels | |
self.attn_bias = get_2dalibi(num_heads=self.num_heads, num_patches=self.num_patches) | |
if modality in ["SAR", "both"]: | |
print("Initializing SAR encoder") | |
self.s1_encoder = ViT( | |
dim=self.encoder_dim, | |
depth=int(self.encoder_depth / 2), | |
in_channels=self.s1_channels, | |
) | |
self.GAP_FFN_s1 = nn.Sequential( | |
nn.LayerNorm(self.encoder_dim), | |
nn.Linear( | |
self.encoder_dim, int(4 * self.encoder_dim) | |
), # (BSZ, num_patches, inner_dim) | |
nn.GELU(), # (BSZ, num_patches, inner_dim) | |
nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim) | |
) | |
# load weights | |
self.s1_encoder.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["s1_encoder"] | |
) | |
self.GAP_FFN_s1.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["s1_GAP_FFN"] | |
) | |
if modality in ["optical", "both"]: | |
print("Initializing optical encoder") | |
self.s2_encoder = ViT( | |
dim=self.encoder_dim, depth=self.encoder_depth, in_channels=self.s2_channels | |
) | |
self.GAP_FFN_s2 = nn.Sequential( | |
nn.LayerNorm(self.encoder_dim), | |
nn.Linear( | |
self.encoder_dim, int(4 * self.encoder_dim) | |
), # (BSZ, num_patches, inner_dim) | |
nn.GELU(), # (BSZ, num_patches, inner_dim) | |
nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim) | |
) | |
# load weights | |
self.s2_encoder.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["s2_encoder"] | |
) | |
self.GAP_FFN_s2.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["s2_GAP_FFN"] | |
) | |
if modality == "both": | |
print("Initializing joint SAR-optical encoder") | |
self.cross_encoder = BaseTransformerCrossAttn( | |
dim=self.encoder_dim, | |
depth=int(self.encoder_depth / 2), | |
num_heads=self.num_heads, | |
) | |
# load weights | |
self.cross_encoder.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["joint_encoder"] | |
) | |
def forward(self, SAR_images=None, optical_images=None): | |
return_dict = {} | |
if self.modality in ["SAR", "both"]: | |
assert ( | |
SAR_images is not None | |
), f"Modality is set to {self.modality}, but SAR_images are None" | |
SAR_encodings = self.s1_encoder( | |
imgs=SAR_images, attn_bias=self.attn_bias.to(SAR_images.device) | |
) # (bsz, num_patches, encoder_dim) | |
SAR_GAP = self.GAP_FFN_s1(SAR_encodings.mean(dim=1)) # (bsz, encoder_dim) | |
return_dict["SAR_encodings"] = SAR_encodings | |
return_dict["SAR_GAP"] = SAR_GAP | |
if self.modality in ["optical", "both"]: | |
assert ( | |
optical_images is not None | |
), f"Modality is set to {self.modality}, but optical_images are None" | |
optical_encodings = self.s2_encoder( | |
imgs=optical_images, attn_bias=self.attn_bias.to(optical_images.device) | |
) # (bsz, num_patches, encoder_dim) | |
optical_GAP = self.GAP_FFN_s2(optical_encodings.mean(dim=1)) # (bsz, encoder_dim) | |
return_dict["optical_encodings"] = optical_encodings | |
return_dict["optical_GAP"] = optical_GAP | |
if self.modality == "both": | |
joint_encodings = self.cross_encoder( | |
x=SAR_encodings, | |
context=optical_encodings, | |
relative_position_bias=self.attn_bias.to(optical_images.device), | |
) # (bsz, num_patches, encoder_dim) | |
joint_GAP = joint_encodings.mean(dim=1) # (bsz, encoder_dim) | |
return_dict["joint_encodings"] = joint_encodings | |
return_dict["joint_GAP"] = joint_GAP | |
return return_dict | |
def get_2dalibi(num_heads, num_patches): | |
# inspired by: https://github.com/ofirpress/attention_with_linear_biases | |
points = list( | |
itertools.product(range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches)))) | |
) | |
def get_slopes(n): | |
def get_slopes_power_of_2(n): | |
start = 2 ** (-(2 ** -(math.log2(n) - 3))) | |
ratio = start | |
return [start * ratio**i for i in range(n)] | |
if math.log2(n).is_integer(): | |
return get_slopes_power_of_2(n) | |
else: | |
closest_power_of_2 = 2 ** math.floor(math.log2(n)) | |
return ( | |
get_slopes_power_of_2(closest_power_of_2) | |
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] | |
) | |
slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1) | |
idxs = [] | |
for p1 in points: | |
for p2 in points: | |
dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) | |
idxs.append(dist * slopes * -1) | |
all_bias = torch.cat(idxs, dim=1) | |
return all_bias.view(1, num_heads, num_patches, num_patches) | |
class FFN(nn.Module): | |
def __init__( | |
self, | |
dim, | |
mult=4, | |
dropout=0.0, | |
): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
self.net = nn.Sequential( | |
nn.Linear(dim, inner_dim), # (BSZ, num_patches, inner_dim) | |
nn.GELU(), # (BSZ, num_patches, inner_dim) | |
nn.Dropout(dropout), # (BSZ, num_patches, inner_dim) | |
nn.Linear(inner_dim, dim), # (BSZ, num_patches, dim) | |
) | |
self.input_norm = nn.LayerNorm(dim) | |
def forward(self, x): | |
x = self.input_norm(x) # (BSZ, num_patches, dim) | |
return self.net(x) # (BSZ, num_patches, dim) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
dropout=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads" | |
dim_head = int(dim / num_heads) | |
self.scale = dim_head**-0.5 | |
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) | |
self.to_out = nn.Linear(dim, dim) | |
self.input_norm = nn.LayerNorm(dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, relative_position_bias): | |
x = self.input_norm(x) # (BSZ, num_patches, dim) | |
q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (BSZ, num_patches, dim) | |
q, k, v = map( | |
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) | |
) # (BSZ, num_heads, num_patches, dim_head) | |
attention_scores = ( | |
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
) # (BSZ, num_heads, num_patches, num_patches) | |
attention_scores = ( | |
attention_scores + relative_position_bias | |
) # (BSZ, num_heads, num_patches, num_patches) | |
attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches) | |
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches) | |
out = einsum( | |
"b h i j, b h j d -> b h i d", attn, v | |
) # (BSZ, num_heads, num_patches, dim_head) | |
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim) | |
return self.to_out(out) # (BSZ, num_patches, dim) | |
class CrossAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
dropout=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads" | |
dim_head = int(dim / num_heads) | |
self.scale = dim_head**-0.5 | |
self.to_q = nn.Linear(dim, dim, bias=False) | |
self.to_k = nn.Linear(dim, dim, bias=False) | |
self.to_v = nn.Linear(dim, dim, bias=False) | |
self.to_out = nn.Linear(dim, dim) | |
self.input_norm = nn.LayerNorm(dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, context, relative_position_bias): | |
x = self.input_norm(x) # (BSZ, num_patches, dim) | |
context = self.input_norm(context) # (BSZ, num_patches, dim) | |
q = self.to_q(x) # (BSZ, num_patches, dim) | |
k = self.to_k(context) # (BSZ, num_patches, dim) | |
v = self.to_v(context) # (BSZ, num_patches, dim) | |
q, k, v = map( | |
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) | |
) # (BSZ, num_heads, num_patches, dim_head) | |
attention_scores = ( | |
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
) # (BSZ, num_heads, num_patches, num_patches) | |
attention_scores = ( | |
attention_scores + relative_position_bias | |
) # (BSZ, num_heads, num_patches, num_patches) | |
attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches) | |
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches) | |
out = einsum( | |
"b h i j, b h j d -> b h i d", attn, v | |
) # (BSZ, num_heads, num_patches, dim_head) | |
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim) | |
return self.to_out(out) # (BSZ, num_patches, dim) | |
class BaseTransformer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
depth, | |
num_heads=8, | |
attn_dropout=0.0, | |
ff_dropout=0.0, | |
ff_mult=4, | |
final_norm=True, | |
): | |
super().__init__() | |
self.final_norm = final_norm | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), | |
] | |
) | |
) | |
if self.final_norm: | |
self.norm_out = nn.LayerNorm(dim) | |
def forward(self, x, relative_position_bias=False): | |
for self_attn, ffn in self.layers: | |
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim) | |
x = ffn(x) + x # (BSZ, num_patches, dim) | |
if self.final_norm: | |
return self.norm_out(x) | |
else: | |
return x | |
class BaseTransformerCrossAttn(nn.Module): | |
def __init__( | |
self, | |
dim, | |
depth, | |
num_heads=8, | |
attn_dropout=0.0, | |
ff_dropout=0.0, | |
ff_mult=4, | |
): | |
super().__init__() | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
CrossAttention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), | |
] | |
) | |
) | |
self.norm_out = nn.LayerNorm(dim) | |
def forward(self, x, context, relative_position_bias): | |
for self_attn, cross_attn, ffn in self.layers: | |
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim) | |
x = cross_attn(x, context, relative_position_bias) + x # (BSZ, num_patches, dim) | |
x = ffn(x) + x # (BSZ, num_patches, dim) | |
x = self.norm_out(x) | |
return x # (BSZ, num_patches, dim) | |
class ViT(nn.Module): | |
def __init__(self, dim, depth, in_channels): | |
super().__init__() | |
self.depth = depth | |
self.in_channels = in_channels | |
self.dim = dim | |
self.num_heads = 16 # always 16, for base and large models | |
self.patch_size = 8 # always 8, for base and large models | |
pixels_per_patch = int(self.patch_size * self.patch_size * in_channels) | |
self.linear_input = nn.Linear(pixels_per_patch, self.dim) | |
self.transformer = BaseTransformer( | |
dim=self.dim, | |
depth=self.depth, | |
num_heads=self.num_heads, | |
) | |
def forward(self, imgs, attn_bias): | |
x = rearrange( | |
imgs, "b c (h i) (w j) -> b (h w) (c i j)", i=self.patch_size, j=self.patch_size | |
) | |
# x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size) | |
x = self.linear_input(x) # (bsz, num_patches, dim) | |
x = self.transformer(x, relative_position_bias=attn_bias) | |
return x | |