bert-gts-absa-triple / bert_opinion.py
gauneg's picture
commit files to HF hub
ccfef63
import torch
from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig
import torch.nn.functional as F
class BertGTSOpinionTripleConfig(PretrainedConfig):
model_type = 'multi-infer-bert-uncased'
def __init__(self, feat_dim=768, max_len=128, class_num=6, **kwargs):
super().__init__(**kwargs)
self.feat_dim = feat_dim
self.max_len = max_len
self.class_num = class_num
class BertGTSOpinionTriple(PreTrainedModel):
config_class = BertGTSOpinionTripleConfig
def __init__(self, config):
model_id = 'google-bert/bert-base-uncased'
super().__init__(config)
self.model = AutoModel.from_pretrained(model_id)
self.max_seq_len = config.max_len
self.bert_feat_dim = config.feat_dim #768
self.class_num = config.class_num #6
self.cls_linear = torch.nn.Linear(self.bert_feat_dim * 2, self.class_num)
self.feature_linear = torch.nn.Linear(self.bert_feat_dim * 2 + self.class_num * 3, self.bert_feat_dim * 2)
self.dropout_output = torch.nn.Dropout(0.1)
self.post_init()
def multi_hops(self, features, mask, k):
max_length = features.shape[1]
mask = mask[:, :max_length]
mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
mask = mask_a * mask_b
mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num])
'''save all logits'''
logits_list = []
logits = self.cls_linear(features)
logits_list.append(logits)
for i in range(k):
#probs = torch.softmax(logits, dim=3)
probs = logits
logits = probs * mask
logits_a = torch.max(logits, dim=1)[0]
logits_b = torch.max(logits, dim=2)[0]
logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
logits = torch.max(logits, dim=3)[0]
logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1])
logits_T = logits.transpose(1, 2)
logits = torch.cat([logits, logits_T], dim=3)
new_features = torch.cat([features, logits, probs], dim=3)
features = self.feature_linear(new_features)
logits = self.cls_linear(features)
logits_list.append(logits)
return logits_list
def forward(self, input_ids, attention_masks, labels=None): # rename if required
model_feature = self.model(input_ids, attention_masks)
model_feature = model_feature.last_hidden_state.detach()
bert_feature = self.dropout_output(model_feature)
bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1])
bert_feature_T = bert_feature.transpose(1, 2)
features = torch.cat([bert_feature, bert_feature_T], dim=3)
logits = self.multi_hops(features, attention_masks, 1)
fin_logits = logits[-1]
loss = None
if labels is not None:
## preforming the loss operation, crosscheck with the previous impl
gold_floss = labels.reshape([-1])
pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]])
loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1)
return {'logits': fin_logits, 'loss': loss}