|
from typing import List, Optional, Tuple, Union, Dict |
|
import logging |
|
import math |
|
import torch |
|
import time |
|
from datetime import datetime |
|
import lightning.pytorch as pl |
|
|
|
from mae_dino.models import GAIABase |
|
from deepspeed.ops.adam import FusedAdam |
|
|
|
import torch.nn.functional as F |
|
|
|
Shape3d = Union[List[int], Tuple[int, int, int]] |
|
|
|
|
|
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR): |
|
def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1): |
|
""" |
|
Cosine annealing scheduler with warm-up period, operating per step. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
warmup_steps (int): Number of warm-up steps. |
|
total_steps (int): Total number of training steps. |
|
eta_min (float): Minimum learning rate after annealing. Default: 0.0. |
|
last_step (int): The index of the last step when resuming training. Default: -1. |
|
""" |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.eta_min = eta_min |
|
|
|
super().__init__(optimizer, self.lr_lambda, last_step) |
|
|
|
def lr_lambda(self, current_step): |
|
if current_step < self.warmup_steps: |
|
|
|
return float(current_step) / float(max(1, self.warmup_steps)) |
|
else: |
|
|
|
progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)) |
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) |
|
return cosine_decay * (1 - self.eta_min) + self.eta_min |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class GAIABaseLightning(pl.LightningModule): |
|
def __init__(self, |
|
img_size: Shape3d = [4, 224, 224], |
|
patch_size: Shape3d = [1, 16, 16], |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
encoder_depth: int = 8, |
|
encoder_num_heads: int = 16, |
|
decoder_embed_dim: int = 512, |
|
decoder_depth: int = 8, |
|
decoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: torch.nn.Module = torch.nn.LayerNorm, |
|
norm_pix_loss: bool = False, |
|
drop_channels_rate: float = 0.0, |
|
|
|
adjacent_masking: bool = False, |
|
norm_last_layer: bool = True, |
|
dino_head_dim: int = 1024, |
|
warmup_teacher_temp: float = 0.04, |
|
teacher_temp: float = 0.04, |
|
warmup_teacher_temp_epochs: int = 5, |
|
epochs: int = 100, |
|
student_temp: float = 0.1, |
|
center_momentum: float = 0.9, |
|
momentum_teacher: float = 0.996, |
|
params: Dict[str, Union[Dict[str, float], int]] = None, |
|
mask_ratio = 0.75, |
|
val_step=0): |
|
super().__init__() |
|
self.params = params or { |
|
'optimizer': { |
|
'learning_rate': 1e-3, |
|
}, |
|
'scheduler': { |
|
'type': 'CosineAnnealingLR', |
|
'warmup_steps': 100, |
|
'total_steps': 1000, |
|
'eta_min': 1e-10, |
|
'last_step': -1, |
|
}, |
|
} |
|
|
|
self.model = GAIABase( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=encoder_embed_dim, |
|
encoder_depth=encoder_depth, |
|
encoder_num_heads=encoder_num_heads, |
|
decoder_embed_dim=decoder_embed_dim, |
|
decoder_depth=decoder_depth, |
|
decoder_num_heads=decoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
norm_pix_loss=norm_pix_loss, |
|
drop_channels_rate=drop_channels_rate, |
|
|
|
adjacent_masking=adjacent_masking, |
|
norm_last_layer=norm_last_layer, |
|
dino_head_dim=dino_head_dim, |
|
warmup_teacher_temp=warmup_teacher_temp, |
|
teacher_temp=teacher_temp, |
|
warmup_teacher_temp_epochs=warmup_teacher_temp_epochs, |
|
epochs=epochs, |
|
student_temp=student_temp, |
|
center_momentum=center_momentum, |
|
momentum_teacher=momentum_teacher, |
|
) |
|
self.val_step = val_step |
|
self.mask_ratio = mask_ratio |
|
self.momentum_teacher = momentum_teacher |
|
self.sync_dist = True |
|
self.mae_loss_start = None |
|
self.mae_loss_min = None |
|
self.dino_loss_min = None |
|
self.dino_loss_start = None |
|
|
|
|
|
def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None): |
|
if epoch is None: |
|
try: |
|
epoch = self.current_epoch |
|
except: |
|
raise ValueError(f"epoch value is invalid") |
|
return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch) |
|
|
|
def training_step(self, batch, batch_idx): |
|
x = batch['x'].to(self.device) |
|
x_mask = batch['x_mask'].to(self.device) |
|
if batch['temporal_pos']: |
|
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']] |
|
else: |
|
temporal_pos = None |
|
|
|
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) |
|
|
|
if batch_idx % 50 == 0: |
|
logger.info(f"Train loss : {total_loss}") |
|
|
|
return total_loss |
|
|
|
def on_after_backward(self): |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
|
|
with torch.no_grad(): |
|
for param_q, param_k in zip(self.model.student.parameters(), self.model.teacher.parameters()): |
|
param_k.data.mul_(self.momentum_teacher).add_((1 - self.momentum_teacher) * param_q.detach().data) |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
x = batch['x'] |
|
x_mask = batch['x_mask'] |
|
if batch['temporal_pos']: |
|
temporal_pos = [time_component for time_component in batch['temporal_pos']] |
|
else: |
|
temporal_pos = None |
|
|
|
mask_ratio = self.mask_ratio |
|
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) |
|
return total_loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate']) |
|
|
|
if self.params['scheduler']['type'] == 'CosineAnnealingLR': |
|
if 'warmup_steps' not in self.params['scheduler']: |
|
warmup_steps = 0 |
|
else: |
|
warmup_steps = self.params['scheduler']['warmup_steps'] |
|
total_steps = self.params['scheduler']['total_steps'] |
|
eta_min = self.params['scheduler']['eta_min'] |
|
last_step = self.params['scheduler']['last_step'] |
|
scheduler = { |
|
"scheduler": CosineAnnealingWarmupLR( |
|
optimizer, |
|
warmup_steps, |
|
total_steps, |
|
eta_min=eta_min, |
|
last_step=last_step |
|
), |
|
"interval": "step", |
|
"frequency": 1 |
|
} |
|
|
|
else: |
|
raise NotImplementedError(f"The specified {self.params['scheduler']['type']} scheduler is not implemented.") |
|
return [optimizer], [scheduler] |
|
|
|
def save_model(self, path): |
|
torch.save(self.state_dict(), path) |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, path, **kwargs): |
|
model = cls(**kwargs) |
|
model.load_state_dict(torch.load(path)['state_dict']) |
|
return model |
|
|
|
@classmethod |
|
def load_from_ds_checkpoint(cls, path, **kwargs): |
|
model = cls(**kwargs) |
|
model.load_state_dict(torch.load(path)) |
|
return model |
|
|