File size: 8,384 Bytes
fd943c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from typing import List, Optional, Tuple, Union, Dict
import logging
import math
import torch
import time
from datetime import datetime
import lightning.pytorch as pl
from mae_dino.models import GAIABase
from deepspeed.ops.adam import FusedAdam
import torch.nn.functional as F
Shape3d = Union[List[int], Tuple[int, int, int]]
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1):
"""
Cosine annealing scheduler with warm-up period, operating per step.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_steps (int): Number of warm-up steps.
total_steps (int): Total number of training steps.
eta_min (float): Minimum learning rate after annealing. Default: 0.0.
last_step (int): The index of the last step when resuming training. Default: -1.
"""
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_min = eta_min
super().__init__(optimizer, self.lr_lambda, last_step)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
# Linear warm-up
return float(current_step) / float(max(1, self.warmup_steps))
else:
# Cosine annealing
progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return cosine_decay * (1 - self.eta_min) + self.eta_min
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GAIABaseLightning(pl.LightningModule):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
encoder_depth: int = 8,
encoder_num_heads: int = 16,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
norm_pix_loss: bool = False,
drop_channels_rate: float = 0.0,
# DINO Args
adjacent_masking: bool = False,
norm_last_layer: bool = True,
dino_head_dim: int = 1024,
warmup_teacher_temp: float = 0.04,
teacher_temp: float = 0.04,
warmup_teacher_temp_epochs: int = 5,
epochs: int = 100,
student_temp: float = 0.1,
center_momentum: float = 0.9,
momentum_teacher: float = 0.996,
params: Dict[str, Union[Dict[str, float], int]] = None,
mask_ratio = 0.75,
val_step=0):
super().__init__()
self.params = params or {
'optimizer': {
'learning_rate': 1e-3,
},
'scheduler': {
'type': 'CosineAnnealingLR',
'warmup_steps': 100,
'total_steps': 1000,
'eta_min': 1e-10,
'last_step': -1,
},
}
self.model = GAIABase(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
norm_pix_loss=norm_pix_loss,
drop_channels_rate=drop_channels_rate,
# DINO Args
adjacent_masking=adjacent_masking,
norm_last_layer=norm_last_layer,
dino_head_dim=dino_head_dim,
warmup_teacher_temp=warmup_teacher_temp,
teacher_temp=teacher_temp,
warmup_teacher_temp_epochs=warmup_teacher_temp_epochs,
epochs=epochs,
student_temp=student_temp,
center_momentum=center_momentum,
momentum_teacher=momentum_teacher,
)
self.val_step = val_step
self.mask_ratio = mask_ratio
self.momentum_teacher = momentum_teacher
self.sync_dist = True
self.mae_loss_start = None
self.mae_loss_min = None
self.dino_loss_min = None
self.dino_loss_start = None
def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None):
if epoch is None:
try:
epoch = self.current_epoch
except:
raise ValueError(f"epoch value is invalid")
return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch)
def training_step(self, batch, batch_idx):
x = batch['x'].to(self.device)
x_mask = batch['x_mask'].to(self.device)
if batch['temporal_pos']:
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']]
else:
temporal_pos = None
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
if batch_idx % 50 == 0:
logger.info(f"Train loss : {total_loss}")
return total_loss
def on_after_backward(self):
# Clip gradients to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update EMA for teacher after each backward pass
with torch.no_grad():
for param_q, param_k in zip(self.model.student.parameters(), self.model.teacher.parameters()):
param_k.data.mul_(self.momentum_teacher).add_((1 - self.momentum_teacher) * param_q.detach().data)
def validation_step(self, batch, batch_idx):
x = batch['x']
x_mask = batch['x_mask']
if batch['temporal_pos']:
temporal_pos = [time_component for time_component in batch['temporal_pos']]
else:
temporal_pos = None
mask_ratio = self.mask_ratio
(total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
return total_loss
def configure_optimizers(self):
optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate'])
if self.params['scheduler']['type'] == 'CosineAnnealingLR':
if 'warmup_steps' not in self.params['scheduler']:
warmup_steps = 0
else:
warmup_steps = self.params['scheduler']['warmup_steps']
total_steps = self.params['scheduler']['total_steps']
eta_min = self.params['scheduler']['eta_min']
last_step = self.params['scheduler']['last_step']
scheduler = {
"scheduler": CosineAnnealingWarmupLR(
optimizer,
warmup_steps,
total_steps,
eta_min=eta_min,
last_step=last_step
),
"interval": "step", # Explicitly update learning rate every step
"frequency": 1 # Apply the scheduler at every step
}
else:
raise NotImplementedError(f"The specified {self.params['scheduler']['type']} scheduler is not implemented.")
return [optimizer], [scheduler]
def save_model(self, path):
torch.save(self.state_dict(), path)
@classmethod
def load_from_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)['state_dict'])
return model
@classmethod
def load_from_ds_checkpoint(cls, path, **kwargs):
model = cls(**kwargs)
model.load_state_dict(torch.load(path)) # NESTED state_dict exists in DS converted checkpoints
return model
|