openfree's picture
Deploy from GitHub repository
b20c769 verified
import itertools
import math
import warnings
from pathlib import Path
from typing import List
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
class CROMAWrapper(nn.Module):
def __init__(
self,
weights_path: Path,
size="base",
modality="optical",
do_pool=True,
temporal_pooling: str = "mean",
):
super().__init__()
assert modality in ["SAR", "optical"]
if size == "base":
self.croma = PretrainedCROMA(
str(weights_path / "CROMA_base.pt"), size, modality=modality, image_resolution=120
)
self.dim = 768
elif size == "large":
self.croma = PretrainedCROMA(
str(weights_path / "CROMA_large.pt"), size, modality=modality, image_resolution=120
)
self.dim = 1024
else:
raise ValueError(f"size must be base or large, not {size}")
self.image_resolution = 120
self.patch_size = 8
self.grid_size = int(self.image_resolution / self.patch_size)
self.do_pool = do_pool
if temporal_pooling not in ["mean", "max"]:
raise ValueError(
f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}"
)
self.temporal_pooling = temporal_pooling
def resize(self, images):
images = F.interpolate(
images,
size=(self.image_resolution, self.image_resolution),
mode="bilinear",
align_corners=False,
)
return images
def preproccess(self, images):
images = rearrange(images, "b h w c -> b c h w")
assert images.shape[1] == 13
# remove cirrus
remove_idx = 10
images = torch.cat(
[images[:, :remove_idx, :, :], images[:, (remove_idx + 1) :, :, :]], dim=1
)
assert images.shape[1] == 12
return self.resize(images) # (bsz, 12, 120, 120)
def preproccess_s1(self, images):
images = rearrange(images, "b h w c -> b c h w")
assert images.shape[1] == 2
return self.resize(images) # (bsz, 2, 120, 120)
def forward(self, s2=None, s1=None, months=None):
output_key = "optical_GAP" if self.do_pool else "optical_encodings"
if s1 is not None:
assert s2 is None, "joint s2 and s1 not implemented for CROMA"
if len(s1.shape) == 5:
outputs: List[torch.Tensor] = []
for timestep in range(s1.shape[3]):
image = self.preproccess_s1(s1[:, :, :, timestep])
outputs.append(self.croma(SAR_images=image)[output_key])
outputs_t = torch.stack(outputs, dim=-1) # b h w d t
if self.temporal_pooling == "mean":
return outputs_t.mean(dim=-1)
else:
return torch.amax(outputs_t, dim=-1)
else:
s1 = self.preproccess_s1(s1)
return self.croma(SAR_images=s1)[output_key]
else:
# just S2
if len(s2.shape) == 5:
outputs: List[torch.Tensor] = []
for timestep in range(s2.shape[3]):
image = self.preproccess(s2[:, :, :, timestep])
outputs.append(self.croma(optical_images=image)[output_key])
outputs_t = torch.stack(outputs, dim=-1) # b h w d t
if self.temporal_pooling == "mean":
return outputs_t.mean(dim=-1)
else:
return torch.amax(outputs_t, dim=-1)
else:
s2 = self.preproccess(s2)
return self.croma(optical_images=s2)[output_key]
class PretrainedCROMA(nn.Module):
def __init__(
self, pretrained_path="CROMA_base.pt", size="base", modality="both", image_resolution=120
):
"""
NOTE: image_resolution is not the spatial, spectral, or temporal resolution. It is the height and width of the image, in pixels.
E.g., CROMA was pretrained on 120x120px images, hence image_resolution is 120 by default
"""
super().__init__()
# check types
assert isinstance(pretrained_path, str)
assert isinstance(size, str)
assert isinstance(modality, str)
assert isinstance(image_resolution, int)
# check values
assert size in ["base", "large"], f"size must be either base or large, not {size}"
assert (
image_resolution % 8 == 0
), f"image_resolution must be a multiple of 8, not {image_resolution}"
assert modality in [
"both",
"SAR",
"optical",
], f"modality must be either both, SAR, or optical, not {modality}"
# warn the user if the path contains a different size than the size parameter
if size == "base" and "large" in pretrained_path:
warnings.warn(
"The size is set to base, but the word large appears in the pretrained path!"
)
elif size == "large" and "base" in pretrained_path:
warnings.warn(
"The size is set to large, but the word base appears in the pretrained path!"
)
if size == "base":
self.encoder_dim = 768
self.encoder_depth = 12
self.num_heads = 16
self.patch_size = 8
else:
# large by default
self.encoder_dim = 1024
self.encoder_depth = 24
self.num_heads = 16
self.patch_size = 8
self.modality = modality
self.num_patches = int((image_resolution / 8) ** 2)
self.s1_channels = 2 # fixed at 2 SAR backscatter channels
self.s2_channels = 12 # fixed at 12 multispectral optical channels
self.attn_bias = get_2dalibi(num_heads=self.num_heads, num_patches=self.num_patches)
if modality in ["SAR", "both"]:
print("Initializing SAR encoder")
self.s1_encoder = ViT(
dim=self.encoder_dim,
depth=int(self.encoder_depth / 2),
in_channels=self.s1_channels,
)
self.GAP_FFN_s1 = nn.Sequential(
nn.LayerNorm(self.encoder_dim),
nn.Linear(
self.encoder_dim, int(4 * self.encoder_dim)
), # (BSZ, num_patches, inner_dim)
nn.GELU(), # (BSZ, num_patches, inner_dim)
nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim)
)
# load weights
self.s1_encoder.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["s1_encoder"]
)
self.GAP_FFN_s1.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["s1_GAP_FFN"]
)
if modality in ["optical", "both"]:
print("Initializing optical encoder")
self.s2_encoder = ViT(
dim=self.encoder_dim, depth=self.encoder_depth, in_channels=self.s2_channels
)
self.GAP_FFN_s2 = nn.Sequential(
nn.LayerNorm(self.encoder_dim),
nn.Linear(
self.encoder_dim, int(4 * self.encoder_dim)
), # (BSZ, num_patches, inner_dim)
nn.GELU(), # (BSZ, num_patches, inner_dim)
nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim)
)
# load weights
self.s2_encoder.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["s2_encoder"]
)
self.GAP_FFN_s2.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["s2_GAP_FFN"]
)
if modality == "both":
print("Initializing joint SAR-optical encoder")
self.cross_encoder = BaseTransformerCrossAttn(
dim=self.encoder_dim,
depth=int(self.encoder_depth / 2),
num_heads=self.num_heads,
)
# load weights
self.cross_encoder.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["joint_encoder"]
)
def forward(self, SAR_images=None, optical_images=None):
return_dict = {}
if self.modality in ["SAR", "both"]:
assert (
SAR_images is not None
), f"Modality is set to {self.modality}, but SAR_images are None"
SAR_encodings = self.s1_encoder(
imgs=SAR_images, attn_bias=self.attn_bias.to(SAR_images.device)
) # (bsz, num_patches, encoder_dim)
SAR_GAP = self.GAP_FFN_s1(SAR_encodings.mean(dim=1)) # (bsz, encoder_dim)
return_dict["SAR_encodings"] = SAR_encodings
return_dict["SAR_GAP"] = SAR_GAP
if self.modality in ["optical", "both"]:
assert (
optical_images is not None
), f"Modality is set to {self.modality}, but optical_images are None"
optical_encodings = self.s2_encoder(
imgs=optical_images, attn_bias=self.attn_bias.to(optical_images.device)
) # (bsz, num_patches, encoder_dim)
optical_GAP = self.GAP_FFN_s2(optical_encodings.mean(dim=1)) # (bsz, encoder_dim)
return_dict["optical_encodings"] = optical_encodings
return_dict["optical_GAP"] = optical_GAP
if self.modality == "both":
joint_encodings = self.cross_encoder(
x=SAR_encodings,
context=optical_encodings,
relative_position_bias=self.attn_bias.to(optical_images.device),
) # (bsz, num_patches, encoder_dim)
joint_GAP = joint_encodings.mean(dim=1) # (bsz, encoder_dim)
return_dict["joint_encodings"] = joint_encodings
return_dict["joint_GAP"] = joint_GAP
return return_dict
def get_2dalibi(num_heads, num_patches):
# inspired by: https://github.com/ofirpress/attention_with_linear_biases
points = list(
itertools.product(range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches))))
)
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1)
idxs = []
for p1 in points:
for p2 in points:
dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
idxs.append(dist * slopes * -1)
all_bias = torch.cat(idxs, dim=1)
return all_bias.view(1, num_heads, num_patches, num_patches)
class FFN(nn.Module):
def __init__(
self,
dim,
mult=4,
dropout=0.0,
):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim), # (BSZ, num_patches, inner_dim)
nn.GELU(), # (BSZ, num_patches, inner_dim)
nn.Dropout(dropout), # (BSZ, num_patches, inner_dim)
nn.Linear(inner_dim, dim), # (BSZ, num_patches, dim)
)
self.input_norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.input_norm(x) # (BSZ, num_patches, dim)
return self.net(x) # (BSZ, num_patches, dim)
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
dropout=0.0,
):
super().__init__()
self.num_heads = num_heads
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
dim_head = int(dim / num_heads)
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim)
self.input_norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, relative_position_bias):
x = self.input_norm(x) # (BSZ, num_patches, dim)
q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (BSZ, num_patches, dim)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
) # (BSZ, num_heads, num_patches, dim_head)
attention_scores = (
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
) # (BSZ, num_heads, num_patches, num_patches)
attention_scores = (
attention_scores + relative_position_bias
) # (BSZ, num_heads, num_patches, num_patches)
attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches)
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
out = einsum(
"b h i j, b h j d -> b h i d", attn, v
) # (BSZ, num_heads, num_patches, dim_head)
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
return self.to_out(out) # (BSZ, num_patches, dim)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
dropout=0.0,
):
super().__init__()
self.num_heads = num_heads
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
dim_head = int(dim / num_heads)
self.scale = dim_head**-0.5
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_k = nn.Linear(dim, dim, bias=False)
self.to_v = nn.Linear(dim, dim, bias=False)
self.to_out = nn.Linear(dim, dim)
self.input_norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context, relative_position_bias):
x = self.input_norm(x) # (BSZ, num_patches, dim)
context = self.input_norm(context) # (BSZ, num_patches, dim)
q = self.to_q(x) # (BSZ, num_patches, dim)
k = self.to_k(context) # (BSZ, num_patches, dim)
v = self.to_v(context) # (BSZ, num_patches, dim)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
) # (BSZ, num_heads, num_patches, dim_head)
attention_scores = (
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
) # (BSZ, num_heads, num_patches, num_patches)
attention_scores = (
attention_scores + relative_position_bias
) # (BSZ, num_heads, num_patches, num_patches)
attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches)
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
out = einsum(
"b h i j, b h j d -> b h i d", attn, v
) # (BSZ, num_heads, num_patches, dim_head)
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
return self.to_out(out) # (BSZ, num_patches, dim)
class BaseTransformer(nn.Module):
def __init__(
self,
dim,
depth,
num_heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
final_norm=True,
):
super().__init__()
self.final_norm = final_norm
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
]
)
)
if self.final_norm:
self.norm_out = nn.LayerNorm(dim)
def forward(self, x, relative_position_bias=False):
for self_attn, ffn in self.layers:
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
x = ffn(x) + x # (BSZ, num_patches, dim)
if self.final_norm:
return self.norm_out(x)
else:
return x
class BaseTransformerCrossAttn(nn.Module):
def __init__(
self,
dim,
depth,
num_heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
CrossAttention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
]
)
)
self.norm_out = nn.LayerNorm(dim)
def forward(self, x, context, relative_position_bias):
for self_attn, cross_attn, ffn in self.layers:
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
x = cross_attn(x, context, relative_position_bias) + x # (BSZ, num_patches, dim)
x = ffn(x) + x # (BSZ, num_patches, dim)
x = self.norm_out(x)
return x # (BSZ, num_patches, dim)
class ViT(nn.Module):
def __init__(self, dim, depth, in_channels):
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.dim = dim
self.num_heads = 16 # always 16, for base and large models
self.patch_size = 8 # always 8, for base and large models
pixels_per_patch = int(self.patch_size * self.patch_size * in_channels)
self.linear_input = nn.Linear(pixels_per_patch, self.dim)
self.transformer = BaseTransformer(
dim=self.dim,
depth=self.depth,
num_heads=self.num_heads,
)
def forward(self, imgs, attn_bias):
x = rearrange(
imgs, "b c (h i) (w j) -> b (h w) (c i j)", i=self.patch_size, j=self.patch_size
)
# x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size)
x = self.linear_input(x) # (bsz, num_patches, dim)
x = self.transformer(x, relative_position_bias=attn_bias)
return x