johnomeara commited on
Commit
2ccf69f
·
verified ·
1 Parent(s): 4c227f0

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +25 -0
modeling.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class KeypointModel(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, heatmap_size=(512, 1024), **kwargs):
8
+ super().__init__()
9
+ config_heatmap_size = kwargs.get("config", {}).get("heatmap_size", heatmap_size)
10
+
11
+ backbone = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=False)
12
+
13
+ self.feature_extractor = nn.Sequential(*list(backbone.children())[:-2])
14
+ in_channels = backbone.num_features
15
+ self.head = nn.Sequential(
16
+ nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
17
+ nn.ReLU(inplace=True),
18
+ nn.Upsample(size=config_heatmap_size, mode='bilinear', align_corners=False),
19
+ nn.Conv2d(256, 1, kernel_size=1)
20
+ )
21
+
22
+ def forward(self, image):
23
+ features = self.feature_extractor(image)
24
+ heatmap = self.head(features)
25
+ return heatmap