""" PyTorch Hub configuration for AnySat model. """ import warnings import torch import torch.nn as nn from .models.networks.encoder.Any_multi import AnyModule # Import your actual model class from .models.networks.encoder.Transformer import TransformerMulti from .models.networks.encoder.utils.ltae import PatchLTAEMulti from .models.networks.encoder.utils.patch_embeddings import PatchMLPMulti class AnySat(nn.Module): """ AnySat: Earth Observation Model for Any Resolutions, Scales, and Modalities Args: model_size (str): Model size - 'tiny', 'small', or 'base' flash_attn (bool): Whether to use flash attention **kwargs: Additional arguments to override config """ def __init__(self, model_size="base", flash_attn=True, **kwargs): super().__init__() self.res = { "aerial": 0.2, "aerial-flair": 0.2, "spot": 1.0, "naip": 1.25, "s2": 10, "s1-asc": 10, "s1-des": 10, "s1": 10, "l8": 10, "l7": 30, "alos": 30, } self.config = get_default_config(model_size) self.config["flash_attn"] = flash_attn # Override any additional parameters device = None for k, v in kwargs.items(): if k == "device": device = v else: # Update nested dictionary keys = k.split(".") current = self.config for key in keys[:-1]: current = current.setdefault(key, {}) current[keys[-1]] = v projectors = {} for modality in self.config["modalities"]["all"]: if "T" in self.config["projectors"][modality].keys(): projectors[modality] = PatchLTAEMulti(**self.config["projectors"][modality]) else: projectors[modality] = PatchMLPMulti(**self.config["projectors"][modality]) del self.config["projectors"] with warnings.catch_warnings(): # Ignore all warnings during model initialization warnings.filterwarnings("ignore") self.spatial_encoder = TransformerMulti(**self.config["spatial_encoder"]) del self.config["spatial_encoder"] self.model = AnyModule( projectors=projectors, spatial_encoder=self.spatial_encoder, **self.config ) if device is not None: self.model = self.model.to(device) @classmethod def from_pretrained(cls, model_size="base", **kwargs): """ Create a pretrained AnySat model Args: model_size (str): Model size - 'tiny', 'small', or 'base' **kwargs: Additional arguments passed to the constructor """ model = cls(model_size=model_size, **kwargs) checkpoint_urls = { "base": "https://huggingface.co/g-astruc/AnySat/resolve/main/models/AnySat.pth", # 'small': 'https://huggingface.co/gastruc/anysat/resolve/main/anysat_small_geoplex.pth', COMING SOON # 'tiny': 'https://huggingface.co/gastruc/anysat/resolve/main/anysat_tiny_geoplex.pth' COMING SOON } checkpoint_url = checkpoint_urls[model_size] state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, progress=True)[ "state_dict" ] model.model.load_state_dict(state_dict) return model def forward(self, x, patch_size, output="patch", **kwargs): assert output in [ "patch", "tile", "dense", "all", ], "Output must be one of 'patch', 'tile', 'dense', 'all'" sizes = {} for modality in list(x.keys()): if modality.endswith("_dates"): continue shape = x[modality].shape assert shape[-2] == shape[-1], "Images must be squared" if modality in ["s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"]: assert ( len(shape) == 5 ), f"{modality} Images must be 5D: Batch, Time, Channels, Height, Width" else: assert ( len(shape) == 4 ), f"{modality} Images must be 4D: Batch, Channels, Height, Width" if modality != "modis": sizes[modality] = shape[-1] * self.res[modality] if len(sizes) >= 2: size_values = list(sizes.values()) for i in range(len(size_values) - 1): if ( abs(size_values[i] - size_values[i + 1]) > 1e-10 ): # Using small epsilon for float comparison mod1, mod2 = list(sizes.keys())[i], list(sizes.keys())[i + 1] raise ValueError( f"Modalities {mod1} and {mod2} have incompatible sizes: {size_values[i]} vs {size_values[i + 1]}" ) return self.model.forward_release(x, patch_size // 10, output=output, **kwargs) # Hub entry points def anysat(pretrained=False, **kwargs): """PyTorch Hub entry point""" if pretrained: return AnySat.from_pretrained(**kwargs) return AnySat(**kwargs) def anysat_tiny(pretrained=False, **kwargs): return anysat(pretrained=pretrained, model_size="tiny", **kwargs) def anysat_small(pretrained=False, **kwargs): return anysat(pretrained=pretrained, model_size="small", **kwargs) def anysat_base(pretrained=False, **kwargs): return anysat(pretrained=pretrained, model_size="base", **kwargs) def get_default_config(model_size="base"): """Get default configuration based on model size""" dim = 768 if model_size == "base" else (512 if model_size == "small" else 256) depth = 6 if model_size == "base" else (4 if model_size == "small" else 2) heads = 12 if model_size == "base" else (8 if model_size == "small" else 4) base_config = { "modalities": { "all": [ "aerial", "aerial-flair", "spot", "naip", "s2", "s1-asc", "s1", "alos", "l7", "l8", "modis", ] }, "projectors": { "aerial": { "patch_size": 10, "in_chans": 4, "embed_dim": dim, "bias": False, "mlp": [dim, dim * 2, dim], }, "aerial-flair": { "patch_size": 10, "in_chans": 5, "embed_dim": dim, "bias": False, "mlp": [dim, dim * 2, dim], }, "spot": { "patch_size": 10, "in_chans": 3, "embed_dim": dim, "bias": False, "resolution": 1.0, "mlp": [dim, dim * 2, dim], }, "naip": { "patch_size": 8, "in_chans": 4, "embed_dim": dim, "bias": False, "resolution": 1.25, "mlp": [dim, dim * 2, dim], }, "s2": { "in_channels": 10, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.0, "T": 367, "in_norm": True, "return_att": False, "positional_encoding": True, }, "s1-asc": { "in_channels": 2, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 367, "in_norm": False, "return_att": False, "positional_encoding": True, }, "s1": { "in_channels": 3, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 367, "in_norm": False, "return_att": False, "positional_encoding": True, }, "alos": { "in_channels": 3, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 367, "in_norm": False, "return_att": False, "positional_encoding": True, }, "l7": { "in_channels": 6, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 367, "in_norm": False, "return_att": False, "positional_encoding": True, }, "l8": { "in_channels": 11, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 366, "in_norm": False, "return_att": False, "positional_encoding": True, }, "modis": { "in_channels": 7, "n_head": 16, "d_k": 8, "mlp": [dim], "mlp_in": [dim // 8, dim // 2, dim, dim * 2, dim], "dropout": 0.2, "T": 367, "in_norm": False, "return_att": False, "positional_encoding": True, "reduce_scale": 12, }, }, "spatial_encoder": { "embed_dim": dim, "depth": depth, "num_heads": heads, "mlp_ratio": 4.0, "attn_drop_rate": 0.0, "drop_path_rate": 0.0, "modalities": { "all": [ "aerial", "aerial-flair", "spot", "naip", "s2", "s1-asc", "s1", "alos", "l7", "l8", "modis", ] }, "scales": {}, "input_res": { "aerial": 2, "aerial-flair": 2, "spot": 10, "naip": 10, "s2": 10, "s1-asc": 10, "s1-des": 10, "s1": 10, "l8": 10, "l7": 30, "alos": 30, "modis": 250, }, }, "num_patches": {}, "embed_dim": dim, "depth": depth, "num_heads": heads, "mlp_ratio": 4.0, "class_token": True, "pre_norm": False, "drop_rate": 0.0, "patch_drop_rate": 0.0, "drop_path_rate": 0.0, "attn_drop_rate": 0.0, "scales": {}, "flash_attn": True, "release": True, } return base_config