openfree's picture
Deploy from GitHub repository
b20c769 verified
from pathlib import Path
from typing import List
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DeCurWrapper(nn.Module):
def __init__(
self, weights_path: Path, modality: str, do_pool=True, temporal_pooling: str = "mean"
):
super().__init__()
assert modality in ["SAR", "optical"]
self.encoder = timm.create_model("vit_small_patch16_224", pretrained=False)
self.dim = 384
self.modality = modality
if modality == "optical":
self.encoder.patch_embed.proj = torch.nn.Conv2d(
13, 384, kernel_size=(16, 16), stride=(16, 16)
)
state_dict = torch.load(weights_path / "vits16_ssl4eo-s12_ms_decur_ep100.pth")
msg = self.encoder.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"head.weight", "head.bias"}
else:
self.encoder.patch_embed.proj = torch.nn.Conv2d(
2, 384, kernel_size=(16, 16), stride=(16, 16)
)
state_dict = torch.load(weights_path / "vits16_ssl4eo-s12_sar_decur_ep100.pth")
msg = self.encoder.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"head.weight", "head.bias"}
self.image_resolution = 224
self.patch_size = 16
self.grid_size = int(self.image_resolution / self.patch_size)
self.do_pool = do_pool
if temporal_pooling not in ["mean", "max"]:
raise ValueError(
f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}"
)
self.temporal_pooling = temporal_pooling
def resize(self, images):
images = F.interpolate(
images,
size=(self.image_resolution, self.image_resolution),
mode="bilinear",
align_corners=False,
)
return images
def preproccess(self, images):
images = rearrange(images, "b h w c -> b c h w")
assert (images.shape[1] == 13) or (images.shape[1] == 2)
return self.resize(images) # (bsz, C, H, W)
def forward(self, s2=None, s1=None, months=None):
if s1 is not None:
assert self.modality == "SAR"
if len(s1.shape) == 5:
outputs_l: List[torch.Tensor] = []
for timestep in range(s1.shape[3]):
image = self.preproccess(s1[:, :, :, timestep])
output = self.encoder.forward_features(image)
if self.do_pool:
output = output.mean(dim=1)
else:
output = output[:, 1:]
outputs_l.append(output)
outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t
if self.temporal_pooling == "mean":
return outputs_t.mean(dim=-1)
else:
return torch.amax(outputs_t, dim=-1)
else:
s1 = self.preproccess(s1)
output = self.encoder.forward_features(s1)
if self.do_pool:
return output.mean(dim=1)
else:
return output[:, 1:]
elif s2 is not None:
assert self.modality == "optical"
if len(s2.shape) == 5:
outputs_l: List[torch.Tensor] = []
for timestep in range(s2.shape[3]):
image = self.preproccess(s2[:, :, :, timestep])
output = self.encoder.forward_features(image)
if self.do_pool:
output = output.mean(dim=1)
else:
output = output[:, 1:]
outputs_l.append(output)
outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t
if self.temporal_pooling == "mean":
return outputs_t.mean(dim=-1)
else:
return torch.amax(outputs_t, dim=-1)
else:
s2 = self.preproccess(s2)
output = self.encoder.forward_features(s2)
if self.do_pool:
return output.mean(dim=1)
else:
return output[:, 1:]