leeyunjai commited on
Commit
06af277
·
1 Parent(s): 6b38ff0

Update caption_model.py

Browse files
Files changed (1) hide show
  1. caption_model.py +57 -0
caption_model.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from utils import NestedTensor, nested_tensor_from_tensor_list
6
+ from backbone import build_backbone
7
+ from transformer import build_transformer
8
+
9
+
10
+ class Caption(nn.Module):
11
+ def __init__(self, backbone, transformer, hidden_dim, vocab_size):
12
+ super().__init__()
13
+ self.backbone = backbone
14
+ self.input_proj = nn.Conv2d(
15
+ backbone.num_channels, hidden_dim, kernel_size=1)
16
+ self.transformer = transformer
17
+ self.mlp = MLP(hidden_dim, 512, vocab_size, 3)
18
+
19
+ def forward(self, samples, target, target_mask):
20
+ if not isinstance(samples, NestedTensor):
21
+ samples = nested_tensor_from_tensor_list(samples)
22
+
23
+ features, pos = self.backbone(samples)
24
+ src, mask = features[-1].decompose()
25
+
26
+ assert mask is not None
27
+
28
+ hs = self.transformer(self.input_proj(src), mask,
29
+ pos[-1], target, target_mask)
30
+ out = self.mlp(hs.permute(1, 0, 2))
31
+ return out
32
+
33
+
34
+ class MLP(nn.Module):
35
+ """ Very simple multi-layer perceptron (also called FFN)"""
36
+
37
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
38
+ super().__init__()
39
+ self.num_layers = num_layers
40
+ h = [hidden_dim] * (num_layers - 1)
41
+ self.layers = nn.ModuleList(nn.Linear(n, k)
42
+ for n, k in zip([input_dim] + h, h + [output_dim]))
43
+
44
+ def forward(self, x):
45
+ for i, layer in enumerate(self.layers):
46
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
47
+ return x
48
+
49
+
50
+ def build_model(config):
51
+ backbone = build_backbone(config)
52
+ transformer = build_transformer(config)
53
+
54
+ model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size)
55
+ criterion = torch.nn.CrossEntropyLoss()
56
+
57
+ return model, criterion