cels / src /models.py
alexandraroze's picture
solution
50bd1fc
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class VGGLikeEncode(nn.Module):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 128,
feature_dim: int = 32,
apply_pooling: bool = False
):
"""
VGG-like encoder for grayscale images.
:param in_channels: number of input channels
:param out_channels: number of output channels
:param feature_dim: number of channels in the intermediate layers
:param apply_pooling: whether to apply global average pooling at the end
"""
super().__init__()
self.apply_pooling = apply_pooling
self.block1 = nn.Sequential(
nn.Conv2d(in_channels, feature_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=True),
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.block2 = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(feature_dim * 2),
nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.block3 = nn.Sequential(
nn.Conv2d(feature_dim * 2, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=1)
)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.blocks = [self.block1, self.block2, self.block3]
def forward(self, x: Tensor) -> Tensor:
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
if self.apply_pooling:
x = self.global_avg_pool(x).view(x.shape[0], -1)
return x
def get_conv_layer(self, block_num: int):
if block_num >= len(self.blocks):
return None
return self.blocks[block_num][0]
class CrossAttentionClassifier(nn.Module):
def __init__(
self,
feature_dim: int = 32,
num_heads: int = 4,
linear_dim: int = 128,
out_channels: int = 128,
encoder: Optional[VGGLikeEncode] = None
):
"""
Cross-attention classifier for comparing two grayscale images.
:param feature_dim: number of channels in the intermediate layers
:param num_heads: number of attention heads
:param linear_dim: number of units in the linear layer
:param out_channels: number of output channels
:param encoder: encoder to use
"""
super(CrossAttentionClassifier, self).__init__()
if encoder:
self.encoder = encoder
else:
self.encoder = VGGLikeEncode(in_channels=1, feature_dim=feature_dim, out_channels=out_channels)
self.out_channels = out_channels
self.seq_len = 8 * 8
self.pos_embedding = nn.Parameter(torch.randn(self.seq_len, 1, out_channels) * 0.01)
self.cross_attention = nn.MultiheadAttention(
embed_dim=out_channels,
num_heads=num_heads,
batch_first=False
)
self.norm = nn.LayerNorm(out_channels)
self.classifier = nn.Sequential(
nn.Linear(out_channels, linear_dim),
nn.ReLU(),
nn.Linear(linear_dim, 1)
)
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
feat1 = self.encoder(img1)
feat2 = self.encoder(img2)
B, C, H, W = feat1.shape
seq_len = H * W
feat1_flat = feat1.view(B, C, seq_len).permute(2, 0, 1)
feat2_flat = feat2.view(B, C, seq_len).permute(2, 0, 1)
feat1_flat = feat1_flat + self.pos_embedding
feat2_flat = feat2_flat + self.pos_embedding
feat1_flat = self.norm(feat1_flat)
feat2_flat = self.norm(feat2_flat)
attn_output, attn_weights = self.cross_attention(
query=feat1_flat,
key=feat2_flat,
value=feat2_flat,
need_weights=True,
average_attn_weights=True
)
pooled_features = attn_output.mean(dim=0)
logits = self.classifier(pooled_features).squeeze(-1)
return logits, attn_weights
class NormalizedMSELoss(nn.Module):
def __init__(self):
"""
Normalized MSE loss for BYOL training.
"""
super(NormalizedMSELoss, self).__init__()
def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
v1 = F.normalize(view1, dim=-1)
v2 = F.normalize(view2, dim=-1)
return 2 - 2 * (v1 * v2).sum(dim=-1)
class MLP(nn.Module):
def __init__(self, input_dim: int, projection_dim: int = 128, hidden_dim: int = 512):
"""
MLP for BYOL training.
:param input_dim: input dimension
:param projection_dim: projection dimension
:param hidden_dim: hidden dimension
"""
super(MLP, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
class EncoderProjecter(nn.Module):
def __init__(self, encoder: nn.Module, hidden_dim: int = 512, projection_out_dim: int = 128):
"""
Encoder followed by a projection MLP.
:param encoder: encoder to use
:param hidden_dim: hidden dimension
:param projection_out_dim: projection output dimension
"""
super(EncoderProjecter, self).__init__()
self.encoder = encoder
self.projection = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim)
def forward(self, x: Tensor) -> Tensor:
h = self.encoder(x)
return self.projection(h)
# https://arxiv.org/pdf/2006.07733
class BYOL(nn.Module):
def __init__(
self,
hidden_dim: int = 512,
projection_out_dim: int = 128,
target_decay: float = 0.9975
):
"""
BYOL model for self-supervised learning.
:param hidden_dim: hidden dimension
:param projection_out_dim: projection output dimension
:param target_decay: target network decay rate
"""
super(BYOL, self).__init__()
encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=True)
self.online_network = EncoderProjecter(encoder)
self.online_predictor = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim)
self.target_network = EncoderProjecter(encoder)
self.target_network.load_state_dict(self.online_network.state_dict())
self.target_network.eval()
for param in self.target_network.parameters():
param.requires_grad = False
self.target_decay = target_decay
self.loss_function = NormalizedMSELoss()
@torch.no_grad()
def soft_update_target_network(self):
for online_p, target_p in zip(self.online_network.parameters(), self.target_network.parameters()):
target_p.data = target_p.data * self.target_decay + online_p.data * (1. - self.target_decay)
def forward(self, view: Tensor) -> Tuple[Tensor, Tensor]:
online_proj = self.online_network(view)
target_proj = self.target_network(view)
return online_proj, target_proj
def loss(self, view1: Tensor, view2: Tensor) -> Tensor:
online_proj1, target_proj1 = self(view1)
online_proj2, target_proj2 = self(view2)
online_prediction_1 = self.online_predictor(online_proj1)
online_prediction_2 = self.online_predictor(online_proj2)
loss1 = self.loss_function(online_prediction_1, target_proj2.detach())
loss2 = self.loss_function(online_prediction_2, target_proj1.detach())
return torch.mean(loss1 + loss2)