Spaces:
Build error
Build error
| import torch | |
| class HeadVQA(torch.nn.Module): | |
| def __init__(self, train_config): | |
| super().__init__() | |
| embedding_size = {'RN50': 1024, | |
| 'RN101': 512, | |
| 'RN50x4': 640, | |
| 'RN50x16': 768, | |
| 'RN50x64': 1024, | |
| 'ViT-B/32': 512, | |
| 'ViT-B/16': 512, | |
| 'ViT-L/14': 768, | |
| 'ViT-L/14@336px': 768} | |
| n_aux_classes = len(set(train_config.aux_mapping.values())) | |
| self.ln1 = torch.nn.LayerNorm(embedding_size[train_config.model]*2) | |
| self.dp1 = torch.nn.Dropout(0.5) | |
| self.fc1 = torch.nn.Linear(embedding_size[train_config.model] * 2, 512) | |
| self.ln2 = torch.nn.LayerNorm(512) | |
| self.dp2 = torch.nn.Dropout(0.5) | |
| self.fc2 = torch.nn.Linear(512, train_config.n_classes) | |
| self.fc_aux = torch.nn.Linear(512, n_aux_classes) | |
| self.fc_gate = torch.nn.Linear(n_aux_classes, train_config.n_classes) | |
| self.act_gate = torch.nn.Sigmoid() | |
| def forward(self, img_features, question_features): | |
| xc = torch.cat((img_features, question_features), dim=-1) | |
| x = self.ln1(xc) | |
| x = self.dp1(x) | |
| x = self.fc1(x) | |
| aux = self.fc_aux(x) | |
| gate = self.fc_gate(aux) | |
| gate = self.act_gate(gate) | |
| x = self.ln2(x) | |
| x = self.dp2(x) | |
| vqa = self.fc2(x) | |
| output = vqa * gate | |
| return output, aux | |
| class NetVQA(torch.nn.Module): | |
| def __init__(self, train_config): | |
| super().__init__() | |
| self.heads = torch.nn.ModuleList() | |
| if isinstance(train_config.folds, list): | |
| self.num_heads = len(train_config.folds) | |
| else: | |
| self.num_heads = train_config.folds | |
| for i in range(self.num_heads): | |
| self.heads.append(HeadVQA(train_config)) | |
| def forward(self, img_features, question_features): | |
| output = [] | |
| output_aux = [] | |
| for head in self.heads: | |
| logits, logits_aux = head(img_features, question_features) | |
| probs = logits.softmax(-1) | |
| probs_aux = logits_aux.softmax(-1) | |
| output.append(probs) | |
| output_aux.append(probs_aux) | |
| output = torch.stack(output, dim=-1).mean(-1) | |
| output_aux = torch.stack(output_aux, dim=-1).mean(-1) | |
| return output, output_aux | |
| def merge_vqa(train_config): | |
| # Initialize model | |
| model = NetVQA(train_config) | |
| for fold in train_config.folds: | |
| print("load weights from fold {} into head {}".format(fold, fold)) | |
| checkpoint_path = "{}/{}/fold_{}".format(train_config.model_path, train_config.model, fold) | |
| if train_config.crossvalidation: | |
| # load best checkpoint | |
| model_state_dict = torch.load('{}/weights_best.pth'.format(checkpoint_path)) | |
| else: | |
| # load checkpoint on train end | |
| model_state_dict = torch.load('{}/weights_end.pth'.format(checkpoint_path)) | |
| model.heads[fold].load_state_dict(model_state_dict, strict=True) | |
| checkpoint_path = "{}/{}/weights_merged.pth".format(train_config.model_path, train_config.model) | |
| print("Saving weights of merged model:", checkpoint_path) | |
| torch.save(model.state_dict(), checkpoint_path) | |
| return model | |