import torch class DecodeAndEvaluate: def __init__(self, tokenizer): self.tokenizer = tokenizer self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5} self.id2sentiment = {v:k for k, v in self.sentiment2id.items()} def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions sel_spans = [] end_ind = -1 has_prev = False start_ind = -1 for i in range(len(token_range)): l,r = token_range[i] if tags[l][l]!= tok_type: if has_prev: sel_spans.append([start_ind, end_ind]) start_ind = -1 end_ind= -1 has_prev = False if tags[l][l] == tok_type and not has_prev: start_ind = l end_ind = r has_prev = True if tags[l][l] == tok_type and has_prev: end_ind = r has_prev = True if has_prev: sel_spans.append([start_ind, end_ind]) return sel_spans ## Corner cases where one sentiment span expresses over multiple sentiments # and one aspect has multiple sentiments expressed on it def find_triplet(self, tags, aspect_spans, opinion_spans): triplets = [] for al, ar in aspect_spans: for pl, pr in opinion_spans: ## get the overlapping indices # we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r] # if opi>asp then lower triangular matrix starts being selected that is not annotated # print(al, ar, pl, pr) if al<=pl: sent_tags = tags[al:ar+1, pl:pr+1] flat_tags = sent_tags.reshape([-1]) flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0]) val = torch.mode(flat_tags).values.item() if val > 0: triplets.append([al, ar, pl, pr, val]) else: # In this case the aspect becomes column and sentiment becomes the row # print(al, pl) sent_tags = tags[pl:pr+1, al: ar+1] # print(sent_tags) flat_tags = sent_tags.reshape([-1]) flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0]) val = torch.mode(flat_tags).values.item() if val>0: triplets.append([al, ar, pl, pr, val]) return triplets def decode_triplets(self, triplets, sent_tokens): triplet_list = [] for alt, art, olt, ort, pol in triplets: asp_toks = sent_tokens[alt:art+1] op_toks = sent_tokens[olt: ort+1] asp_string = self.tokenizer.decode(asp_toks) op_string = self.tokenizer.decode(op_toks) if pol in [3, 4, 5]: sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent") triplet_list.append([asp_string, op_string, sentiment_pol]) return triplet_list def decode_predict_one(self, tags, token_range, sent_tokens): aspect_spans = self.get_span_from_tags(tags, token_range, 1) opinion_spans = self.get_span_from_tags(tags, token_range, 2) triplets = self.find_triplet(tags, aspect_spans, opinion_spans) return self.decode_triplets(triplets, sent_tokens) def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens): decoded_batch_results = [] for i in range(tags_batch.shape[0]): res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i]) decoded_batch_results.append(res) return decoded_batch_results def decode_predict_string_one(self, text_sent, model, max_len=64): token_range = [] words = text_sent.strip().split() bert_tokens_padding = torch.zeros(max_len).long() bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words) tok_length = len(bert_tokens) if tok_length>max_len: raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})') # this maps (token_start, token_end) # token_start=1 for i, w, in enumerate(words): token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False)) token_range.append([token_start, token_end-1]) token_start = token_end bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long() attention_mask = torch.zeros(max_len).long() attention_mask[:tok_length]=1 tags_pred = model(bert_tokens_padding.unsqueeze(0), attention_masks=attention_mask.unsqueeze(0)) tags = tags_pred['logits'][0].argmax(dim=-1) return self.decode_predict_one(tags, token_range, bert_tokens) def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels): batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens) flat_gold, flat_pred = [], [] for preds, golds in list(zip(batch_results, gold_labels)): for pred in preds: flat_pred.append("-".join(pred)) for gold in golds: flat_gold.append("-".join(gold)) gold_set = set(flat_gold) pred_set = set(flat_pred) tp = len(gold_set & pred_set) fp = len(pred_set - gold_set) fn = len(gold_set - pred_set) return tp, fp, fn