File size: 8,384 Bytes
fd943c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from typing import List, Optional, Tuple, Union, Dict
import logging
import math
import torch
import time
from datetime import datetime
import lightning.pytorch as pl

from mae_dino.models import GAIABase
from deepspeed.ops.adam import FusedAdam

import torch.nn.functional as F

Shape3d = Union[List[int], Tuple[int, int, int]]


class CosineAnnealingWarmupLR(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0.0, last_step=-1):
        """
        Cosine annealing scheduler with warm-up period, operating per step.

        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_steps (int): Number of warm-up steps.
            total_steps (int): Total number of training steps.
            eta_min (float): Minimum learning rate after annealing. Default: 0.0.
            last_step (int): The index of the last step when resuming training. Default: -1.
        """
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.eta_min = eta_min

        super().__init__(optimizer, self.lr_lambda, last_step)

    def lr_lambda(self, current_step):
        if current_step < self.warmup_steps:
            # Linear warm-up
            return float(current_step) / float(max(1, self.warmup_steps))
        else:
            # Cosine annealing
            progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
            cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
            return cosine_decay * (1 - self.eta_min) + self.eta_min


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class GAIABaseLightning(pl.LightningModule):
    def __init__(self, 
                 img_size: Shape3d = [4, 224, 224], 
                 patch_size: Shape3d = [1, 16, 16], 
                 in_chans: int = 3, 
                 encoder_embed_dim: int = 1024, 
                 encoder_depth: int = 8, 
                 encoder_num_heads: int = 16, 
                 decoder_embed_dim: int = 512, 
                 decoder_depth: int = 8, 
                 decoder_num_heads: int = 16,
                 mlp_ratio: float = 4., 
                 norm_layer: torch.nn.Module = torch.nn.LayerNorm, 
                 norm_pix_loss: bool = False,
                 drop_channels_rate: float = 0.0,
                 # DINO Args
                 adjacent_masking: bool = False,
                 norm_last_layer: bool = True,
                 dino_head_dim: int = 1024,
                 warmup_teacher_temp: float = 0.04, 
                 teacher_temp: float = 0.04, 
                 warmup_teacher_temp_epochs: int = 5, 
                 epochs: int = 100, 
                 student_temp: float = 0.1, 
                 center_momentum: float = 0.9,
                 momentum_teacher: float = 0.996,
                 params: Dict[str, Union[Dict[str, float], int]] = None,
                 mask_ratio = 0.75,
                 val_step=0):
        super().__init__()
        self.params = params or {
            'optimizer': {
                'learning_rate': 1e-3,
            },
            'scheduler': {
                'type': 'CosineAnnealingLR',
                'warmup_steps': 100,
                'total_steps': 1000,
                'eta_min': 1e-10,
                'last_step': -1,
            },
        }

        self.model = GAIABase(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans, 
            encoder_embed_dim=encoder_embed_dim, 
            encoder_depth=encoder_depth, 
            encoder_num_heads=encoder_num_heads, 
            decoder_embed_dim=decoder_embed_dim, 
            decoder_depth=decoder_depth, 
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio, 
            norm_layer=norm_layer, 
            norm_pix_loss=norm_pix_loss,
            drop_channels_rate=drop_channels_rate,
            # DINO Args
            adjacent_masking=adjacent_masking,
            norm_last_layer=norm_last_layer,
            dino_head_dim=dino_head_dim,
            warmup_teacher_temp=warmup_teacher_temp,
            teacher_temp=teacher_temp,
            warmup_teacher_temp_epochs=warmup_teacher_temp_epochs,
            epochs=epochs,
            student_temp=student_temp,
            center_momentum=center_momentum,
            momentum_teacher=momentum_teacher,
        )
        self.val_step = val_step
        self.mask_ratio = mask_ratio
        self.momentum_teacher = momentum_teacher
        self.sync_dist = True
        self.mae_loss_start = None
        self.mae_loss_min = None
        self.dino_loss_min = None
        self.dino_loss_start = None


    def forward(self, x, x_mask, temporal_pos=None, mask_ratio=0.75, epoch=None):
        if epoch is None:
            try:
                epoch = self.current_epoch
            except:
                raise ValueError(f"epoch value is invalid")
        return self.model(x, x_mask, temporal_pos=temporal_pos, mask_ratio=mask_ratio, epoch=epoch)

    def training_step(self, batch, batch_idx):
        x = batch['x'].to(self.device)
        x_mask = batch['x_mask'].to(self.device)
        if batch['temporal_pos']:
            temporal_pos = [time_component.to(self.device) for time_component in batch['temporal_pos']]
        else:
            temporal_pos = None
        
        (total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)

        if batch_idx % 50 == 0:
            logger.info(f"Train loss : {total_loss}")

        return total_loss
        
    def on_after_backward(self):
        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
    
        # Update EMA for teacher after each backward pass
        with torch.no_grad():
            for param_q, param_k in zip(self.model.student.parameters(), self.model.teacher.parameters()):
                param_k.data.mul_(self.momentum_teacher).add_((1 - self.momentum_teacher) * param_q.detach().data)


    def validation_step(self, batch, batch_idx):
        x = batch['x']
        x_mask = batch['x_mask']
        if batch['temporal_pos']:
            temporal_pos = [time_component for time_component in batch['temporal_pos']]
        else:
            temporal_pos = None

        mask_ratio = self.mask_ratio
        (total_loss, dino_loss, mae_loss), (pred, mask, _, _) = self(x, x_mask, temporal_pos=temporal_pos, mask_ratio=self.mask_ratio, epoch=self.current_epoch)
        return total_loss

    def configure_optimizers(self):        
        optimizer = FusedAdam(self.parameters(), lr=self.params['optimizer']['learning_rate'])

        if self.params['scheduler']['type'] == 'CosineAnnealingLR':
            if 'warmup_steps' not in self.params['scheduler']:
                warmup_steps = 0
            else:
                warmup_steps = self.params['scheduler']['warmup_steps']
            total_steps = self.params['scheduler']['total_steps']
            eta_min = self.params['scheduler']['eta_min']
            last_step = self.params['scheduler']['last_step']
            scheduler = {
                "scheduler": CosineAnnealingWarmupLR(
                    optimizer, 
                    warmup_steps, 
                    total_steps, 
                    eta_min=eta_min, 
                    last_step=last_step
                ),
                "interval": "step",  # Explicitly update learning rate every step
                "frequency": 1  # Apply the scheduler at every step
            }

        else:
            raise NotImplementedError(f"The specified {self.params['scheduler']['type']} scheduler is not implemented.")
        return [optimizer], [scheduler]

    def save_model(self, path):
        torch.save(self.state_dict(), path)

    @classmethod
    def load_from_checkpoint(cls, path, **kwargs):
        model = cls(**kwargs)
        model.load_state_dict(torch.load(path)['state_dict'])
        return model
    
    @classmethod
    def load_from_ds_checkpoint(cls, path, **kwargs):
        model = cls(**kwargs)
        model.load_state_dict(torch.load(path)) # NESTED state_dict exists in DS converted checkpoints
        return model