Spaces:
Sleeping
Sleeping
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
class VGGLikeEncode(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 1, | |
out_channels: int = 128, | |
feature_dim: int = 32, | |
apply_pooling: bool = False | |
): | |
""" | |
VGG-like encoder for grayscale images. | |
:param in_channels: number of input channels | |
:param out_channels: number of output channels | |
:param feature_dim: number of channels in the intermediate layers | |
:param apply_pooling: whether to apply global average pooling at the end | |
""" | |
super().__init__() | |
self.apply_pooling = apply_pooling | |
self.block1 = nn.Sequential( | |
nn.Conv2d(in_channels, feature_dim, kernel_size=3, padding=1), | |
nn.BatchNorm2d(feature_dim), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2) | |
) | |
self.block2 = nn.Sequential( | |
nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(feature_dim * 2), | |
nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2) | |
) | |
self.block3 = nn.Sequential( | |
nn.Conv2d(feature_dim * 2, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(out_channels), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=1) | |
) | |
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.blocks = [self.block1, self.block2, self.block3] | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.block3(x) | |
if self.apply_pooling: | |
x = self.global_avg_pool(x).view(x.shape[0], -1) | |
return x | |
def get_conv_layer(self, block_num: int): | |
if block_num >= len(self.blocks): | |
return None | |
return self.blocks[block_num][0] | |
class CrossAttentionClassifier(nn.Module): | |
def __init__( | |
self, | |
feature_dim: int = 32, | |
num_heads: int = 4, | |
linear_dim: int = 128, | |
out_channels: int = 128, | |
encoder: Optional[VGGLikeEncode] = None | |
): | |
""" | |
Cross-attention classifier for comparing two grayscale images. | |
:param feature_dim: number of channels in the intermediate layers | |
:param num_heads: number of attention heads | |
:param linear_dim: number of units in the linear layer | |
:param out_channels: number of output channels | |
:param encoder: encoder to use | |
""" | |
super(CrossAttentionClassifier, self).__init__() | |
if encoder: | |
self.encoder = encoder | |
else: | |
self.encoder = VGGLikeEncode(in_channels=1, feature_dim=feature_dim, out_channels=out_channels) | |
self.out_channels = out_channels | |
self.seq_len = 8 * 8 | |
self.pos_embedding = nn.Parameter(torch.randn(self.seq_len, 1, out_channels) * 0.01) | |
self.cross_attention = nn.MultiheadAttention( | |
embed_dim=out_channels, | |
num_heads=num_heads, | |
batch_first=False | |
) | |
self.norm = nn.LayerNorm(out_channels) | |
self.classifier = nn.Sequential( | |
nn.Linear(out_channels, linear_dim), | |
nn.ReLU(), | |
nn.Linear(linear_dim, 1) | |
) | |
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: | |
feat1 = self.encoder(img1) | |
feat2 = self.encoder(img2) | |
B, C, H, W = feat1.shape | |
seq_len = H * W | |
feat1_flat = feat1.view(B, C, seq_len).permute(2, 0, 1) | |
feat2_flat = feat2.view(B, C, seq_len).permute(2, 0, 1) | |
feat1_flat = feat1_flat + self.pos_embedding | |
feat2_flat = feat2_flat + self.pos_embedding | |
feat1_flat = self.norm(feat1_flat) | |
feat2_flat = self.norm(feat2_flat) | |
attn_output, attn_weights = self.cross_attention( | |
query=feat1_flat, | |
key=feat2_flat, | |
value=feat2_flat, | |
need_weights=True, | |
average_attn_weights=True | |
) | |
pooled_features = attn_output.mean(dim=0) | |
logits = self.classifier(pooled_features).squeeze(-1) | |
return logits, attn_weights | |
class NormalizedMSELoss(nn.Module): | |
def __init__(self): | |
""" | |
Normalized MSE loss for BYOL training. | |
""" | |
super(NormalizedMSELoss, self).__init__() | |
def forward(self, view1: Tensor, view2: Tensor) -> Tensor: | |
v1 = F.normalize(view1, dim=-1) | |
v2 = F.normalize(view2, dim=-1) | |
return 2 - 2 * (v1 * v2).sum(dim=-1) | |
class MLP(nn.Module): | |
def __init__(self, input_dim: int, projection_dim: int = 128, hidden_dim: int = 512): | |
""" | |
MLP for BYOL training. | |
:param input_dim: input dimension | |
:param projection_dim: projection dimension | |
:param hidden_dim: hidden dimension | |
""" | |
super(MLP, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, projection_dim) | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.net(x) | |
class EncoderProjecter(nn.Module): | |
def __init__(self, encoder: nn.Module, hidden_dim: int = 512, projection_out_dim: int = 128): | |
""" | |
Encoder followed by a projection MLP. | |
:param encoder: encoder to use | |
:param hidden_dim: hidden dimension | |
:param projection_out_dim: projection output dimension | |
""" | |
super(EncoderProjecter, self).__init__() | |
self.encoder = encoder | |
self.projection = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) | |
def forward(self, x: Tensor) -> Tensor: | |
h = self.encoder(x) | |
return self.projection(h) | |
# https://arxiv.org/pdf/2006.07733 | |
class BYOL(nn.Module): | |
def __init__( | |
self, | |
hidden_dim: int = 512, | |
projection_out_dim: int = 128, | |
target_decay: float = 0.9975 | |
): | |
""" | |
BYOL model for self-supervised learning. | |
:param hidden_dim: hidden dimension | |
:param projection_out_dim: projection output dimension | |
:param target_decay: target network decay rate | |
""" | |
super(BYOL, self).__init__() | |
encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=True) | |
self.online_network = EncoderProjecter(encoder) | |
self.online_predictor = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) | |
self.target_network = EncoderProjecter(encoder) | |
self.target_network.load_state_dict(self.online_network.state_dict()) | |
self.target_network.eval() | |
for param in self.target_network.parameters(): | |
param.requires_grad = False | |
self.target_decay = target_decay | |
self.loss_function = NormalizedMSELoss() | |
def soft_update_target_network(self): | |
for online_p, target_p in zip(self.online_network.parameters(), self.target_network.parameters()): | |
target_p.data = target_p.data * self.target_decay + online_p.data * (1. - self.target_decay) | |
def forward(self, view: Tensor) -> Tuple[Tensor, Tensor]: | |
online_proj = self.online_network(view) | |
target_proj = self.target_network(view) | |
return online_proj, target_proj | |
def loss(self, view1: Tensor, view2: Tensor) -> Tensor: | |
online_proj1, target_proj1 = self(view1) | |
online_proj2, target_proj2 = self(view2) | |
online_prediction_1 = self.online_predictor(online_proj1) | |
online_prediction_2 = self.online_predictor(online_proj2) | |
loss1 = self.loss_function(online_prediction_1, target_proj2.detach()) | |
loss2 = self.loss_function(online_prediction_2, target_proj1.detach()) | |
return torch.mean(loss1 + loss2) | |