Spaces:
Sleeping
Sleeping
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| class VGGLikeEncode(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| out_channels: int = 128, | |
| feature_dim: int = 32, | |
| apply_pooling: bool = False | |
| ): | |
| """ | |
| VGG-like encoder for grayscale images. | |
| :param in_channels: number of input channels | |
| :param out_channels: number of output channels | |
| :param feature_dim: number of channels in the intermediate layers | |
| :param apply_pooling: whether to apply global average pooling at the end | |
| """ | |
| super().__init__() | |
| self.apply_pooling = apply_pooling | |
| self.block1 = nn.Sequential( | |
| nn.Conv2d(in_channels, feature_dim, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(feature_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=2) | |
| ) | |
| self.block2 = nn.Sequential( | |
| nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(feature_dim * 2), | |
| nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=2) | |
| ) | |
| self.block3 = nn.Sequential( | |
| nn.Conv2d(feature_dim * 2, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(out_channels), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=1) | |
| ) | |
| self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.blocks = [self.block1, self.block2, self.block3] | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| x = self.block3(x) | |
| if self.apply_pooling: | |
| x = self.global_avg_pool(x).view(x.shape[0], -1) | |
| return x | |
| def get_conv_layer(self, block_num: int): | |
| if block_num >= len(self.blocks): | |
| return None | |
| return self.blocks[block_num][0] | |
| class CrossAttentionClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| feature_dim: int = 32, | |
| num_heads: int = 4, | |
| linear_dim: int = 128, | |
| out_channels: int = 128, | |
| encoder: Optional[VGGLikeEncode] = None | |
| ): | |
| """ | |
| Cross-attention classifier for comparing two grayscale images. | |
| :param feature_dim: number of channels in the intermediate layers | |
| :param num_heads: number of attention heads | |
| :param linear_dim: number of units in the linear layer | |
| :param out_channels: number of output channels | |
| :param encoder: encoder to use | |
| """ | |
| super(CrossAttentionClassifier, self).__init__() | |
| if encoder: | |
| self.encoder = encoder | |
| else: | |
| self.encoder = VGGLikeEncode(in_channels=1, feature_dim=feature_dim, out_channels=out_channels) | |
| self.out_channels = out_channels | |
| self.seq_len = 8 * 8 | |
| self.pos_embedding = nn.Parameter(torch.randn(self.seq_len, 1, out_channels) * 0.01) | |
| self.cross_attention = nn.MultiheadAttention( | |
| embed_dim=out_channels, | |
| num_heads=num_heads, | |
| batch_first=False | |
| ) | |
| self.norm = nn.LayerNorm(out_channels) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(out_channels, linear_dim), | |
| nn.ReLU(), | |
| nn.Linear(linear_dim, 1) | |
| ) | |
| def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: | |
| feat1 = self.encoder(img1) | |
| feat2 = self.encoder(img2) | |
| B, C, H, W = feat1.shape | |
| seq_len = H * W | |
| feat1_flat = feat1.view(B, C, seq_len).permute(2, 0, 1) | |
| feat2_flat = feat2.view(B, C, seq_len).permute(2, 0, 1) | |
| feat1_flat = feat1_flat + self.pos_embedding | |
| feat2_flat = feat2_flat + self.pos_embedding | |
| feat1_flat = self.norm(feat1_flat) | |
| feat2_flat = self.norm(feat2_flat) | |
| attn_output, attn_weights = self.cross_attention( | |
| query=feat1_flat, | |
| key=feat2_flat, | |
| value=feat2_flat, | |
| need_weights=True, | |
| average_attn_weights=True | |
| ) | |
| pooled_features = attn_output.mean(dim=0) | |
| logits = self.classifier(pooled_features).squeeze(-1) | |
| return logits, attn_weights | |
| class NormalizedMSELoss(nn.Module): | |
| def __init__(self): | |
| """ | |
| Normalized MSE loss for BYOL training. | |
| """ | |
| super(NormalizedMSELoss, self).__init__() | |
| def forward(self, view1: Tensor, view2: Tensor) -> Tensor: | |
| v1 = F.normalize(view1, dim=-1) | |
| v2 = F.normalize(view2, dim=-1) | |
| return 2 - 2 * (v1 * v2).sum(dim=-1) | |
| class MLP(nn.Module): | |
| def __init__(self, input_dim: int, projection_dim: int = 128, hidden_dim: int = 512): | |
| """ | |
| MLP for BYOL training. | |
| :param input_dim: input dimension | |
| :param projection_dim: projection dimension | |
| :param hidden_dim: hidden dimension | |
| """ | |
| super(MLP, self).__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, projection_dim) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.net(x) | |
| class EncoderProjecter(nn.Module): | |
| def __init__(self, encoder: nn.Module, hidden_dim: int = 512, projection_out_dim: int = 128): | |
| """ | |
| Encoder followed by a projection MLP. | |
| :param encoder: encoder to use | |
| :param hidden_dim: hidden dimension | |
| :param projection_out_dim: projection output dimension | |
| """ | |
| super(EncoderProjecter, self).__init__() | |
| self.encoder = encoder | |
| self.projection = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) | |
| def forward(self, x: Tensor) -> Tensor: | |
| h = self.encoder(x) | |
| return self.projection(h) | |
| # https://arxiv.org/pdf/2006.07733 | |
| class BYOL(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int = 512, | |
| projection_out_dim: int = 128, | |
| target_decay: float = 0.9975 | |
| ): | |
| """ | |
| BYOL model for self-supervised learning. | |
| :param hidden_dim: hidden dimension | |
| :param projection_out_dim: projection output dimension | |
| :param target_decay: target network decay rate | |
| """ | |
| super(BYOL, self).__init__() | |
| encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=True) | |
| self.online_network = EncoderProjecter(encoder) | |
| self.online_predictor = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) | |
| self.target_network = EncoderProjecter(encoder) | |
| self.target_network.load_state_dict(self.online_network.state_dict()) | |
| self.target_network.eval() | |
| for param in self.target_network.parameters(): | |
| param.requires_grad = False | |
| self.target_decay = target_decay | |
| self.loss_function = NormalizedMSELoss() | |
| def soft_update_target_network(self): | |
| for online_p, target_p in zip(self.online_network.parameters(), self.target_network.parameters()): | |
| target_p.data = target_p.data * self.target_decay + online_p.data * (1. - self.target_decay) | |
| def forward(self, view: Tensor) -> Tuple[Tensor, Tensor]: | |
| online_proj = self.online_network(view) | |
| target_proj = self.target_network(view) | |
| return online_proj, target_proj | |
| def loss(self, view1: Tensor, view2: Tensor) -> Tensor: | |
| online_proj1, target_proj1 = self(view1) | |
| online_proj2, target_proj2 = self(view2) | |
| online_prediction_1 = self.online_predictor(online_proj1) | |
| online_prediction_2 = self.online_predictor(online_proj2) | |
| loss1 = self.loss_function(online_prediction_1, target_proj2.detach()) | |
| loss2 = self.loss_function(online_prediction_2, target_proj1.detach()) | |
| return torch.mean(loss1 + loss2) | |