from typing import List, Optional, Tuple, Union, Dict import logging import math import torch import time from datetime import datetime import lightning.pytorch as pl from mae_dino.models import GAIABase from deepspeed.ops.adam import FusedAdam import torch.nn.functional as F Shape3d = Union[List[int], Tuple[int, int, int]] class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR): def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1): """ Cosine annealing scheduler with warm-up period, operating per step. Args: optimizer (Optimizer): Wrapped optimizer. warmup_steps (int): Number of warm-up steps. total_steps (int): Total number of training steps. eta_min (float): Minimum learning rate after annealing. Default: 0.0. last_step (int): The index of the last step when resuming training. Default: -1. """ self.warmup_steps = warmup_steps self.total_steps = total_steps self.eta_min = eta_min super().__init__(optimizer, self.lr_lambda, last_step) def lr_lambda(self, current_step): if current_step < self.warmup_steps: # Linear warm-up return float(current_step) / float(max(1, self.warmup_steps)) else: # Cosine annealing progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return cosine_decay * (1 - self.eta_min) + self.eta_min logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GAIABaseLightning(pl.LightningModule): def __init__(self, img_size: Shape3d = [4, 224, 224], patch_size: Shape3d = [1, 16, 16], in_chans: int = 3, encoder_embed_dim: int = 1024, encoder_depth: int = 8, encoder_num_heads: int = 16, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., norm_layer: torch.nn.Module = torch.nn.LayerNorm, norm_pix_loss: bool = False, drop_channels_rate: float = 0.0, # DINO Args adjacent_masking: bool = False, norm_last_layer: bool = True, dino_head_dim: int = 1024, warmup_teacher_temp: float = 0.04, teacher_temp: float = 0.04, warmup_teacher_temp_epochs: int = 5, epochs: int = 100, student_temp: float = 0.1, center_momentum: float = 0.9, momentum_teacher: float = 0.996, params: Dict[str, Union[Dict[str, float], int]] = None, mask_ratio = 0.75, val_step=0): super().__init__() self.params = params or { 'optimizer': { 'learning_rate': 1e-3, }, 'scheduler': { 'type': 'CosineAnnealingLR', 'warmup_steps': 100, 'total_steps': 1000, 'eta_min': 1e-10, 'last_step': -1, }, } self.model = GAIABase( img_size=img_size, patch_size=patch_size, in_chans=in_chans, encoder_embed_dim=encoder_embed_dim, encoder_depth=encoder_depth, encoder_num_heads=encoder_num_heads, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, decoder_num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, norm_pix_loss=norm_pix_loss, drop_channels_rate=drop_channels_rate, # DINO Args adjacent_masking=adjacent_masking, norm_last_layer=norm_last_layer, dino_head_dim=dino_head_dim, warmup_teacher_temp=warmup_teacher_temp, teacher_temp=teacher_temp, warmup_teacher_temp_epochs=warmup_teacher_temp_epochs, epochs=epochs, student_temp=student_temp, center_momentum=center_momentum, momentum_teacher=momentum_teacher, ) self.val_step = val_step self.mask_ratio = mask_ratio self.momentum_teacher = momentum_teacher self.sync_dist = True self.mae_loss_start = None self.mae_loss_min = None self.dino_loss_min = None self.dino_loss_start = None def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None): if epoch is None: try: epoch = self.current_epoch except: raise ValueError(f"epoch value is invalid") return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch) def training_step(self, batch, batch_idx): x = batch['x'].to(self.device) x_mask = batch['x_mask'].to(self.device) if batch['temporal_pos']: temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']] else: temporal_pos = None (total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) if batch_idx % 50 == 0: logger.info(f"Train loss : {total_loss}") return total_loss def on_after_backward(self): # Clip gradients to prevent exploding gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # Update EMA for teacher after each backward pass with torch.no_grad(): for param_q, param_k in zip(self.model.student.parameters(), self.model.teacher.parameters()): param_k.data.mul_(self.momentum_teacher).add_((1 - self.momentum_teacher) * param_q.detach().data) def validation_step(self, batch, batch_idx): x = batch['x'] x_mask = batch['x_mask'] if batch['temporal_pos']: temporal_pos = [time_component for time_component in batch['temporal_pos']] else: temporal_pos = None mask_ratio = self.mask_ratio (total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) return total_loss def configure_optimizers(self): optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate']) if self.params['scheduler']['type'] == 'CosineAnnealingLR': if 'warmup_steps' not in self.params['scheduler']: warmup_steps = 0 else: warmup_steps = self.params['scheduler']['warmup_steps'] total_steps = self.params['scheduler']['total_steps'] eta_min = self.params['scheduler']['eta_min'] last_step = self.params['scheduler']['last_step'] scheduler = { "scheduler": CosineAnnealingWarmupLR( optimizer, warmup_steps, total_steps, eta_min=eta_min, last_step=last_step ), "interval": "step", # Explicitly update learning rate every step "frequency": 1 # Apply the scheduler at every step } else: raise NotImplementedError(f"The specified {self.params['scheduler']['type']} scheduler is not implemented.") return [optimizer], [scheduler] def save_model(self, path): torch.save(self.state_dict(), path) @classmethod def load_from_checkpoint(cls, path, **kwargs): model = cls(**kwargs) model.load_state_dict(torch.load(path)['state_dict']) return model @classmethod def load_from_ds_checkpoint(cls, path, **kwargs): model = cls(**kwargs) model.load_state_dict(torch.load(path)) # NESTED state_dict exists in DS converted checkpoints return model