willbender's picture
GAIA: A Foundation Model for Operational Atmospheric Dynamics
fd943c3
from typing import List, Optional, Tuple, Union
import timm
import torch
import numpy as np
from einops import rearrange
from timm.layers import Mlp
Shape3d = Union[List[int], Tuple[int, int, int]]
# PatchEmbed
class PatchEmbed(torch.nn.Module):
""" 3D Image to Patch Embedding
"""
def __init__(
self,
img_size: Shape3d = (4, 224, 224),
patch_size: Shape3d = (1, 16, 16),
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[torch.nn.Module] = None,
flatten: bool = True,
bias: bool = True,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
assert len(self.img_size) == 3 and len(self.patch_size) == 3
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten
self.proj = torch.nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()
def forward(self, x):
B, C, T, H, W = x.shape
assert H == self.img_size[1], f"Input data height ({H}) doesn't match model ({self.img_size[1]})."
assert W == self.img_size[2], f"Input data width ({W}) doesn't match model ({self.img_size[2]})."
assert T == self.img_size[0], f"Input data timesteps ({T}) doesn't match model ({self.img_size[0]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BCL -> BLC
x = self.norm(x)
return x
# DropPath
def generate_mask(x, drop_prob: float = 0., scale_by_keep: bool = True):
""" Create drop mask for x. Adapted from timm.models.layers.drop_path. """
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return random_tensor
class DropPath(torch.nn.Module):
""" Adapted from timm.models.layers.DropPath. In this version, drop mask can be saved and reused.
This is useful when applying the same drop mask more than once.
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
self.drop_mask = None
def generate_mask(self, x: torch.Tensor):
self.drop_mask = generate_mask(x, self.drop_prob, self.scale_by_keep)
def forward(self, x: torch.Tensor, new_mask: bool = True):
if self.drop_prob == 0. or not self.training:
return x
if self.drop_mask is None or new_mask:
self.generate_mask(x)
return self.drop_mask * x
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class Attention(torch.nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
self.proj = torch.nn.Linear(dim, dim)
self.proj_drop = torch.nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(torch.nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=torch.nn.GELU, norm_layer=torch.nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else torch.nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x