|
from typing import List, Tuple, Union, Dict |
|
import logging |
|
import math |
|
import torch |
|
import lightning.pytorch as pl |
|
|
|
from downstream.gap_fill.models import GapFill |
|
from deepspeed.ops.adam import FusedAdam |
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
Shape3d = Union[List[int], Tuple[int, int, int]] |
|
|
|
|
|
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR): |
|
def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1): |
|
""" |
|
Cosine annealing scheduler with warm-up period, operating per step. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
warmup_steps (int): Number of warm-up steps. |
|
total_steps (int): Total number of training steps. |
|
eta_min (float): Minimum learning rate after annealing. Default: 0.0. |
|
last_step (int): The index of the last step when resuming training. Default: -1. |
|
""" |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.eta_min = eta_min |
|
|
|
super().__init__(optimizer, self.lr_lambda, last_step) |
|
|
|
def lr_lambda(self, current_step): |
|
if current_step < self.warmup_steps: |
|
|
|
return float(current_step) / float(max(1, self.warmup_steps)) |
|
else: |
|
|
|
progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)) |
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) |
|
return cosine_decay * (1 - self.eta_min) + self.eta_min |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class GapFillLightning(pl.LightningModule): |
|
def __init__(self, |
|
img_size: Shape3d = [4, 224, 224], |
|
patch_size: Shape3d = [1, 16, 16], |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
encoder_depth: int = 8, |
|
encoder_num_heads: int = 16, |
|
decoder_embed_dim: int = 512, |
|
decoder_depth: int = 8, |
|
decoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: torch.nn.Module = torch.nn.LayerNorm, |
|
norm_pix_loss: bool = False, |
|
drop_channels_rate: float = 0.0, |
|
adjacent_masking: bool = False, |
|
params: Dict[str, Union[Dict[str, float], int]] = None, |
|
mask_ratio = 0.75): |
|
super().__init__() |
|
self.params = params or { |
|
'optimizer': { |
|
'learning_rate': 1e-3, |
|
}, |
|
'scheduler': { |
|
'type': 'CosineAnnealingLR', |
|
'warmup_steps': 100, |
|
'total_steps': 1000, |
|
'eta_min': 1e-10, |
|
'last_step': -1, |
|
}, |
|
} |
|
|
|
self.model = GapFill( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=encoder_embed_dim, |
|
encoder_depth=encoder_depth, |
|
encoder_num_heads=encoder_num_heads, |
|
decoder_embed_dim=decoder_embed_dim, |
|
decoder_depth=decoder_depth, |
|
decoder_num_heads=decoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
norm_pix_loss=norm_pix_loss, |
|
drop_channels_rate=drop_channels_rate, |
|
adjacent_masking=adjacent_masking |
|
) |
|
self.mask_ratio = mask_ratio |
|
|
|
self.log_helper(f"Mask Ratio: {self.mask_ratio}") |
|
|
|
|
|
def log_helper(self, message): |
|
logger.info(message) |
|
|
|
|
|
def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None): |
|
self.log_helper("forward pass start") |
|
if epoch is None: |
|
try: |
|
epoch = self.current_epoch |
|
except: |
|
raise ValueError(f"epoch value is invalid") |
|
self.log_helper("finished forward") |
|
return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
x = batch['x'].to(self.device) |
|
x_mask = batch['x_mask'].to(self.device) |
|
if batch['temporal_pos']: |
|
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']] |
|
else: |
|
temporal_pos = None |
|
|
|
(loss), (pred, mask) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) |
|
|
|
return loss |
|
|
|
def on_after_backward(self): |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x = batch['x'].to(self.device) |
|
x_mask = batch['x_mask'].to(self.device) |
|
|
|
if batch['temporal_pos']: |
|
temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']] |
|
else: |
|
temporal_pos = None |
|
|
|
(loss), (pred, mask) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch) |
|
|
|
return loss |
|
|
|
|
|
def configure_optimizers(self): |
|
optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate']) |
|
|
|
|
|
if 'warmup_steps' in self.params['scheduler'] and self.params['scheduler']['warmup_steps'] > 0: |
|
warmup_steps = self.params['scheduler']['warmup_steps'] |
|
|
|
|
|
def warmup_fn(step): |
|
if step < warmup_steps: |
|
return step / warmup_steps |
|
return 1.0 |
|
|
|
scheduler = { |
|
"scheduler": LambdaLR(optimizer, lr_lambda=warmup_fn), |
|
"interval": "step", |
|
"frequency": 1 |
|
} |
|
|
|
return [optimizer], [scheduler] |
|
|
|
|
|
return [optimizer] |
|
|
|
|
|
def save_model(self, path): |
|
torch.save(self.state_dict(), path) |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, path, **kwargs): |
|
model = cls(**kwargs) |
|
model.load_state_dict(torch.load(path)['state_dict']) |
|
return model |
|
|
|
@classmethod |
|
def load_from_ds_checkpoint(cls, path, **kwargs): |
|
model = cls(**kwargs) |
|
model.load_state_dict(torch.load(path)) |
|
return model |
|
|