gauneg commited on
Commit
32da917
·
1 Parent(s): 6a268f7

commit files to HF hub

Browse files
Files changed (9) hide show
  1. README.md +79 -0
  2. bert_opinion.py +76 -0
  3. config.json +11 -0
  4. model.safetensors +3 -0
  5. post.py +135 -0
  6. special_tokens_map.json +7 -0
  7. tokenizer.json +0 -0
  8. tokenizer_config.json +55 -0
  9. vocab.txt +0 -0
README.md CHANGED
@@ -1,3 +1,82 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ## Dataset Domain: Restaurant Reviews
6
+
7
+ ## Overview
8
+ This work is based on [Grid Tagging Scheme for Aspect-oriented Fine-grained Opinion Extraction](https://aclanthology.org/2020.findings-emnlp.234/).The code from
9
+ their [github repository](https://github.com/NJUNLP/GTS) was also utilized along with their dataset.
10
+
11
+ This model requires custom code as it uses GridTaggingScheme to predict the labels on the input. For the convenience,
12
+ the custom code and model architecture has been included with the model.
13
+
14
+ ## Example Code for inferencing
15
+
16
+ ### STEP 1 (Installing huggingface lib)
17
+
18
+ ```bash
19
+ pip install --upgrade huggingface_hub
20
+ ```
21
+
22
+ ### STEP 2 (Download the custom code and model to predict opinion target, opinion span and sentiment polarity)
23
+ ```python
24
+
25
+ from huggingface_hub import hf_hub_download
26
+ import sys
27
+ # Download the custom model code
28
+ bert_gts_pretrained = hf_hub_download(repo_id='gauneg/bert-gts-absa-triple-restaurant', filename="bert_opinion.py")
29
+ post = hf_hub_download(repo_id='gauneg/bert-gts-absa-triple-restaurant', filename="post.py")
30
+
31
+ sys.path.append(bert_gts_pretrained.rsplit("/", 1)[0])
32
+ sys.path.append(post.rsplit("/", 1)[0])
33
+
34
+
35
+ from bert_opinion import BertGTSOpinionTriple
36
+ from post import DecodeAndEvaluate
37
+
38
+
39
+ from transformers import AutoTokenizer
40
+
41
+
42
+ model_id = 'gauneg/bert-gts-absa-triple-restaurant'
43
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
44
+ model = BertGTSOpinionTriple.from_pretrained(model_id)
45
+ dec_and_infer = DecodeAndEvaluate(tokenizer)
46
+ test_sentence0 = """The food was good but the ambience was bad"""
47
+
48
+
49
+
50
+ # prediction
51
+ print(dec_and_infer.decode_predict_string_one(test_sentence, model, max_len=128))
52
+
53
+ ```
54
+ Expected output
55
+
56
+ ```bash
57
+ [['display', 'great', 'positive'], ['battery life', 'great', 'positive']]
58
+ ```
59
+
60
+ # DETAILS
61
+ The model has been trained to use Grid Tagging Scheme (GTS) to predict `Opinion Target`, `Opinion Span` and `Sentiment Polarity`. For the purpose of training this model the domain specific datasets (laptop and restaurant reviews) were combined. The grid tagging example is shown
62
+ in the following diagram:
63
+
64
+ <figure>
65
+ <img src="./gts_pic.png" alt="gts-image" style="width:45%">
66
+ <figcaption>Fig 1. Grid tagging Scheme from <a href="https://aclanthology.org/2020.findings-emnlp.234/">(Wu et al., Findings 2020)</a> </figcaption>
67
+ </figure>
68
+
69
+ In the above sentence there are two absa triples. Each triple is expressed in the following order:
70
+
71
+ [<span style="color:red">Aspect Term/Opinion Target</span>, <span style="color:#7393B3">opinion span</span>, <span style="color:purple">sentiment polarity</span>]
72
+
73
+ The model and sample code as shown in the snippet with extract opinion triplets as: [
74
+ [<span style="color:red">hot dogs</span>, <span style="color:#7393B3">top notch</span>, <span style="color:purple">positive</span>],
75
+ [<span style="color:red">coffee</span>, <span style="color:#7393B3">avergae</span>, <span style="color:purple">neutral</span>]
76
+ ]
77
+
78
+ Definitions <a href="https://aclanthology.org/2020.findings-emnlp.234/">(Wu et al., Findings 2020)</a>:
79
+
80
+ 1. <span style="color:red">Aspect Term/Opinion Target</span>: Aspect term, also known as opinion target, is the word or phrase in a sentence representing feature or entity of products or services.
81
+ 2. <span style="color:#7393B3">Opinion Term </span>: Opinion Term refers to the term in a sentence used to express attitudes or opinions explicitly.
82
+ 3. <span style="color:purple">Sentiment Polarity</span>: This is the sentiment expressed.
bert_opinion.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BertGTSOpinionTripleConfig(PretrainedConfig):
7
+ model_type = 'multi-infer-bert-uncased'
8
+ def __init__(self, feat_dim = 768, max_len=128, class_num=6, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.feat_dim = feat_dim
11
+ self.max_len = max_len
12
+ self.class_num = class_num
13
+
14
+ class BertGTSOpinionTriple(PreTrainedModel):
15
+ config_class = BertGTSOpinionTripleConfig
16
+ def __init__(self, config):
17
+ model_id = 'google-bert/bert-base-uncased'
18
+ super().__init__(config)
19
+ self.model = AutoModel.from_pretrained(model_id)
20
+ self.max_seq_len = config.max_len
21
+ self.bert_feat_dim = config.feat_dim #768
22
+ self.class_num = config.class_num #6
23
+ self.cls_linear = torch.nn.Linear(self.bert_feat_dim*2, self.class_num)
24
+ self.feature_linear = torch.nn.Linear(self.bert_feat_dim*2+self.class_num*3, self.bert_feat_dim*2)
25
+ self.dropout_output = torch.nn.Dropout(0.1)
26
+ self.post_init()
27
+
28
+
29
+ def multi_hops(self, features, mask, k):
30
+ max_length = features.shape[1]
31
+ mask = mask[:, :max_length]
32
+ mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
33
+ mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
34
+ mask = mask_a * mask_b
35
+ mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num])
36
+
37
+ '''save all logits'''
38
+ logits_list = []
39
+ logits = self.cls_linear(features)
40
+ logits_list.append(logits)
41
+ for i in range(k):
42
+ #probs = torch.softmax(logits, dim=3)
43
+ probs = logits
44
+ logits = probs * mask
45
+ logits_a = torch.max(logits, dim=1)[0]
46
+ logits_b = torch.max(logits, dim=2)[0]
47
+ logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
48
+ logits = torch.max(logits, dim=3)[0]
49
+
50
+ logits = logits.unsqueeze(2).expand([-1,-1, max_length, -1])
51
+ logits_T = logits.transpose(1, 2)
52
+ logits = torch.cat([logits, logits_T], dim=3)
53
+
54
+ new_features = torch.cat([features, logits, probs], dim=3)
55
+ features = self.feature_linear(new_features)
56
+ logits = self.cls_linear(features)
57
+ logits_list.append(logits)
58
+ return logits_list
59
+
60
+ def forward(self, input_ids, attention_masks, labels=None): # rename if required
61
+ model_feature = self.model(input_ids, attention_masks)
62
+ model_feature = model_feature.last_hidden_state.detach()
63
+ bert_feature = self.dropout_output(model_feature)
64
+ bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1])
65
+ bert_feature_T = bert_feature.transpose(1, 2)
66
+ features = torch.cat([bert_feature, bert_feature_T], dim=3)
67
+ logits = self.multi_hops(features, attention_masks, 1)
68
+ fin_logits = logits[-1]
69
+ loss = None
70
+ if labels is not None:
71
+ ## preforming the loss operation, crosscheck with the previous impl
72
+ gold_floss = labels.reshape([-1])
73
+ pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]])
74
+ loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1)
75
+ return {'logits': fin_logits, 'loss': loss}
76
+
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertGTSOpinionTriple"
4
+ ],
5
+ "class_num": 6,
6
+ "feat_dim": 768,
7
+ "max_len": 128,
8
+ "model_type": "multi-infer-bert-uncased",
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.44.2"
11
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d2431c4f6ce4c2e013d67c47439ce09450cfb2b156fc9cb4d13be7b4bde7f1a
3
+ size 447543680
post.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ class DecodeAndEvaluate:
5
+ def __init__(self, tokenizer):
6
+ self.tokenizer = tokenizer
7
+ self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5}
8
+ self.id2sentiment = {v:k for k, v in self.sentiment2id.items()}
9
+
10
+ def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions
11
+ sel_spans = []
12
+ end_ind = -1
13
+ has_prev = False
14
+ start_ind = -1
15
+ for i in range(len(token_range)):
16
+ l,r = token_range[i]
17
+ if tags[l][l]!= tok_type:
18
+ if has_prev:
19
+ sel_spans.append([start_ind, end_ind])
20
+ start_ind = -1
21
+ end_ind= -1
22
+ has_prev = False
23
+ if tags[l][l] == tok_type and not has_prev:
24
+ start_ind = l
25
+ end_ind = r
26
+ has_prev = True
27
+ if tags[l][l] == tok_type and has_prev:
28
+ end_ind = r
29
+ has_prev = True
30
+ if has_prev:
31
+ sel_spans.append([start_ind, end_ind])
32
+
33
+ return sel_spans
34
+
35
+ ## Corner cases where one sentiment span expresses over multiple sentiments
36
+ # and one aspect has multiple sentiments expressed on it
37
+ def find_triplet(self, tags, aspect_spans, opinion_spans):
38
+ triplets = []
39
+ for al, ar in aspect_spans:
40
+ for pl, pr in opinion_spans:
41
+ ## get the overlapping indices
42
+ # we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r]
43
+ # if opi>asp then lower triangular matrix starts being selected that is not annotated
44
+ # print(al, ar, pl, pr)
45
+ if al<=pl:
46
+ sent_tags = tags[al:ar+1, pl:pr+1]
47
+ flat_tags = sent_tags.reshape([-1])
48
+ flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
49
+ val = torch.mode(flat_tags).values.item()
50
+ if val > 0:
51
+ triplets.append([al, ar, pl, pr, val])
52
+ else: # In this case the aspect becomes column and sentiment becomes the row
53
+ # print(al, pl)
54
+ sent_tags = tags[pl:pr+1, al: ar+1]
55
+ # print(sent_tags)
56
+ flat_tags = sent_tags.reshape([-1])
57
+ flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
58
+ val = torch.mode(flat_tags).values.item()
59
+ if val>0:
60
+ triplets.append([al, ar, pl, pr, val])
61
+ return triplets
62
+
63
+ def decode_triplets(self, triplets, sent_tokens):
64
+ triplet_list = []
65
+ for alt, art, olt, ort, pol in triplets:
66
+ asp_toks = sent_tokens[alt:art+1]
67
+ op_toks = sent_tokens[olt: ort+1]
68
+ asp_string = self.tokenizer.decode(asp_toks)
69
+ op_string = self.tokenizer.decode(op_toks)
70
+ if pol in [3, 4, 5]:
71
+ sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent")
72
+ triplet_list.append([asp_string, op_string, sentiment_pol])
73
+ return triplet_list
74
+
75
+ def decode_predict_one(self, tags, token_range, sent_tokens):
76
+ aspect_spans = self.get_span_from_tags(tags, token_range, 1)
77
+ opinion_spans = self.get_span_from_tags(tags, token_range, 2)
78
+ triplets = self.find_triplet(tags, aspect_spans, opinion_spans)
79
+ return self.decode_triplets(triplets, sent_tokens)
80
+
81
+
82
+ def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens):
83
+ decoded_batch_results = []
84
+ for i in range(tags_batch.shape[0]):
85
+ res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i])
86
+ decoded_batch_results.append(res)
87
+ return decoded_batch_results
88
+
89
+ def decode_predict_string_one(self, text_sent, model, max_len=64):
90
+ token_range = []
91
+ words = text_sent.strip().split()
92
+ bert_tokens_padding = torch.zeros(max_len).long()
93
+ bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words)
94
+
95
+ tok_length = len(bert_tokens)
96
+ if tok_length>max_len:
97
+ raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})')
98
+ # this maps (token_start, token_end)
99
+ #
100
+ token_start=1
101
+ for i, w, in enumerate(words):
102
+ token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False))
103
+ token_range.append([token_start, token_end-1])
104
+ token_start = token_end
105
+
106
+ bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long()
107
+ attention_mask = torch.zeros(max_len).long()
108
+ attention_mask[:tok_length]=1
109
+
110
+ tags_pred = model(bert_tokens_padding.unsqueeze(0),
111
+ attention_masks=attention_mask.unsqueeze(0))
112
+
113
+ tags = tags_pred['logits'][0].argmax(dim=-1)
114
+ return self.decode_predict_one(tags, token_range, bert_tokens)
115
+
116
+
117
+
118
+ def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels):
119
+
120
+ batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens)
121
+ flat_gold, flat_pred = [], []
122
+
123
+ for preds, golds in list(zip(batch_results, gold_labels)):
124
+ for pred in preds:
125
+ flat_pred.append("-".join(pred))
126
+ for gold in golds:
127
+ flat_gold.append("-".join(gold))
128
+ gold_set = set(flat_gold)
129
+ pred_set = set(flat_pred)
130
+ tp = len(gold_set & pred_set)
131
+ fp = len(pred_set - gold_set)
132
+ fn = len(gold_set - pred_set)
133
+
134
+ return tp, fp, fn
135
+
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "BertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff