Spaces:
Runtime error
Runtime error
| import copy | |
| import torch | |
| from pytorch_lightning import LightningModule | |
| from torch import Tensor | |
| from torch.optim import SGD | |
| from torch.nn import Identity | |
| from torchvision.models import resnet50 | |
| from lightly.loss import DINOLoss | |
| from lightly.models.modules import DINOProjectionHead | |
| from lightly.models.utils import ( | |
| activate_requires_grad, | |
| deactivate_requires_grad, | |
| get_weight_decay_parameters, | |
| update_momentum, | |
| ) | |
| from lightly.utils.benchmarking import OnlineLinearClassifier | |
| from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule | |
| from typing import Union, Tuple, List | |
| class DINO(LightningModule): | |
| def __init__(self, batch_size_per_device: int, num_classes: int) -> None: | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.batch_size_per_device = batch_size_per_device | |
| resnet = resnet50() | |
| resnet.fc = Identity() # Ignore classification head | |
| self.backbone = resnet | |
| self.projection_head = DINOProjectionHead(freeze_last_layer=1) | |
| self.student_backbone = copy.deepcopy(self.backbone) | |
| self.student_projection_head = DINOProjectionHead() | |
| self.criterion = DINOLoss(output_dim=65536) | |
| self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.backbone(x) | |
| def forward_student(self, x: Tensor) -> Tensor: | |
| features = self.student_backbone(x).flatten(start_dim=1) | |
| projections = self.student_projection_head(features) | |
| return projections | |
| def on_train_start(self) -> None: | |
| deactivate_requires_grad(self.backbone) | |
| deactivate_requires_grad(self.projection_head) | |
| def on_train_end(self) -> None: | |
| activate_requires_grad(self.backbone) | |
| activate_requires_grad(self.projection_head) | |
| def training_step( | |
| self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int | |
| ) -> Tensor: | |
| # Momentum update teacher. | |
| momentum = cosine_schedule( | |
| step=self.trainer.global_step, | |
| max_steps=self.trainer.estimated_stepping_batches, | |
| start_value=0.996, | |
| end_value=1.0, | |
| ) | |
| update_momentum(self.student_backbone, self.backbone, m=momentum) | |
| update_momentum(self.student_projection_head, self.projection_head, m=momentum) | |
| views, targets = batch[0], batch[1] | |
| global_views = torch.cat(views[:2]) | |
| local_views = torch.cat(views[2:]) | |
| teacher_features = self.forward(global_views).flatten(start_dim=1) | |
| teacher_projections = self.projection_head(teacher_features) | |
| student_projections = torch.cat( | |
| [self.forward_student(global_views), self.forward_student(local_views)] | |
| ) | |
| loss = self.criterion( | |
| teacher_out=teacher_projections.chunk(2), | |
| student_out=student_projections.chunk(len(views)), | |
| epoch=self.current_epoch, | |
| ) | |
| self.log_dict( | |
| {"train_loss": loss, "ema_momentum": momentum}, | |
| prog_bar=True, | |
| sync_dist=True, | |
| batch_size=len(targets), | |
| ) | |
| # Online classification. | |
| cls_loss, cls_log = self.online_classifier.training_step( | |
| (teacher_features.chunk(2)[0].detach(), targets), batch_idx | |
| ) | |
| self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) | |
| return loss + cls_loss | |
| def validation_step( | |
| self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int | |
| ) -> Tensor: | |
| images, targets = batch[0], batch[1] | |
| features = self.forward(images).flatten(start_dim=1) | |
| cls_loss, cls_log = self.online_classifier.validation_step( | |
| (features.detach(), targets), batch_idx | |
| ) | |
| self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) | |
| return cls_loss | |
| def configure_optimizers(self): | |
| # Don't use weight decay for batch norm, bias parameters, and classification | |
| # head to improve performance. | |
| params, params_no_weight_decay = get_weight_decay_parameters( | |
| [self.student_backbone, self.student_projection_head] | |
| ) | |
| # For ResNet50 we use SGD instead of AdamW/LARS as recommended by the authors: | |
| # https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings | |
| optimizer = SGD( | |
| [ | |
| {"name": "dino", "params": params}, | |
| { | |
| "name": "dino_no_weight_decay", | |
| "params": params_no_weight_decay, | |
| "weight_decay": 0.0, | |
| }, | |
| { | |
| "name": "online_classifier", | |
| "params": self.online_classifier.parameters(), | |
| "weight_decay": 0.0, | |
| }, | |
| ], | |
| lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256, | |
| momentum=0.9, | |
| weight_decay=1e-4, | |
| ) | |
| scheduler = { | |
| "scheduler": CosineWarmupScheduler( | |
| optimizer=optimizer, | |
| warmup_epochs=int( | |
| self.trainer.estimated_stepping_batches | |
| / self.trainer.max_epochs | |
| * 10 | |
| ), | |
| max_epochs=int(self.trainer.estimated_stepping_batches), | |
| ), | |
| "interval": "step", | |
| } | |
| return [optimizer], [scheduler] | |
| def configure_gradient_clipping( | |
| self, | |
| optimizer, | |
| gradient_clip_val: Union[int, float, None] = None, | |
| gradient_clip_algorithm: Union[str, None] = None, | |
| ) -> None: | |
| self.clip_gradients( | |
| optimizer=optimizer, | |
| gradient_clip_val=3.0, | |
| gradient_clip_algorithm="norm", | |
| ) | |
| self.student_projection_head.cancel_last_layer_gradients(self.current_epoch) | |