johnomeara commited on
Commit
58fa1b8
·
verified ·
1 Parent(s): 2ccf69f

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +5 -3
modeling.py CHANGED
@@ -1,12 +1,14 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import timm
4
  from huggingface_hub import PyTorchModelHubMixin
5
 
6
  class KeypointModel(nn.Module, PyTorchModelHubMixin):
7
- def __init__(self, heatmap_size=(512, 1024), **kwargs):
8
  super().__init__()
9
- config_heatmap_size = kwargs.get("config", {}).get("heatmap_size", heatmap_size)
 
10
 
11
  backbone = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=False)
12
 
@@ -15,7 +17,7 @@ class KeypointModel(nn.Module, PyTorchModelHubMixin):
15
  self.head = nn.Sequential(
16
  nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
17
  nn.ReLU(inplace=True),
18
- nn.Upsample(size=config_heatmap_size, mode='bilinear', align_corners=False),
19
  nn.Conv2d(256, 1, kernel_size=1)
20
  )
21
 
 
1
+
2
  import torch
3
  import torch.nn as nn
4
  import timm
5
  from huggingface_hub import PyTorchModelHubMixin
6
 
7
  class KeypointModel(nn.Module, PyTorchModelHubMixin):
8
+ def __init__(self, config, **kwargs):
9
  super().__init__()
10
+
11
+ upsample_size = config.heatmap_size
12
 
13
  backbone = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=False)
14
 
 
17
  self.head = nn.Sequential(
18
  nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
19
  nn.ReLU(inplace=True),
20
+ nn.Upsample(size=upsample_size, mode='bilinear', align_corners=False),
21
  nn.Conv2d(256, 1, kernel_size=1)
22
  )
23