import torch from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig import torch.nn.functional as F class BertGTSOpinionTripleConfig(PretrainedConfig): model_type = 'multi-infer-bert-uncased' def __init__(self, feat_dim=768, max_len=128, class_num=6, **kwargs): super().__init__(**kwargs) self.feat_dim = feat_dim self.max_len = max_len self.class_num = class_num class BertGTSOpinionTriple(PreTrainedModel): config_class = BertGTSOpinionTripleConfig def __init__(self, config): model_id = 'google-bert/bert-base-uncased' super().__init__(config) self.model = AutoModel.from_pretrained(model_id) self.max_seq_len = config.max_len self.bert_feat_dim = config.feat_dim #768 self.class_num = config.class_num #6 self.cls_linear = torch.nn.Linear(self.bert_feat_dim * 2, self.class_num) self.feature_linear = torch.nn.Linear(self.bert_feat_dim * 2 + self.class_num * 3, self.bert_feat_dim * 2) self.dropout_output = torch.nn.Dropout(0.1) self.post_init() def multi_hops(self, features, mask, k): max_length = features.shape[1] mask = mask[:, :max_length] mask_a = mask.unsqueeze(1).expand([-1, max_length, -1]) mask_b = mask.unsqueeze(2).expand([-1, -1, max_length]) mask = mask_a * mask_b mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num]) '''save all logits''' logits_list = [] logits = self.cls_linear(features) logits_list.append(logits) for i in range(k): #probs = torch.softmax(logits, dim=3) probs = logits logits = probs * mask logits_a = torch.max(logits, dim=1)[0] logits_b = torch.max(logits, dim=2)[0] logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3) logits = torch.max(logits, dim=3)[0] logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1]) logits_T = logits.transpose(1, 2) logits = torch.cat([logits, logits_T], dim=3) new_features = torch.cat([features, logits, probs], dim=3) features = self.feature_linear(new_features) logits = self.cls_linear(features) logits_list.append(logits) return logits_list def forward(self, input_ids, attention_masks, labels=None): # rename if required model_feature = self.model(input_ids, attention_masks) model_feature = model_feature.last_hidden_state.detach() bert_feature = self.dropout_output(model_feature) bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1]) bert_feature_T = bert_feature.transpose(1, 2) features = torch.cat([bert_feature, bert_feature_T], dim=3) logits = self.multi_hops(features, attention_masks, 1) fin_logits = logits[-1] loss = None if labels is not None: ## preforming the loss operation, crosscheck with the previous impl gold_floss = labels.reshape([-1]) pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]]) loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1) return {'logits': fin_logits, 'loss': loss}