GAIA-v1 / downstream /gap_fill /lightning_wrapper.py
willbender's picture
GAIA: A Foundation Model for Operational Atmospheric Dynamics
fd943c3
from typing import List, Tuple, Union, Dict
import logging
import math
import torch
import lightning.pytorch as pl
from downstream.gap_fill.models import GapFill
from deepspeed.ops.adam import FusedAdam
from torch.optim.lr_scheduler import LambdaLR
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Shape3d = Union[List[int], Tuple[int, int, int]]
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1):
"""
Cosine annealing scheduler with warm-up period, operating per step.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_steps (int): Number of warm-up steps.
total_steps (int): Total number of training steps.
eta_min (float): Minimum learning rate after annealing. Default: 0.0.
last_step (int): The index of the last step when resuming training. Default: -1.
"""
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_min = eta_min
super().__init__(optimizer, self.lr_lambda, last_step)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
# Linear warm-up
return float(current_step) / float(max(1, self.warmup_steps))
else:
# Cosine annealing
progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return cosine_decay * (1 - self.eta_min) + self.eta_min
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GapFillLightning(pl.LightningModule):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
encoder_depth: int = 8,
encoder_num_heads: int = 16,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
norm_pix_loss: bool = False,
drop_channels_rate: float = 0.0,
adjacent_masking: bool = False,
params: Dict[str, Union[Dict[str, float], int]] = None,
mask_ratio = 0.75):
super().__init__()
self.params = params or {
'optimizer': {
'learning_rate': 1e-3,
},
'scheduler': {
'type': 'CosineAnnealingLR',
'warmup_steps': 100,
'total_steps': 1000,
'eta_min': 1e-10,
'last_step': -1,
},
}
self.model = GapFill(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
norm_pix_loss=norm_pix_loss,
drop_channels_rate=drop_channels_rate,
adjacent_masking=adjacent_masking
)
self.mask_ratio = mask_ratio
self.log_helper(f"Mask Ratio: {self.mask_ratio}")
def log_helper(self, message):
logger.info(message)
def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None):
self.log_helper("forward pass start")
if epoch is None:
try:
epoch = self.current_epoch
except:
raise ValueError(f"epoch value is invalid")
self.log_helper("finished forward")
return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch)
def training_step(self, batch, batch_idx):
x = batch['x'].to(self.device)
x_mask = batch['x_mask'].to(self.device)
if batch['temporal_pos']:
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']]
else:
temporal_pos = None
(loss), (pred, mask) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
return loss
def on_after_backward(self):
# Clip gradients to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
def validation_step(self, batch, batch_idx):
x = batch['x'].to(self.device)
x_mask = batch['x_mask'].to(self.device)
if batch['temporal_pos']:
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']]
else:
temporal_pos = None
(loss), (pred, mask) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
return loss
def configure_optimizers(self):
optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate'])
# Check if warmup is needed
if 'warmup_steps' in self.params['scheduler'] and self.params['scheduler']['warmup_steps'] > 0:
warmup_steps = self.params['scheduler']['warmup_steps']
# Warmup function: linear increase from 0 to target LR, then constant
def warmup_fn(step):
if step < warmup_steps:
return step / warmup_steps # Linear warmup
return 1.0 # Constant LR after warmup
scheduler = {
"scheduler": LambdaLR(optimizer, lr_lambda=warmup_fn),
"interval": "step",
"frequency": 1
}
return [optimizer], [scheduler]
# No warmup: just return optimizer
return [optimizer]
def save_model(self, path):
torch.save(self.state_dict(), path)
@classmethod
def load_from_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)['state_dict'])
return model
@classmethod
def load_from_ds_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)) # NESTED state_dict exists in DS converted checkpoints
return model