import copy import torch import torch.nn.functional as F from lightning import LightningModule from models.networks.encoder.Transformer import apply_masks, repeat_interleave_batch from utils.mask_collator import MaskCollator, MaskCollatorNaive class Module(LightningModule): def __init__( self, network, loss, train_metrics, val_metrics, test_metrics, scheduler, optimizer, ema, ipe_scale, len_data, batch_size, num_epochs, scale, shape, ): super().__init__() self.model = network.instance self.target_encoder = copy.deepcopy(self.model.encoder) for p in self.target_encoder.parameters(): p.requires_grad = False self.loss = loss self.train_metrics = train_metrics self.val_metrics = val_metrics self.test_metrics = test_metrics self.optimizer = optimizer self.scheduler = scheduler self.mask_collator = MaskCollator( input_size=(shape // scale, shape // scale), patch_size=1, enc_mask_scale=(0.85, 1.0), pred_mask_scale=(0.2, 0.8), aspect_ratio=(0.75, 1.5), nenc=1, npred=4, min_keep=0, allow_overlap=False, ) ipe = len_data // batch_size self.momentum_scheduler = ( ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) for i in range(int(ipe * num_epochs * ipe_scale) + 1) ) def forward(self, x): mask_enc, mask_pred = self.mask_collator(x) with torch.no_grad(): h = self.target_encoder(x)[:, 1:, :] h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim B = len(h) # -- create targets (masked regions of h) h = apply_masks(h, mask_pred) h = repeat_interleave_batch(h, B, repeat=len(mask_enc)) return self.model(x, mask_enc, mask_pred), h def training_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): loss.pop("logits") for metric_name, metric_value in loss.items(): self.log( f"train/{metric_name}", metric_value, sync_dist=True, on_step=True, on_epoch=True, ) return loss def on_after_backward(self): with torch.no_grad(): m = next(self.momentum_scheduler) for param_q, param_k in zip( self.model.encoder.parameters(), self.target_encoder.parameters() ): param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) @torch.no_grad() def validation_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.val_metrics.update(loss["logits"]) loss.pop("logits") else: self.val_metrics.update(pred, batch) for metric_name, metric_value in loss.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def on_validation_epoch_end(self): metrics = self.val_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) @torch.no_grad() def test_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.test_metrics.update(loss["logits"]) loss.pop("logits") else: self.test_metrics.update(pred, batch) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def configure_optimizers(self): optimizer = self.optimizer(params=self.parameters()) if self.scheduler is not None: scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val/loss", "interval": "epoch", "frequency": 1, }, } return {"optimizer": optimizer} class ModuleMulti(LightningModule): def __init__( self, network, loss, train_metrics, val_metrics, test_metrics, scheduler, optimizer, ema, ipe_scale, ipe, batch_size, num_epochs, scales, shapes, devices, ): super().__init__() self.model = network.instance self.target_encoder = copy.deepcopy(self.model.encoder) for p in self.target_encoder.parameters(): p.requires_grad = False self.loss = loss self.train_metrics = train_metrics self.val_metrics = val_metrics self.test_metrics = test_metrics self.optimizer = optimizer self.scheduler = scheduler self.mask_collator = {} datasets = list(scales.keys()) for dataset in datasets: for scale in scales[dataset]: shape = shapes[dataset] // scale self.mask_collator["_".join([dataset, str(scale)])] = MaskCollator( input_size=(shape, shape), patch_size=1, enc_mask_scale=(0.85, 1.0), pred_mask_scale=(0.2, 0.8), aspect_ratio=(0.75, 1.5), nenc=1, npred=4, min_keep=0, allow_overlap=False, ) self.momentum_scheduler = ( ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) for i in range(int(ipe * num_epochs * ipe_scale) + 1) ) def forward(self, x): mask_enc, mask_pred = self.mask_collator["_".join([x["dataset"], str(x["scale"])])](x) with torch.no_grad(): h = self.target_encoder(x)[:, 1:, :] h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim B = len(h) # -- create targets (masked regions of h) h = apply_masks(h, mask_pred) h = repeat_interleave_batch(h, B, repeat=len(mask_enc)) return self.model(x, mask_enc, mask_pred), h def training_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): loss.pop("logits") for metric_name, metric_value in loss.items(): self.log( f"train/{metric_name}", metric_value, sync_dist=True, on_step=True, on_epoch=True, ) return loss def on_after_backward(self): with torch.no_grad(): m = next(self.momentum_scheduler) for param_q, param_k in zip( self.model.encoder.parameters(), self.target_encoder.parameters() ): param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) @torch.no_grad() def validation_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.val_metrics.update(loss["logits"], dataset=batch["dataset"]) loss.pop("logits") else: self.val_metrics.update(pred, batch) for metric_name, metric_value in loss.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def on_validation_epoch_end(self): metrics = self.val_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) @torch.no_grad() def test_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.test_metrics.update(loss["logits"], dataset=batch["dataset"]) loss.pop("logits") else: self.test_metrics.update(pred, batch) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def configure_optimizers(self): optimizer = self.optimizer(params=self.parameters()) if self.scheduler is not None: scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val/loss", "interval": "epoch", "frequency": 1, }, } return {"optimizer": optimizer} class ModuleMultiNaive(LightningModule): def __init__( self, network, loss, train_metrics, val_metrics, test_metrics, scheduler, optimizer, ema, ipe_scale, ipe, batch_size, num_epochs, scales, shapes, devices, ): super().__init__() self.model = network.instance self.target_encoder = copy.deepcopy(self.model.encoder) for p in self.target_encoder.parameters(): p.requires_grad = False self.loss = loss self.train_metrics = train_metrics self.val_metrics = val_metrics self.test_metrics = test_metrics self.optimizer = optimizer self.scheduler = scheduler self.mask_collator = {} datasets = list(scales.keys()) for dataset in datasets: for scale in scales[dataset]: shape = shapes[dataset] // scale self.mask_collator["_".join([dataset, str(scale)])] = MaskCollatorNaive( input_size=(shape, shape), patch_size=1 ) self.momentum_scheduler = ( ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) for i in range(int(ipe * num_epochs * ipe_scale) + 1) ) def forward(self, x): mask_enc, mask_pred = self.mask_collator["_".join([x["dataset"], str(x["scale"])])](x) with torch.no_grad(): h = self.target_encoder(x)[:, 1:, :] h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim B = len(h) # -- create targets (masked regions of h) h = apply_masks(h, mask_pred) h = repeat_interleave_batch(h, B, repeat=len(mask_enc)) return self.model(x, mask_enc, mask_pred), h def training_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): loss.pop("logits") for metric_name, metric_value in loss.items(): self.log( f"train/{metric_name}", metric_value, sync_dist=True, on_step=True, on_epoch=True, ) return loss def on_after_backward(self): with torch.no_grad(): m = next(self.momentum_scheduler) for param_q, param_k in zip( self.model.encoder.parameters(), self.target_encoder.parameters() ): param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) @torch.no_grad() def validation_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.val_metrics.update(loss["logits"], dataset=batch["dataset"]) loss.pop("logits") else: self.val_metrics.update(pred, batch) for metric_name, metric_value in loss.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def on_validation_epoch_end(self): metrics = self.val_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"val/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) @torch.no_grad() def test_step(self, batch, batch_idx): pred, target = self.forward(batch) batch["target"] = target loss = self.loss(pred, batch, average=True) if "logits" in loss.keys(): self.test_metrics.update(loss["logits"], dataset=batch["dataset"]) loss.pop("logits") else: self.test_metrics.update(pred, batch) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def configure_optimizers(self): optimizer = self.optimizer(params=self.parameters()) if self.scheduler is not None: scheduler = self.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val/loss", "interval": "epoch", "frequency": 1, }, } return {"optimizer": optimizer}