PavanKumarAmbadapudi's picture
Successfully Added
27dfecc
import torch
import torch.nn as nn
import timm
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=1280, num_heads=8, ff_dim=3072, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x.unsqueeze(1)
x = x.permute(1, 0, 2)
attn_output, _ = self.attn(x, x, x)
x = self.norm1(x + attn_output)
ffn_output = self.ffn(x)
x = self.norm2(x + ffn_output)
x = x.permute(1, 0, 2)
return x
class EfficientNetBackbone(nn.Module):
def __init__(self):
super(EfficientNetBackbone, self).__init__()
self.model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0, global_pool='avg')
self.out_features = 1280
def forward(self, x):
x = self.model(x)
return x
class CNNViT(nn.Module):
def __init__(self, num_classes=5):
super(CNNViT, self).__init__()
self.cnn_backbone = EfficientNetBackbone()
self.transformer = TransformerBlock(embed_dim=1280, num_heads=8, ff_dim=3072)
self.fc = nn.Sequential(
nn.Linear(1280, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.cnn_backbone(x)
x = self.transformer(x)
x = x.squeeze(1)
x = self.fc(x)
return x
model_Hybrid = CNNViT(num_classes=5)