GAIA-v1 / mae_dino /lightning_wrapper.py
willbender's picture
GAIA: A Foundation Model for Operational Atmospheric Dynamics
fd943c3
from typing import List, Optional, Tuple, Union, Dict
import logging
import math
import torch
import time
from datetime import datetime
import lightning.pytorch as pl
from mae_dino.models import GAIABase
from deepspeed.ops.adam import FusedAdam
import torch.nn.functional as F
Shape3d = Union[List[int], Tuple[int, int, int]]
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1):
"""
Cosine annealing scheduler with warm-up period, operating per step.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_steps (int): Number of warm-up steps.
total_steps (int): Total number of training steps.
eta_min (float): Minimum learning rate after annealing. Default: 0.0.
last_step (int): The index of the last step when resuming training. Default: -1.
"""
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_min = eta_min
super().__init__(optimizer, self.lr_lambda, last_step)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
# Linear warm-up
return float(current_step) / float(max(1, self.warmup_steps))
else:
# Cosine annealing
progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return cosine_decay * (1 - self.eta_min) + self.eta_min
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GAIABaseLightning(pl.LightningModule):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
encoder_depth: int = 8,
encoder_num_heads: int = 16,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
norm_pix_loss: bool = False,
drop_channels_rate: float = 0.0,
# DINO Args
adjacent_masking: bool = False,
norm_last_layer: bool = True,
dino_head_dim: int = 1024,
warmup_teacher_temp: float = 0.04,
teacher_temp: float = 0.04,
warmup_teacher_temp_epochs: int = 5,
epochs: int = 100,
student_temp: float = 0.1,
center_momentum: float = 0.9,
momentum_teacher: float = 0.996,
params: Dict[str, Union[Dict[str, float], int]] = None,
mask_ratio = 0.75,
val_step=0):
super().__init__()
self.params = params or {
'optimizer': {
'learning_rate': 1e-3,
},
'scheduler': {
'type': 'CosineAnnealingLR',
'warmup_steps': 100,
'total_steps': 1000,
'eta_min': 1e-10,
'last_step': -1,
},
}
self.model = GAIABase(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
norm_pix_loss=norm_pix_loss,
drop_channels_rate=drop_channels_rate,
# DINO Args
adjacent_masking=adjacent_masking,
norm_last_layer=norm_last_layer,
dino_head_dim=dino_head_dim,
warmup_teacher_temp=warmup_teacher_temp,
teacher_temp=teacher_temp,
warmup_teacher_temp_epochs=warmup_teacher_temp_epochs,
epochs=epochs,
student_temp=student_temp,
center_momentum=center_momentum,
momentum_teacher=momentum_teacher,
)
self.val_step = val_step
self.mask_ratio = mask_ratio
self.momentum_teacher = momentum_teacher
self.sync_dist = True
self.mae_loss_start = None
self.mae_loss_min = None
self.dino_loss_min = None
self.dino_loss_start = None
def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None):
if epoch is None:
try:
epoch = self.current_epoch
except:
raise ValueError(f"epoch value is invalid")
return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch)
def training_step(self, batch, batch_idx):
x = batch['x'].to(self.device)
x_mask = batch['x_mask'].to(self.device)
if batch['temporal_pos']:
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']]
else:
temporal_pos = None
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
if batch_idx % 50 == 0:
logger.info(f"Train loss : {total_loss}")
return total_loss
def on_after_backward(self):
# Clip gradients to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update EMA for teacher after each backward pass
with torch.no_grad():
for param_q, param_k in zip(self.model.student.parameters(), self.model.teacher.parameters()):
param_k.data.mul_(self.momentum_teacher).add_((1 - self.momentum_teacher) * param_q.detach().data)
def validation_step(self, batch, batch_idx):
x = batch['x']
x_mask = batch['x_mask']
if batch['temporal_pos']:
temporal_pos = [time_component for time_component in batch['temporal_pos']]
else:
temporal_pos = None
mask_ratio = self.mask_ratio
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
return total_loss
def configure_optimizers(self):
optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate'])
if self.params['scheduler']['type'] == 'CosineAnnealingLR':
if 'warmup_steps' not in self.params['scheduler']:
warmup_steps = 0
else:
warmup_steps = self.params['scheduler']['warmup_steps']
total_steps = self.params['scheduler']['total_steps']
eta_min = self.params['scheduler']['eta_min']
last_step = self.params['scheduler']['last_step']
scheduler = {
"scheduler": CosineAnnealingWarmupLR(
optimizer,
warmup_steps,
total_steps,
eta_min=eta_min,
last_step=last_step
),
"interval": "step", # Explicitly update learning rate every step
"frequency": 1 # Apply the scheduler at every step
}
else:
raise NotImplementedError(f"The specified {self.params['scheduler']['type']} scheduler is not implemented.")
return [optimizer], [scheduler]
def save_model(self, path):
torch.save(self.state_dict(), path)
@classmethod
def load_from_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)['state_dict'])
return model
@classmethod
def load_from_ds_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)) # NESTED state_dict exists in DS converted checkpoints
return model