File size: 3,363 Bytes
ccfef63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig
import torch.nn.functional as F


class BertGTSOpinionTripleConfig(PretrainedConfig):
    model_type = 'multi-infer-bert-uncased'

    def __init__(self, feat_dim=768, max_len=128, class_num=6, **kwargs):
        super().__init__(**kwargs)
        self.feat_dim = feat_dim
        self.max_len = max_len
        self.class_num = class_num


class BertGTSOpinionTriple(PreTrainedModel):
    config_class = BertGTSOpinionTripleConfig

    def __init__(self, config):
        model_id = 'google-bert/bert-base-uncased'
        super().__init__(config)
        self.model = AutoModel.from_pretrained(model_id)
        self.max_seq_len = config.max_len
        self.bert_feat_dim = config.feat_dim  #768
        self.class_num = config.class_num  #6
        self.cls_linear = torch.nn.Linear(self.bert_feat_dim * 2, self.class_num)
        self.feature_linear = torch.nn.Linear(self.bert_feat_dim * 2 + self.class_num * 3, self.bert_feat_dim * 2)
        self.dropout_output = torch.nn.Dropout(0.1)
        self.post_init()

    def multi_hops(self, features, mask, k):
        max_length = features.shape[1]
        mask = mask[:, :max_length]
        mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
        mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
        mask = mask_a * mask_b
        mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num])

        '''save all logits'''
        logits_list = []
        logits = self.cls_linear(features)
        logits_list.append(logits)
        for i in range(k):
            #probs = torch.softmax(logits, dim=3)
            probs = logits
            logits = probs * mask
            logits_a = torch.max(logits, dim=1)[0]
            logits_b = torch.max(logits, dim=2)[0]
            logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
            logits = torch.max(logits, dim=3)[0]

            logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1])
            logits_T = logits.transpose(1, 2)
            logits = torch.cat([logits, logits_T], dim=3)

            new_features = torch.cat([features, logits, probs], dim=3)
            features = self.feature_linear(new_features)
            logits = self.cls_linear(features)
            logits_list.append(logits)
        return logits_list

    def forward(self, input_ids, attention_masks, labels=None):  # rename if required
        model_feature = self.model(input_ids, attention_masks)
        model_feature = model_feature.last_hidden_state.detach()
        bert_feature = self.dropout_output(model_feature)
        bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1])
        bert_feature_T = bert_feature.transpose(1, 2)
        features = torch.cat([bert_feature, bert_feature_T], dim=3)
        logits = self.multi_hops(features, attention_masks, 1)
        fin_logits = logits[-1]
        loss = None
        if labels is not None:
            ## preforming the loss operation, crosscheck with the previous impl
            gold_floss = labels.reshape([-1])
            pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]])
            loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1)
        return {'logits': fin_logits, 'loss': loss}