|
import torch |
|
from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig |
|
import torch.nn.functional as F |
|
|
|
|
|
class BertGTSOpinionTripleConfig(PretrainedConfig): |
|
model_type = 'multi-infer-bert-uncased' |
|
|
|
def __init__(self, feat_dim=768, max_len=128, class_num=6, **kwargs): |
|
super().__init__(**kwargs) |
|
self.feat_dim = feat_dim |
|
self.max_len = max_len |
|
self.class_num = class_num |
|
|
|
|
|
class BertGTSOpinionTriple(PreTrainedModel): |
|
config_class = BertGTSOpinionTripleConfig |
|
|
|
def __init__(self, config): |
|
model_id = 'google-bert/bert-base-uncased' |
|
super().__init__(config) |
|
self.model = AutoModel.from_pretrained(model_id) |
|
self.max_seq_len = config.max_len |
|
self.bert_feat_dim = config.feat_dim |
|
self.class_num = config.class_num |
|
self.cls_linear = torch.nn.Linear(self.bert_feat_dim * 2, self.class_num) |
|
self.feature_linear = torch.nn.Linear(self.bert_feat_dim * 2 + self.class_num * 3, self.bert_feat_dim * 2) |
|
self.dropout_output = torch.nn.Dropout(0.1) |
|
self.post_init() |
|
|
|
def multi_hops(self, features, mask, k): |
|
max_length = features.shape[1] |
|
mask = mask[:, :max_length] |
|
mask_a = mask.unsqueeze(1).expand([-1, max_length, -1]) |
|
mask_b = mask.unsqueeze(2).expand([-1, -1, max_length]) |
|
mask = mask_a * mask_b |
|
mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num]) |
|
|
|
'''save all logits''' |
|
logits_list = [] |
|
logits = self.cls_linear(features) |
|
logits_list.append(logits) |
|
for i in range(k): |
|
|
|
probs = logits |
|
logits = probs * mask |
|
logits_a = torch.max(logits, dim=1)[0] |
|
logits_b = torch.max(logits, dim=2)[0] |
|
logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3) |
|
logits = torch.max(logits, dim=3)[0] |
|
|
|
logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1]) |
|
logits_T = logits.transpose(1, 2) |
|
logits = torch.cat([logits, logits_T], dim=3) |
|
|
|
new_features = torch.cat([features, logits, probs], dim=3) |
|
features = self.feature_linear(new_features) |
|
logits = self.cls_linear(features) |
|
logits_list.append(logits) |
|
return logits_list |
|
|
|
def forward(self, input_ids, attention_masks, labels=None): |
|
model_feature = self.model(input_ids, attention_masks) |
|
model_feature = model_feature.last_hidden_state.detach() |
|
bert_feature = self.dropout_output(model_feature) |
|
bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1]) |
|
bert_feature_T = bert_feature.transpose(1, 2) |
|
features = torch.cat([bert_feature, bert_feature_T], dim=3) |
|
logits = self.multi_hops(features, attention_masks, 1) |
|
fin_logits = logits[-1] |
|
loss = None |
|
if labels is not None: |
|
|
|
gold_floss = labels.reshape([-1]) |
|
pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]]) |
|
loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1) |
|
return {'logits': fin_logits, 'loss': loss} |
|
|