|
import torch |
|
import torch.nn as nn |
|
import timm |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, embed_dim=1280, num_heads=8, ff_dim=3072, dropout=0.1): |
|
super(TransformerBlock, self).__init__() |
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) |
|
self.norm1 = nn.LayerNorm(embed_dim) |
|
self.norm2 = nn.LayerNorm(embed_dim) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(embed_dim, ff_dim), |
|
nn.GELU(), |
|
nn.Linear(ff_dim, embed_dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
x = x.unsqueeze(1) |
|
x = x.permute(1, 0, 2) |
|
|
|
attn_output, _ = self.attn(x, x, x) |
|
x = self.norm1(x + attn_output) |
|
|
|
ffn_output = self.ffn(x) |
|
x = self.norm2(x + ffn_output) |
|
|
|
x = x.permute(1, 0, 2) |
|
return x |
|
|
|
class EfficientNetBackbone(nn.Module): |
|
def __init__(self): |
|
super(EfficientNetBackbone, self).__init__() |
|
self.model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0, global_pool='avg') |
|
self.out_features = 1280 |
|
|
|
def forward(self, x): |
|
x = self.model(x) |
|
return x |
|
|
|
class CNNViT(nn.Module): |
|
def __init__(self, num_classes=5): |
|
super(CNNViT, self).__init__() |
|
self.cnn_backbone = EfficientNetBackbone() |
|
self.transformer = TransformerBlock(embed_dim=1280, num_heads=8, ff_dim=3072) |
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(1280, 512), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(256, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.cnn_backbone(x) |
|
x = self.transformer(x) |
|
x = x.squeeze(1) |
|
x = self.fc(x) |
|
return x |
|
|
|
model_Hybrid = CNNViT(num_classes=5) |
|
|