Commit
·
4e38daf
1
Parent(s):
d60b7fd
Upload 18 files
Browse files- .gitattributes +7 -0
- args_model_utils.py +210 -0
- argument_model_state_dict.pth +3 -0
- configuration.py +1 -0
- event_arg_predict.py +280 -0
- event_arg_role_dataloader.py +100 -0
- event_arg_role_predict.py +113 -0
- event_nugget_predict.py +250 -0
- event_realis_predict.py +270 -0
- model_59.pt +3 -0
- model_64_pos_ner.pt +3 -0
- model_66.pt +3 -0
- model_97.pt +3 -0
- nugget_model_state_dict.pth +3 -0
- nugget_model_utils.py +151 -0
- realis_model_state_dict.pth +3 -0
- realis_model_utils.py +146 -0
- utils.py +196 -0
.gitattributes
CHANGED
|
@@ -1 +1,8 @@
|
|
| 1 |
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
argument_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
model_59.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
model_64_pos_ner.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
model_66.pt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
model_97.pt filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
nugget_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
realis_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
args_model_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import spacy
|
| 3 |
+
import en_core_web_sm
|
| 4 |
+
from torch import nn
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 9 |
+
|
| 10 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 14 |
+
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 16 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
| 17 |
+
|
| 18 |
+
nlp = en_core_web_sm.load()
|
| 19 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 20 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 21 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
| 22 |
+
event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
|
| 23 |
+
arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
|
| 24 |
+
|
| 25 |
+
class CustomRobertaWithPOS(nn.Module):
|
| 26 |
+
def __init__(self, num_classes):
|
| 27 |
+
super(CustomRobertaWithPOS, self).__init__()
|
| 28 |
+
self.num_classes = num_classes
|
| 29 |
+
|
| 30 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
| 31 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
|
| 32 |
+
self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
|
| 33 |
+
self.depth_embed = nn.Embedding(17, 8)
|
| 34 |
+
self.subtype_embed = nn.Embedding(len(event_nugget_tag_list), 2)
|
| 35 |
+
self.dist_embed = nn.Embedding(11, 6)
|
| 36 |
+
self.relative_pos_embed = nn.Embedding(len(arg_nugget_relative_pos_tag_list), 2)
|
| 37 |
+
|
| 38 |
+
self.roberta = roberta_model
|
| 39 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 40 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size + 50, num_classes)
|
| 41 |
+
|
| 42 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, nearest_nugget_subtype, nearest_nugget_dist, arg_nugget_relative_pos):
|
| 43 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
| 44 |
+
last_hidden_output = outputs.last_hidden_state
|
| 45 |
+
|
| 46 |
+
pooler_output = outputs.pooler_output
|
| 47 |
+
pooler_output_unsqz = pooler_output.unsqueeze(1)
|
| 48 |
+
pooler_output_fin = pooler_output_unsqz.expand(-1, last_hidden_output.shape[1], -1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
pos_mask = pos_spacy != -100
|
| 52 |
+
pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
|
| 53 |
+
pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
|
| 54 |
+
pos_embed[pos_mask] = pos_embed_masked
|
| 55 |
+
|
| 56 |
+
ner_mask = ner_spacy != -100
|
| 57 |
+
ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
|
| 58 |
+
ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 59 |
+
ner_embed[ner_mask] = ner_embed_masked
|
| 60 |
+
|
| 61 |
+
dep_mask = dep_spacy != -100
|
| 62 |
+
dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
|
| 63 |
+
dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 64 |
+
dep_embed[dep_mask] = dep_embed_masked
|
| 65 |
+
|
| 66 |
+
depth_mask = depth_spacy != -100
|
| 67 |
+
depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
|
| 68 |
+
depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 69 |
+
depth_embed[dep_mask] = depth_embed_masked
|
| 70 |
+
|
| 71 |
+
nearest_nugget_subtype_mask = nearest_nugget_subtype != -100
|
| 72 |
+
nearest_nugget_subtype_embed_masked = self.subtype_embed(nearest_nugget_subtype[nearest_nugget_subtype_mask])
|
| 73 |
+
nearest_nugget_subtype_embed = torch.zeros((nearest_nugget_subtype.shape[0], nearest_nugget_subtype.shape[1], 2), dtype=torch.float).to(device)
|
| 74 |
+
nearest_nugget_subtype_embed[dep_mask] = nearest_nugget_subtype_embed_masked
|
| 75 |
+
|
| 76 |
+
nearest_nugget_dist_mask = nearest_nugget_dist != -100
|
| 77 |
+
nearest_nugget_dist_embed_masked = self.dist_embed(nearest_nugget_dist[nearest_nugget_dist_mask])
|
| 78 |
+
nearest_nugget_dist_embed = torch.zeros((nearest_nugget_dist.shape[0], nearest_nugget_dist.shape[1], 6), dtype=torch.float).to(device)
|
| 79 |
+
nearest_nugget_dist_embed[dep_mask] = nearest_nugget_dist_embed_masked
|
| 80 |
+
|
| 81 |
+
arg_nugget_relative_pos_mask = arg_nugget_relative_pos != -100
|
| 82 |
+
arg_nugget_relative_pos_embed_masked = self.relative_pos_embed(arg_nugget_relative_pos[arg_nugget_relative_pos_mask])
|
| 83 |
+
arg_nugget_relative_pos_embed = torch.zeros((arg_nugget_relative_pos.shape[0], arg_nugget_relative_pos.shape[1], 2), dtype=torch.float).to(device)
|
| 84 |
+
arg_nugget_relative_pos_embed[dep_mask] = arg_nugget_relative_pos_embed_masked
|
| 85 |
+
|
| 86 |
+
features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nearest_nugget_subtype_embed, nearest_nugget_dist_embed, arg_nugget_relative_pos_embed), 2).to(device)
|
| 87 |
+
features_concat = self.dropout1(features_concat)
|
| 88 |
+
|
| 89 |
+
logits = self.fc1(features_concat)
|
| 90 |
+
|
| 91 |
+
return logits
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
|
| 95 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
| 96 |
+
#tokenized_inputs.pop('input_ids')
|
| 97 |
+
ner_spacy = []
|
| 98 |
+
pos_spacy = []
|
| 99 |
+
dep_spacy = []
|
| 100 |
+
depth_spacy = []
|
| 101 |
+
nearest_nugget_subtype = []
|
| 102 |
+
nearest_nugget_dist = []
|
| 103 |
+
arg_nugget_relative_pos = []
|
| 104 |
+
|
| 105 |
+
for i, (pos, ner, dep, depth, subtype, dist, relative_pos) in enumerate(zip(examples["pos_spacy"],
|
| 106 |
+
examples["ner_spacy"],
|
| 107 |
+
examples["dep_spacy"],
|
| 108 |
+
examples["depth_spacy"],
|
| 109 |
+
examples["nearest_nugget_subtype"],
|
| 110 |
+
examples["nearest_nugget_dist"],
|
| 111 |
+
examples["arg_nugget_relative_pos"])):
|
| 112 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 113 |
+
previous_word_idx = None
|
| 114 |
+
ner_spacy_ids = []
|
| 115 |
+
pos_spacy_ids = []
|
| 116 |
+
dep_spacy_ids = []
|
| 117 |
+
depth_spacy_ids = []
|
| 118 |
+
nearest_nugget_subtype_ids = []
|
| 119 |
+
nearest_nugget_dist_ids = []
|
| 120 |
+
arg_nugget_relative_pos_ids = []
|
| 121 |
+
|
| 122 |
+
for word_idx in word_ids:
|
| 123 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
| 124 |
+
# ignored in the loss function.
|
| 125 |
+
if word_idx is None:
|
| 126 |
+
ner_spacy_ids.append(-100)
|
| 127 |
+
pos_spacy_ids.append(-100)
|
| 128 |
+
dep_spacy_ids.append(-100)
|
| 129 |
+
depth_spacy_ids.append(-100)
|
| 130 |
+
nearest_nugget_subtype_ids.append(-100)
|
| 131 |
+
nearest_nugget_dist_ids.append(-100)
|
| 132 |
+
arg_nugget_relative_pos_ids.append(-100)
|
| 133 |
+
# We set the label for the first token of each word.
|
| 134 |
+
elif word_idx != previous_word_idx:
|
| 135 |
+
ner_spacy_ids.append(ner[word_idx])
|
| 136 |
+
pos_spacy_ids.append(pos[word_idx])
|
| 137 |
+
dep_spacy_ids.append(dep[word_idx])
|
| 138 |
+
depth_spacy_ids.append(depth[word_idx])
|
| 139 |
+
nearest_nugget_subtype_ids.append(subtype[word_idx])
|
| 140 |
+
nearest_nugget_dist_ids.append(dist[word_idx])
|
| 141 |
+
arg_nugget_relative_pos_ids.append(relative_pos[word_idx])
|
| 142 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
| 143 |
+
# the label_all_tokens flag.
|
| 144 |
+
else:
|
| 145 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
| 146 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
| 147 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
| 148 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
| 149 |
+
nearest_nugget_subtype_ids.append(subtype[word_idx] if label_all_tokens else -100)
|
| 150 |
+
nearest_nugget_dist_ids.append(dist[word_idx] if label_all_tokens else -100)
|
| 151 |
+
arg_nugget_relative_pos_ids.append(relative_pos[word_idx] if label_all_tokens else -100)
|
| 152 |
+
previous_word_idx = word_idx
|
| 153 |
+
|
| 154 |
+
ner_spacy.append(ner_spacy_ids)
|
| 155 |
+
pos_spacy.append(pos_spacy_ids)
|
| 156 |
+
dep_spacy.append(dep_spacy_ids)
|
| 157 |
+
depth_spacy.append(depth_spacy_ids)
|
| 158 |
+
nearest_nugget_subtype.append(nearest_nugget_subtype_ids)
|
| 159 |
+
nearest_nugget_dist.append(nearest_nugget_dist_ids)
|
| 160 |
+
arg_nugget_relative_pos.append(arg_nugget_relative_pos_ids)
|
| 161 |
+
|
| 162 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
| 163 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
| 164 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
| 165 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
| 166 |
+
tokenized_inputs["nearest_nugget_subtype"] = nearest_nugget_subtype
|
| 167 |
+
tokenized_inputs["nearest_nugget_dist"] = nearest_nugget_dist
|
| 168 |
+
tokenized_inputs["arg_nugget_relative_pos"] = arg_nugget_relative_pos
|
| 169 |
+
return tokenized_inputs
|
| 170 |
+
|
| 171 |
+
def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
|
| 172 |
+
nearest_subtype = None
|
| 173 |
+
nearest_dist = math.inf
|
| 174 |
+
relative_pos = None
|
| 175 |
+
|
| 176 |
+
mid_idx = (end_idx + start_idx) / 2
|
| 177 |
+
for nugget in event_nuggets:
|
| 178 |
+
mid_nugget_idx = (nugget["startOffset"] + nugget["endOffset"]) / 2
|
| 179 |
+
dist = abs(mid_nugget_idx - mid_idx)
|
| 180 |
+
|
| 181 |
+
if dist < nearest_dist:
|
| 182 |
+
nearest_dist = dist
|
| 183 |
+
nearest_subtype = nugget["subtype"]
|
| 184 |
+
for sent in doc.sents:
|
| 185 |
+
if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
|
| 186 |
+
if mid_idx < mid_nugget_idx:
|
| 187 |
+
relative_pos = "before-same-sentence"
|
| 188 |
+
else:
|
| 189 |
+
relative_pos = "after-same-sentence"
|
| 190 |
+
break
|
| 191 |
+
elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
|
| 192 |
+
relative_pos = "after-differ-sentence"
|
| 193 |
+
break
|
| 194 |
+
elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
|
| 195 |
+
relative_pos = "before-differ-sentence"
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
nearest_dist = int(min(10, nearest_dist // 20))
|
| 199 |
+
return nearest_subtype, nearest_dist, relative_pos
|
| 200 |
+
|
| 201 |
+
def find_dep_depth(token):
|
| 202 |
+
depth = 0
|
| 203 |
+
current_token = token
|
| 204 |
+
while current_token.head != current_token:
|
| 205 |
+
depth += 1
|
| 206 |
+
current_token = current_token.head
|
| 207 |
+
return min(depth, 16)
|
| 208 |
+
|
| 209 |
+
def between_idxs(idx, start_idx, end_idx):
|
| 210 |
+
return idx >= start_idx and idx <= end_idx
|
argument_model_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:185e22992430c80ec1eb1fca7f3ba4ebe801163c3ba13bed00abc6dc24072712
|
| 3 |
+
size 498813605
|
configuration.py
CHANGED
|
@@ -5,6 +5,7 @@ from cybersecurity_knowledge_graph.utils import event_args_list, event_nugget_li
|
|
| 5 |
|
| 6 |
|
| 7 |
class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
|
|
|
|
| 8 |
|
| 9 |
def __init__(
|
| 10 |
self,
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
|
| 8 |
+
model_type = "cybersecurity_knowledge_graph"
|
| 9 |
|
| 10 |
def __init__(
|
| 11 |
self,
|
event_arg_predict.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from annotated_text import annotated_text
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from cybersecurity_knowledge_graph.args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
| 7 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
|
| 8 |
+
from cybersecurity_knowledge_graph.utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
|
| 9 |
+
|
| 10 |
+
from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
|
| 11 |
+
import spacy
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 17 |
+
|
| 18 |
+
def find_dep_depth(token):
|
| 19 |
+
depth = 0
|
| 20 |
+
current_token = token
|
| 21 |
+
while current_token.head != current_token:
|
| 22 |
+
depth += 1
|
| 23 |
+
current_token = current_token.head
|
| 24 |
+
return min(depth, 16)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
nlp = spacy.load('en_core_web_sm')
|
| 28 |
+
|
| 29 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 30 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 31 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
| 32 |
+
event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
|
| 33 |
+
arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 36 |
+
|
| 37 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 39 |
+
|
| 40 |
+
from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
| 41 |
+
model_nugget = ArgumentModel(num_classes=43)
|
| 42 |
+
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
| 43 |
+
model_nugget.eval()
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
Function: create_dataloader(text_input)
|
| 47 |
+
Description: This function creates a DataLoader for processing text data, tokenizes it, and organizes it into batches.
|
| 48 |
+
Inputs:
|
| 49 |
+
- text_input: The input text to be processed.
|
| 50 |
+
Output:
|
| 51 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
| 52 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
| 53 |
+
"""
|
| 54 |
+
def create_dataloader(text_input):
|
| 55 |
+
|
| 56 |
+
event_nuggets = get_event_nuggets(text_input)
|
| 57 |
+
doc = nlp(text_input)
|
| 58 |
+
|
| 59 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
| 60 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
| 61 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
| 62 |
+
|
| 63 |
+
data = []
|
| 64 |
+
|
| 65 |
+
words = []
|
| 66 |
+
arg_nugget_nearest_subtype = []
|
| 67 |
+
arg_nugget_nearest_dist = []
|
| 68 |
+
arg_nugget_relative_pos = []
|
| 69 |
+
|
| 70 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
| 71 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
| 72 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
| 73 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
| 74 |
+
|
| 75 |
+
for content_dict in content_idx_dict:
|
| 76 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
| 77 |
+
nearest_subtype, nearest_dist, relative_pos = find_nearest_nugget_features(doc, content_dict["start_idx"], content_dict["end_idx"], event_nuggets)
|
| 78 |
+
words.append(content_dict["word"])
|
| 79 |
+
|
| 80 |
+
arg_nugget_nearest_subtype.append(nearest_subtype)
|
| 81 |
+
arg_nugget_nearest_dist.append(nearest_dist)
|
| 82 |
+
arg_nugget_relative_pos.append(relative_pos)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
| 86 |
+
if content_token_len > tokenizer.model_max_length:
|
| 87 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
| 88 |
+
split_len = (len(words) // no_split) + 1
|
| 89 |
+
|
| 90 |
+
last_id = 0
|
| 91 |
+
threshold = split_len
|
| 92 |
+
|
| 93 |
+
for id, token in enumerate(words):
|
| 94 |
+
if token == "." and id > threshold:
|
| 95 |
+
data.append(
|
| 96 |
+
{
|
| 97 |
+
"tokens" : words[last_id : id + 1],
|
| 98 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
| 99 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
| 100 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
| 101 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
| 102 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : id + 1],
|
| 103 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : id + 1],
|
| 104 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : id + 1]
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
last_id = id + 1
|
| 108 |
+
threshold += split_len
|
| 109 |
+
data.append({"tokens" : words[last_id : ],
|
| 110 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
| 111 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
| 112 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
| 113 |
+
"depth_spacy" : depth_spacy[last_id : ],
|
| 114 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : ],
|
| 115 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : ],
|
| 116 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : ]})
|
| 117 |
+
else:
|
| 118 |
+
data.append(
|
| 119 |
+
{
|
| 120 |
+
"tokens" : words,
|
| 121 |
+
"pos_spacy" : pos_spacy,
|
| 122 |
+
"ner_spacy" : ner_spacy,
|
| 123 |
+
"dep_spacy" : dep_spacy,
|
| 124 |
+
"depth_spacy" : depth_spacy,
|
| 125 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype,
|
| 126 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist,
|
| 127 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
| 133 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 134 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 135 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 136 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None),
|
| 137 |
+
'nearest_nugget_subtype' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_tag_list), names=event_nugget_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 138 |
+
'nearest_nugget_dist' : Sequence(feature=ClassLabel(num_classes=11, names=list(range(11)), names_file=None, id=None), length=-1, id=None),
|
| 139 |
+
'arg_nugget_relative_pos' : Sequence(feature=ClassLabel(num_classes=len(arg_nugget_relative_pos_tag_list), names=arg_nugget_relative_pos_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
| 143 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
|
| 144 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
| 145 |
+
|
| 146 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
| 147 |
+
|
| 148 |
+
batch_size = 4 # Number of input texts
|
| 149 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
| 150 |
+
return dataloader, tokenized_dataset_ner
|
| 151 |
+
|
| 152 |
+
"""
|
| 153 |
+
Function: predict(dataloader)
|
| 154 |
+
Description: This function performs prediction on a given dataloader using a trained model for label classification.
|
| 155 |
+
Inputs:
|
| 156 |
+
- dataloader: A DataLoader containing the input data for prediction.
|
| 157 |
+
Output:
|
| 158 |
+
- predicted_label: A tensor containing the predicted labels for each input in the dataloader.
|
| 159 |
+
"""
|
| 160 |
+
def predict(dataloader):
|
| 161 |
+
predicted_label = []
|
| 162 |
+
for batch in dataloader:
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
logits = model_nugget(**batch)
|
| 165 |
+
|
| 166 |
+
batch_predicted_label = logits.argmax(-1)
|
| 167 |
+
predicted_label.append(batch_predicted_label)
|
| 168 |
+
return torch.cat(predicted_label, dim=-1)
|
| 169 |
+
|
| 170 |
+
"""
|
| 171 |
+
Function: show_annotations(text_input)
|
| 172 |
+
Description: This function displays annotated event arguments in the provided input text.
|
| 173 |
+
Inputs:
|
| 174 |
+
- text_input: The input text containing event arguments to be annotated and displayed.
|
| 175 |
+
Output:
|
| 176 |
+
- An interactive display of annotated event arguments within the input text.
|
| 177 |
+
"""
|
| 178 |
+
def show_annotations(text_input):
|
| 179 |
+
st.title("Event Arguments")
|
| 180 |
+
|
| 181 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 182 |
+
predicted_label = predict(dataloader)
|
| 183 |
+
|
| 184 |
+
for idx, labels in enumerate(predicted_label):
|
| 185 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 186 |
+
|
| 187 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 188 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 189 |
+
|
| 190 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 191 |
+
idxs = get_idxs_from_text(text, tokens)
|
| 192 |
+
|
| 193 |
+
labels = labels[token_mask]
|
| 194 |
+
|
| 195 |
+
annotated_text_list = []
|
| 196 |
+
last_label = ""
|
| 197 |
+
cumulative_tokens = ""
|
| 198 |
+
last_id = 0
|
| 199 |
+
|
| 200 |
+
for idx, label in zip(idxs, labels):
|
| 201 |
+
to_label = event_args_list[label]
|
| 202 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
| 203 |
+
if last_label == label_short:
|
| 204 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
| 205 |
+
last_id = idx["end_idx"]
|
| 206 |
+
else:
|
| 207 |
+
if last_label != "":
|
| 208 |
+
if last_label == "O":
|
| 209 |
+
annotated_text_list.append(cumulative_tokens)
|
| 210 |
+
else:
|
| 211 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 212 |
+
last_label = label_short
|
| 213 |
+
cumulative_tokens = idx["word"]
|
| 214 |
+
last_id = idx["end_idx"]
|
| 215 |
+
if last_label == "O":
|
| 216 |
+
annotated_text_list.append(cumulative_tokens)
|
| 217 |
+
else:
|
| 218 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 219 |
+
|
| 220 |
+
annotated_text(annotated_text_list)
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
Function: get_event_args(text_input)
|
| 224 |
+
Description: This function extracts predicted event arguments (event nuggets) from the provided input text.
|
| 225 |
+
Inputs:
|
| 226 |
+
- text_input: The input text containing event nuggets to be extracted.
|
| 227 |
+
Output:
|
| 228 |
+
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
| 229 |
+
subtype, and text content.
|
| 230 |
+
"""
|
| 231 |
+
def get_event_args(text_input):
|
| 232 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 233 |
+
predicted_label = predict(dataloader)
|
| 234 |
+
|
| 235 |
+
predicted_event_nuggets = []
|
| 236 |
+
text_length = 0
|
| 237 |
+
for idx, labels in enumerate(predicted_label):
|
| 238 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 239 |
+
|
| 240 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 241 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 242 |
+
|
| 243 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 244 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
| 245 |
+
|
| 246 |
+
labels = labels[token_mask]
|
| 247 |
+
|
| 248 |
+
start_idx = 0
|
| 249 |
+
end_idx = 0
|
| 250 |
+
last_label = ""
|
| 251 |
+
|
| 252 |
+
for idx, label in zip(idxs, labels):
|
| 253 |
+
to_label = event_args_list[label]
|
| 254 |
+
if "-" in to_label:
|
| 255 |
+
label_split = to_label.split("-")[1]
|
| 256 |
+
else:
|
| 257 |
+
label_split = to_label
|
| 258 |
+
|
| 259 |
+
if label_split == last_label:
|
| 260 |
+
end_idx = idx["end_idx"]
|
| 261 |
+
else:
|
| 262 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
| 263 |
+
predicted_event_nuggets.append(
|
| 264 |
+
{
|
| 265 |
+
"startOffset" : text_length + start_idx,
|
| 266 |
+
"endOffset" : text_length + end_idx,
|
| 267 |
+
"subtype" : last_label,
|
| 268 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
start_idx = idx["start_idx"]
|
| 272 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
| 273 |
+
last_label = label_split
|
| 274 |
+
text_length += idx["end_idx"]
|
| 275 |
+
return predicted_event_nuggets
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
event_arg_role_dataloader.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
from cybersecurity_knowledge_graph.utils import get_content, get_event_args, get_event_nugget, get_idxs_from_text, get_args_entity_from_idx, find_dict_by_overlap
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import spacy
|
| 5 |
+
import jsonlines
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
import math
|
| 8 |
+
from transformers import pipeline
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 13 |
+
|
| 14 |
+
pipe = pipeline("token-classification", model="CyberPeace-Institute/SecureBERT-NER")
|
| 15 |
+
|
| 16 |
+
nlp = spacy.load('en_core_web_sm')
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Class: EventArgumentRoleDataset
|
| 20 |
+
Description: This class represents a dataset for training and evaluating event argument role classifiers.
|
| 21 |
+
Attributes:
|
| 22 |
+
- path: The path to the folder containing JSON files with event data.
|
| 23 |
+
- tokenizer: A tokenizer for encoding text data.
|
| 24 |
+
- arg: The specific argument type (subtype) for which the dataset is created.
|
| 25 |
+
- data: A list to store data samples, each consisting of an embedding and a label.
|
| 26 |
+
- train_data, val_data, test_data: Lists to store the split training, validation, and test data samples.
|
| 27 |
+
- datapoint_id: An identifier for tracking data samples.
|
| 28 |
+
Methods:
|
| 29 |
+
- __len__(): Returns the total number of data samples in the dataset.
|
| 30 |
+
- __getitem__(index): Retrieves a data sample at a specified index.
|
| 31 |
+
- to_jsonlines(train_path, val_path, test_path): Writes the dataset to JSON files for train, validation, and test sets.
|
| 32 |
+
- train_val_test_split(): Splits the data into training and test sets.
|
| 33 |
+
- load_data(): Loads and preprocesses event data from JSON files, creating embeddings for argument-role classification.
|
| 34 |
+
"""
|
| 35 |
+
class EventArgumentRoleDataset():
|
| 36 |
+
def __init__(self, path, tokenizer, arg):
|
| 37 |
+
self.path = path
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
self.arg = arg
|
| 40 |
+
self.data = []
|
| 41 |
+
self.train_data, self.val_data, self.test_data = None, None, None
|
| 42 |
+
self.datapoint_id = 0
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.data)
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index):
|
| 48 |
+
sample = self.data[index]
|
| 49 |
+
return sample
|
| 50 |
+
|
| 51 |
+
def to_jsonlines(self, train_path, val_path, test_path):
|
| 52 |
+
if self.train_data is None or self.test_data is None:
|
| 53 |
+
raise ValueError("Do the train-val-test split")
|
| 54 |
+
with jsonlines.open(train_path, "w") as f:
|
| 55 |
+
f.write_all(self.train_data)
|
| 56 |
+
# with jsonlines.open(val_path, "w") as f:
|
| 57 |
+
# f.write_all(self.val_data)
|
| 58 |
+
with jsonlines.open(test_path, "w") as f:
|
| 59 |
+
f.write_all(self.test_data)
|
| 60 |
+
|
| 61 |
+
def train_val_test_split(self):
|
| 62 |
+
self.train_data, self.test_data = train_test_split(self.data, test_size=0.1, random_state=42, shuffle=True)
|
| 63 |
+
# self.val_data, self.test_data = train_test_split(test_val, test_size=0.5, random_state=42, shuffle=True)
|
| 64 |
+
|
| 65 |
+
def load_data(self):
|
| 66 |
+
folder_path = self.path
|
| 67 |
+
json_files = [file for file in os.listdir(folder_path) if file.endswith('.json')]
|
| 68 |
+
|
| 69 |
+
# Load the nuggets
|
| 70 |
+
for idx, file_path in enumerate(tqdm(json_files)):
|
| 71 |
+
try:
|
| 72 |
+
with open(self.path + file_path, "r") as f:
|
| 73 |
+
file_json = json.load(f)
|
| 74 |
+
except:
|
| 75 |
+
print("Error in ", file_path)
|
| 76 |
+
content = get_content(file_json)
|
| 77 |
+
content = content.replace("\xa0", " ")
|
| 78 |
+
|
| 79 |
+
event_args = get_event_args(file_json)
|
| 80 |
+
doc = nlp(content)
|
| 81 |
+
|
| 82 |
+
sentence_indexes = []
|
| 83 |
+
for sent in doc.sents:
|
| 84 |
+
start_index = sent[0].idx
|
| 85 |
+
end_index = sent[-1].idx + len(sent[-1].text)
|
| 86 |
+
sentence_indexes.append((start_index, end_index))
|
| 87 |
+
|
| 88 |
+
for idx, (start, end) in enumerate(sentence_indexes):
|
| 89 |
+
sentence = content[start:end]
|
| 90 |
+
is_arg_sentence = [event_arg["startOffset"] >= start and event_arg["endOffset"] <= end for event_arg in event_args]
|
| 91 |
+
args = [event_args[idx] for idx, boolean in enumerate(is_arg_sentence) if boolean]
|
| 92 |
+
if args != []:
|
| 93 |
+
sentence_doc = nlp(sentence)
|
| 94 |
+
sentence_embed = embed_model.encode(sentence)
|
| 95 |
+
for arg in args:
|
| 96 |
+
if arg["type"] == self.arg:
|
| 97 |
+
arg_embed = embed_model.encode(arg["text"])
|
| 98 |
+
embedding = np.concatenate((sentence_embed, arg_embed))
|
| 99 |
+
|
| 100 |
+
self.data.append({"embedding" : embedding, "label" : arg["role"]["type"]})
|
event_arg_role_predict.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cybersecurity_knowledge_graph.event_arg_role_dataloader import EventArgumentRoleDataset
|
| 2 |
+
from cybersecurity_knowledge_graph.utils import arg_2_role
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
import optuna
|
| 7 |
+
from sklearn.model_selection import StratifiedKFold
|
| 8 |
+
from sklearn.model_selection import cross_val_score
|
| 9 |
+
from sklearn.metrics import make_scorer, f1_score
|
| 10 |
+
from sklearn.ensemble import VotingClassifier
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.neural_network import MLPClassifier
|
| 13 |
+
from sklearn.svm import SVC
|
| 14 |
+
from joblib import dump, load
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 19 |
+
|
| 20 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 21 |
+
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 23 |
+
|
| 24 |
+
classifiers = {}
|
| 25 |
+
folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
|
| 26 |
+
|
| 27 |
+
for filename in os.listdir(os.getcwd() + folder_path):
|
| 28 |
+
if filename.endswith('.joblib'):
|
| 29 |
+
file_path = os.getcwd() + os.path.join(folder_path, filename)
|
| 30 |
+
clf = load(file_path)
|
| 31 |
+
arg = filename.split(".")[0]
|
| 32 |
+
classifiers[arg] = clf
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
Function: fit()
|
| 36 |
+
Description: This function performs a machine learning task to train and evaluate classifiers for multiple argument roles.
|
| 37 |
+
It utilizes Optuna for hyperparameter optimization and creates a Voting Classifier.
|
| 38 |
+
The trained classifiers are saved as joblib files.
|
| 39 |
+
"""
|
| 40 |
+
def fit():
|
| 41 |
+
for arg, roles in arg_2_role.items():
|
| 42 |
+
if len(roles) > 1:
|
| 43 |
+
|
| 44 |
+
dataset = EventArgumentRoleDataset(path="./data/annotation/", tokenizer=tokenizer, arg=arg)
|
| 45 |
+
dataset.load_data()
|
| 46 |
+
dataset.train_val_test_split()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
X = [datapoint["embedding"] for datapoint in dataset.data]
|
| 50 |
+
y = [roles.index(datapoint["label"]) for datapoint in dataset.data]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# FYI: Objective functions can take additional arguments
|
| 54 |
+
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
|
| 55 |
+
def objective(trial):
|
| 56 |
+
|
| 57 |
+
classifier_name = trial.suggest_categorical("classifier", ["voting"])
|
| 58 |
+
if classifier_name == "voting":
|
| 59 |
+
svc_c = trial.suggest_float("svc_c", 1e-3, 1e3, log=True)
|
| 60 |
+
svc_kernel = trial.suggest_categorical("kernel", ['rbf'])
|
| 61 |
+
classifier_obj = VotingClassifier(estimators=[
|
| 62 |
+
('Logistic Regression', LogisticRegression()),
|
| 63 |
+
('Neural Network', MLPClassifier(max_iter=500)),
|
| 64 |
+
('Support Vector Machine', SVC(C=svc_c, kernel=svc_kernel))
|
| 65 |
+
], voting='hard')
|
| 66 |
+
|
| 67 |
+
f1_scorer = make_scorer(f1_score, average = "weighted")
|
| 68 |
+
stratified_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 69 |
+
cv_scores = cross_val_score(classifier_obj, X, y, cv=stratified_kfold, scoring=f1_scorer)
|
| 70 |
+
return cv_scores.mean()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
study = optuna.create_study(direction="maximize")
|
| 74 |
+
study.optimize(objective, n_trials=20)
|
| 75 |
+
print(f"{arg} : {study.best_trial.values[0]}")
|
| 76 |
+
|
| 77 |
+
best_clf = VotingClassifier(estimators=[
|
| 78 |
+
('Logistic Regression', LogisticRegression()),
|
| 79 |
+
('Neural Network', MLPClassifier(max_iter=500)),
|
| 80 |
+
('Support Vector Machine', SVC(C=study.best_trial.params["svc_c"], kernel=study.best_trial.params["kernel"]))
|
| 81 |
+
], voting='hard')
|
| 82 |
+
|
| 83 |
+
best_clf.fit(X, y)
|
| 84 |
+
dump(best_clf, f'{arg}.joblib')
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
Function: get_arg_roles(event_args, doc)
|
| 88 |
+
Description: This function assigns argument roles to a list of event arguments within a document.
|
| 89 |
+
Inputs:
|
| 90 |
+
- event_args: A list of event argument dictionaries, each containing information about an argument.
|
| 91 |
+
- doc: A spaCy document representing the analyzed text.
|
| 92 |
+
Output:
|
| 93 |
+
- The input 'event_args' list with updated 'role' values assigned to each argument.
|
| 94 |
+
"""
|
| 95 |
+
def get_arg_roles(event_args, doc):
|
| 96 |
+
for arg in event_args:
|
| 97 |
+
if len(arg_2_role[arg["subtype"]]) > 1:
|
| 98 |
+
sent = next(filter(lambda x : arg["startOffset"] >= x.start_char and arg["endOffset"] <= x.end_char, doc.sents))
|
| 99 |
+
|
| 100 |
+
sent_embed = embed_model.encode(sent.text)
|
| 101 |
+
arg_embed = embed_model.encode(arg["text"])
|
| 102 |
+
embed = np.concatenate((sent_embed, arg_embed))
|
| 103 |
+
|
| 104 |
+
arg_clf = classifiers[arg["subtype"]]
|
| 105 |
+
role_id = arg_clf.predict(embed.reshape(1, -1))
|
| 106 |
+
role = arg_2_role[arg["subtype"]][role_id[0]]
|
| 107 |
+
|
| 108 |
+
arg["role"] = role
|
| 109 |
+
else:
|
| 110 |
+
arg["role"] = arg_2_role[arg["subtype"]][0]
|
| 111 |
+
return event_args
|
| 112 |
+
|
| 113 |
+
|
event_nugget_predict.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from annotated_text import annotated_text
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
|
| 7 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
| 8 |
+
from cybersecurity_knowledge_graph.utils import get_idxs_from_text, event_nugget_list
|
| 9 |
+
import spacy
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 16 |
+
|
| 17 |
+
def find_dep_depth(token):
|
| 18 |
+
depth = 0
|
| 19 |
+
current_token = token
|
| 20 |
+
while current_token.head != current_token:
|
| 21 |
+
depth += 1
|
| 22 |
+
current_token = current_token.head
|
| 23 |
+
return min(depth, 16)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
nlp = spacy.load('en_core_web_sm')
|
| 27 |
+
|
| 28 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 29 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 30 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
| 31 |
+
|
| 32 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 33 |
+
|
| 34 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 36 |
+
|
| 37 |
+
model_nugget = NuggetModel(num_classes = 11)
|
| 38 |
+
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/nugget_model_state_dict.pth", map_location=device))
|
| 39 |
+
model_nugget.eval()
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
Function: create_dataloader(text_input)
|
| 43 |
+
Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
|
| 44 |
+
Inputs:
|
| 45 |
+
- text_input: The input text to be processed.
|
| 46 |
+
Output:
|
| 47 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
| 48 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
| 49 |
+
"""
|
| 50 |
+
def create_dataloader(text_input):
|
| 51 |
+
|
| 52 |
+
doc = nlp(text_input)
|
| 53 |
+
|
| 54 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
| 55 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
| 56 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
| 57 |
+
|
| 58 |
+
data = []
|
| 59 |
+
|
| 60 |
+
words = []
|
| 61 |
+
|
| 62 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
| 63 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
| 64 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
| 65 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
| 66 |
+
|
| 67 |
+
for content_dict in content_idx_dict:
|
| 68 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
| 69 |
+
words.append(content_dict["word"])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
| 73 |
+
if content_token_len > tokenizer.model_max_length:
|
| 74 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
| 75 |
+
split_len = (len(words) // no_split) + 1
|
| 76 |
+
|
| 77 |
+
last_id = 0
|
| 78 |
+
threshold = split_len
|
| 79 |
+
|
| 80 |
+
for id, token in enumerate(words):
|
| 81 |
+
if token == "." and id > threshold:
|
| 82 |
+
data.append(
|
| 83 |
+
{
|
| 84 |
+
"tokens" : words[last_id : id + 1],
|
| 85 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
| 86 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
| 87 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
| 88 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
last_id = id + 1
|
| 92 |
+
threshold += split_len
|
| 93 |
+
data.append({"tokens" : words[last_id : ],
|
| 94 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
| 95 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
| 96 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
| 97 |
+
"depth_spacy" : depth_spacy[last_id : ]})
|
| 98 |
+
else:
|
| 99 |
+
data.append(
|
| 100 |
+
{
|
| 101 |
+
"tokens" : words,
|
| 102 |
+
"pos_spacy" : pos_spacy,
|
| 103 |
+
"ner_spacy" : ner_spacy,
|
| 104 |
+
"dep_spacy" : dep_spacy,
|
| 105 |
+
"depth_spacy" : depth_spacy
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
| 111 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 112 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 113 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 114 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
| 118 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
|
| 119 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
| 120 |
+
|
| 121 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
| 122 |
+
|
| 123 |
+
batch_size = 4 # Number of input texts
|
| 124 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
| 125 |
+
# TODO : context_idx_dict should be used to index the words
|
| 126 |
+
return dataloader, tokenized_dataset_ner
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
Function: predict(dataloader)
|
| 130 |
+
Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
|
| 131 |
+
Inputs:
|
| 132 |
+
- dataloader: A DataLoader containing input data for prediction.
|
| 133 |
+
Output:
|
| 134 |
+
- predicted_label: A tensor containing the predicted labels for the input data.
|
| 135 |
+
"""
|
| 136 |
+
def predict(dataloader):
|
| 137 |
+
predicted_label = []
|
| 138 |
+
for batch in dataloader:
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
logits = model_nugget(**batch)
|
| 141 |
+
batch_predicted_label = logits.argmax(-1)
|
| 142 |
+
predicted_label.append(batch_predicted_label)
|
| 143 |
+
return torch.cat(predicted_label, dim=-1)
|
| 144 |
+
|
| 145 |
+
"""
|
| 146 |
+
Function: show_annotations(text_input)
|
| 147 |
+
Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
|
| 148 |
+
Inputs:
|
| 149 |
+
- text_input: The input text containing event nuggets to be annotated and displayed.
|
| 150 |
+
Output:
|
| 151 |
+
- An interactive display of annotated event nuggets within the input text.
|
| 152 |
+
"""
|
| 153 |
+
def show_annotations(text_input):
|
| 154 |
+
st.title("Event Nuggets")
|
| 155 |
+
|
| 156 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 157 |
+
predicted_label = predict(dataloader)
|
| 158 |
+
|
| 159 |
+
for idx, labels in enumerate(predicted_label):
|
| 160 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 161 |
+
|
| 162 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 163 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 164 |
+
|
| 165 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 166 |
+
idxs = get_idxs_from_text(text, tokens)
|
| 167 |
+
|
| 168 |
+
labels = labels[token_mask]
|
| 169 |
+
|
| 170 |
+
annotated_text_list = []
|
| 171 |
+
last_label = ""
|
| 172 |
+
cumulative_tokens = ""
|
| 173 |
+
last_id = 0
|
| 174 |
+
|
| 175 |
+
for idx, label in zip(idxs, labels):
|
| 176 |
+
to_label = event_nugget_list[label]
|
| 177 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
| 178 |
+
if last_label == label_short:
|
| 179 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
| 180 |
+
last_id = idx["end_idx"]
|
| 181 |
+
else:
|
| 182 |
+
if last_label != "":
|
| 183 |
+
if last_label == "O":
|
| 184 |
+
annotated_text_list.append(cumulative_tokens)
|
| 185 |
+
else:
|
| 186 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 187 |
+
last_label = label_short
|
| 188 |
+
cumulative_tokens = idx["word"]
|
| 189 |
+
last_id = idx["end_idx"]
|
| 190 |
+
if last_label == "O":
|
| 191 |
+
annotated_text_list.append(cumulative_tokens)
|
| 192 |
+
else:
|
| 193 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 194 |
+
annotated_text(annotated_text_list)
|
| 195 |
+
|
| 196 |
+
"""
|
| 197 |
+
Function: get_event_nuggets(text_input)
|
| 198 |
+
Description: This function extracts predicted event nuggets (event entities) from the provided input text.
|
| 199 |
+
Inputs:
|
| 200 |
+
- text_input: The input text containing event nuggets to be extracted.
|
| 201 |
+
Output:
|
| 202 |
+
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
| 203 |
+
subtype, and text content.
|
| 204 |
+
"""
|
| 205 |
+
def get_event_nuggets(text_input):
|
| 206 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 207 |
+
predicted_label = predict(dataloader)
|
| 208 |
+
|
| 209 |
+
predicted_event_nuggets = []
|
| 210 |
+
text_length = 0
|
| 211 |
+
for idx, labels in enumerate(predicted_label):
|
| 212 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 213 |
+
|
| 214 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 215 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 216 |
+
|
| 217 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 218 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
| 219 |
+
|
| 220 |
+
labels = labels[token_mask]
|
| 221 |
+
|
| 222 |
+
start_idx = 0
|
| 223 |
+
end_idx = 0
|
| 224 |
+
last_label = ""
|
| 225 |
+
|
| 226 |
+
for idx, label in zip(idxs, labels):
|
| 227 |
+
to_label = event_nugget_list[label]
|
| 228 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
| 229 |
+
|
| 230 |
+
if label_short == last_label:
|
| 231 |
+
end_idx = idx["end_idx"]
|
| 232 |
+
else:
|
| 233 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
| 234 |
+
predicted_event_nuggets.append(
|
| 235 |
+
{
|
| 236 |
+
"startOffset" : text_length + start_idx,
|
| 237 |
+
"endOffset" : text_length + end_idx,
|
| 238 |
+
"subtype" : last_label,
|
| 239 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
| 240 |
+
}
|
| 241 |
+
)
|
| 242 |
+
start_idx = idx["start_idx"]
|
| 243 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
| 244 |
+
last_label = label_short
|
| 245 |
+
|
| 246 |
+
text_length += idx["end_idx"]
|
| 247 |
+
return predicted_event_nuggets
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
event_realis_predict.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import spacy
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from cybersecurity_knowledge_graph.utils import get_idxs_from_text
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from annotated_text import annotated_text
|
| 9 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
|
| 10 |
+
from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
|
| 11 |
+
from cybersecurity_knowledge_graph.realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
|
| 12 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
| 13 |
+
|
| 14 |
+
event_nugget_list = ['B-Phishing',
|
| 15 |
+
'I-Phishing',
|
| 16 |
+
'O',
|
| 17 |
+
'B-DiscoverVulnerability',
|
| 18 |
+
'B-Ransom',
|
| 19 |
+
'I-Ransom',
|
| 20 |
+
'B-Databreach',
|
| 21 |
+
'I-DiscoverVulnerability',
|
| 22 |
+
'B-PatchVulnerability',
|
| 23 |
+
'I-PatchVulnerability',
|
| 24 |
+
'I-Databreach']
|
| 25 |
+
|
| 26 |
+
realis_list = ["O", "Generic", "Other", "Actual"]
|
| 27 |
+
|
| 28 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def find_dep_depth(token):
|
| 33 |
+
depth = 0
|
| 34 |
+
current_token = token
|
| 35 |
+
while current_token.head != current_token:
|
| 36 |
+
depth += 1
|
| 37 |
+
current_token = current_token.head
|
| 38 |
+
return min(depth, 16)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
nlp = spacy.load('en_core_web_sm')
|
| 42 |
+
|
| 43 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 44 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 45 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
| 46 |
+
|
| 47 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 48 |
+
|
| 49 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 51 |
+
|
| 52 |
+
from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
|
| 53 |
+
model_realis = RealisModel(num_classes_realis=4)
|
| 54 |
+
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
| 55 |
+
model_realis.eval()
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
Function: create_dataloader(text_input)
|
| 59 |
+
Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
|
| 60 |
+
Inputs:
|
| 61 |
+
- text_input: The input text to be processed.
|
| 62 |
+
Output:
|
| 63 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
| 64 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
| 65 |
+
"""
|
| 66 |
+
def create_dataloader(text_input):
|
| 67 |
+
|
| 68 |
+
event_nuggets = get_event_nuggets(text_input)
|
| 69 |
+
doc = nlp(text_input)
|
| 70 |
+
|
| 71 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
| 72 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
| 73 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
| 74 |
+
|
| 75 |
+
data = []
|
| 76 |
+
|
| 77 |
+
words = []
|
| 78 |
+
nugget_ner_tags = []
|
| 79 |
+
|
| 80 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
| 81 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
| 82 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
| 83 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
| 84 |
+
|
| 85 |
+
for content_dict in content_idx_dict:
|
| 86 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
| 87 |
+
entity = get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets)
|
| 88 |
+
words.append(content_dict["word"])
|
| 89 |
+
nugget_ner_tags.append(entity)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
| 93 |
+
if content_token_len > tokenizer.model_max_length:
|
| 94 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
| 95 |
+
split_len = (len(words) // no_split) + 1
|
| 96 |
+
|
| 97 |
+
last_id = 0
|
| 98 |
+
threshold = split_len
|
| 99 |
+
|
| 100 |
+
for id, token in enumerate(words):
|
| 101 |
+
if token == "." and id > threshold:
|
| 102 |
+
data.append(
|
| 103 |
+
{
|
| 104 |
+
"tokens" : words[last_id : id + 1],
|
| 105 |
+
"ner_tags" : nugget_ner_tags[last_id : id + 1],
|
| 106 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
| 107 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
| 108 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
| 109 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
last_id = id + 1
|
| 113 |
+
threshold += split_len
|
| 114 |
+
data.append({"tokens" : words[last_id : ],
|
| 115 |
+
"ner_tags" : nugget_ner_tags[last_id : ],
|
| 116 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
| 117 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
| 118 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
| 119 |
+
"depth_spacy" : depth_spacy[last_id : ]})
|
| 120 |
+
else:
|
| 121 |
+
data.append(
|
| 122 |
+
{
|
| 123 |
+
"tokens" : words,
|
| 124 |
+
"ner_tags" : nugget_ner_tags,
|
| 125 |
+
"pos_spacy" : pos_spacy,
|
| 126 |
+
"ner_spacy" : ner_spacy,
|
| 127 |
+
"dep_spacy" : dep_spacy,
|
| 128 |
+
"depth_spacy" : depth_spacy
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
| 134 |
+
'ner_tags' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_list), names=event_nugget_list, names_file=None, id=None), length=-1, id=None),
|
| 135 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 136 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 137 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
| 138 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
| 142 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_realis, fn_kwargs={'tokenizer' : tokenizer, 'ner_names' : event_nugget_list}, batched=True, load_from_cache_file=False)
|
| 143 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
| 144 |
+
|
| 145 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
| 146 |
+
|
| 147 |
+
batch_size = 4 # Number of input texts
|
| 148 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
| 149 |
+
return dataloader, tokenized_dataset_ner
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
Function: predict(dataloader)
|
| 153 |
+
Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
|
| 154 |
+
Inputs:
|
| 155 |
+
- dataloader: A DataLoader containing input data for prediction.
|
| 156 |
+
Output:
|
| 157 |
+
- predicted_label: A tensor containing the predicted labels for the input data.
|
| 158 |
+
"""
|
| 159 |
+
def predict(dataloader):
|
| 160 |
+
predicted_label = []
|
| 161 |
+
for batch in dataloader:
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
logits = model_realis(**batch)
|
| 164 |
+
|
| 165 |
+
batch_predicted_label = logits.argmax(-1)
|
| 166 |
+
predicted_label.append(batch_predicted_label)
|
| 167 |
+
return torch.cat(predicted_label, dim=-1)
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
Function: show_annotations(text_input)
|
| 171 |
+
Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
|
| 172 |
+
Inputs:
|
| 173 |
+
- text_input: The input text containing event nuggets to be annotated and displayed.
|
| 174 |
+
Output:
|
| 175 |
+
- An interactive display of annotated event nuggets within the input text.
|
| 176 |
+
"""
|
| 177 |
+
def show_annotations(text_input):
|
| 178 |
+
st.title("Event Realis")
|
| 179 |
+
|
| 180 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 181 |
+
predicted_label = predict(dataloader)
|
| 182 |
+
|
| 183 |
+
for idx, labels in enumerate(predicted_label):
|
| 184 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 185 |
+
|
| 186 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 187 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 188 |
+
|
| 189 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 190 |
+
idxs = get_idxs_from_text(text, tokens)
|
| 191 |
+
|
| 192 |
+
labels = labels[token_mask]
|
| 193 |
+
|
| 194 |
+
annotated_text_list = []
|
| 195 |
+
last_label = ""
|
| 196 |
+
cumulative_tokens = ""
|
| 197 |
+
last_id = 0
|
| 198 |
+
|
| 199 |
+
for idx, label in zip(idxs, labels):
|
| 200 |
+
to_label = realis_list[label]
|
| 201 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
| 202 |
+
if last_label == label_short:
|
| 203 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
| 204 |
+
last_id = idx["end_idx"]
|
| 205 |
+
else:
|
| 206 |
+
if last_label != "":
|
| 207 |
+
if last_label == "O":
|
| 208 |
+
annotated_text_list.append(cumulative_tokens)
|
| 209 |
+
else:
|
| 210 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 211 |
+
last_label = label_short
|
| 212 |
+
cumulative_tokens = idx["word"]
|
| 213 |
+
last_id = idx["end_idx"]
|
| 214 |
+
if last_label == "O":
|
| 215 |
+
annotated_text_list.append(cumulative_tokens)
|
| 216 |
+
else:
|
| 217 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
| 218 |
+
annotated_text(annotated_text_list)
|
| 219 |
+
|
| 220 |
+
"""
|
| 221 |
+
Function: get_event_realis(text_input)
|
| 222 |
+
Description: This function extracts predicted event realis (event modality) from the provided input text.
|
| 223 |
+
Inputs:
|
| 224 |
+
- text_input: The input text containing event realis to be extracted.
|
| 225 |
+
Output:
|
| 226 |
+
- predicted_event_realis: A list of dictionaries, each representing an extracted event realis with start and end offsets,
|
| 227 |
+
realis type, and text content.
|
| 228 |
+
"""
|
| 229 |
+
def get_event_realis(text_input):
|
| 230 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
| 231 |
+
predicted_label = predict(dataloader)
|
| 232 |
+
|
| 233 |
+
predicted_event_realis = []
|
| 234 |
+
text_length = 0
|
| 235 |
+
for idx, labels in enumerate(predicted_label):
|
| 236 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
| 237 |
+
|
| 238 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
| 239 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
| 240 |
+
|
| 241 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
| 242 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
| 243 |
+
|
| 244 |
+
labels = labels[token_mask]
|
| 245 |
+
|
| 246 |
+
start_idx = 0
|
| 247 |
+
end_idx = 0
|
| 248 |
+
last_label = ""
|
| 249 |
+
|
| 250 |
+
for idx, label in zip(idxs, labels):
|
| 251 |
+
to_label = realis_list[label]
|
| 252 |
+
label_split = to_label
|
| 253 |
+
|
| 254 |
+
if label_split == last_label:
|
| 255 |
+
end_idx = idx["end_idx"]
|
| 256 |
+
else:
|
| 257 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
| 258 |
+
predicted_event_realis.append(
|
| 259 |
+
{
|
| 260 |
+
"startOffset" : text_length + start_idx,
|
| 261 |
+
"endOffset" : text_length + end_idx,
|
| 262 |
+
"realis" : last_label,
|
| 263 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
start_idx = idx["start_idx"]
|
| 267 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
| 268 |
+
last_label = label_split
|
| 269 |
+
text_length += idx["end_idx"]
|
| 270 |
+
return predicted_event_realis
|
model_59.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09bc24b422adbe6c4c6ca1333a3a8c33146e6152e00a7ad6376cab616b51e53f
|
| 3 |
+
size 498858353
|
model_64_pos_ner.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76125c5bbce2c32e536fe74d24dc51fb1fce3ba076104b459ee290102ce4bd5d
|
| 3 |
+
size 498746934
|
model_66.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46531e8ccf92661a025b15c829be791f72416d1b458ae1aa82cc66e069193bf5
|
| 3 |
+
size 498751092
|
model_97.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6147a98aa2baaa545903103e9e2f0e55fc249ec638cfe27e273ffdd247479c4
|
| 3 |
+
size 498729523
|
nugget_model_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d04c7ccd654b3af96c1c8e0f391a20d79ae1b5970d5419680f379c6a09e78bf
|
| 3 |
+
size 498703483
|
nugget_model_utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import spacy
|
| 3 |
+
import en_core_web_sm
|
| 4 |
+
from torch import nn
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 9 |
+
|
| 10 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 14 |
+
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 16 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
| 17 |
+
|
| 18 |
+
nlp = en_core_web_sm.load()
|
| 19 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 20 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CustomRobertaWithPOS(nn.Module):
|
| 24 |
+
def __init__(self, num_classes):
|
| 25 |
+
super(CustomRobertaWithPOS, self).__init__()
|
| 26 |
+
self.num_classes = num_classes
|
| 27 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
| 28 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 16)
|
| 29 |
+
self.roberta = roberta_model
|
| 30 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 31 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size, num_classes)
|
| 32 |
+
|
| 33 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy):
|
| 34 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
| 35 |
+
last_hidden_output = outputs.last_hidden_state
|
| 36 |
+
|
| 37 |
+
pos_mask = pos_spacy != -100
|
| 38 |
+
|
| 39 |
+
pos_one_hot = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], len(pos_spacy_tag_list)), dtype=torch.long)
|
| 40 |
+
pos_one_hot[pos_mask, pos_spacy[pos_mask]] = 1
|
| 41 |
+
pos_one_hot = pos_one_hot.to(device)
|
| 42 |
+
|
| 43 |
+
ner_mask = ner_spacy != -100
|
| 44 |
+
|
| 45 |
+
ner_one_hot = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], len(ner_spacy_tag_list)), dtype=torch.long)
|
| 46 |
+
ner_one_hot[ner_mask, ner_spacy[ner_mask]] = 1
|
| 47 |
+
ner_one_hot = ner_one_hot.to(device)
|
| 48 |
+
|
| 49 |
+
features_concat = last_hidden_output
|
| 50 |
+
features_concat = self.dropout1(features_concat)
|
| 51 |
+
|
| 52 |
+
logits = self.fc1(features_concat)
|
| 53 |
+
|
| 54 |
+
return logits
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
|
| 58 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
| 59 |
+
#tokenized_inputs.pop('input_ids')
|
| 60 |
+
ner_spacy = []
|
| 61 |
+
pos_spacy = []
|
| 62 |
+
dep_spacy = []
|
| 63 |
+
depth_spacy = []
|
| 64 |
+
|
| 65 |
+
for i, (pos, ner, dep, depth) in enumerate(zip(examples["pos_spacy"],
|
| 66 |
+
examples["ner_spacy"],
|
| 67 |
+
examples["dep_spacy"],
|
| 68 |
+
examples["depth_spacy"])):
|
| 69 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 70 |
+
previous_word_idx = None
|
| 71 |
+
ner_spacy_ids = []
|
| 72 |
+
pos_spacy_ids = []
|
| 73 |
+
dep_spacy_ids = []
|
| 74 |
+
depth_spacy_ids = []
|
| 75 |
+
|
| 76 |
+
for word_idx in word_ids:
|
| 77 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
| 78 |
+
# ignored in the loss function.
|
| 79 |
+
if word_idx is None:
|
| 80 |
+
ner_spacy_ids.append(-100)
|
| 81 |
+
pos_spacy_ids.append(-100)
|
| 82 |
+
dep_spacy_ids.append(-100)
|
| 83 |
+
depth_spacy_ids.append(-100)
|
| 84 |
+
# We set the label for the first token of each word.
|
| 85 |
+
elif word_idx != previous_word_idx:
|
| 86 |
+
ner_spacy_ids.append(ner[word_idx])
|
| 87 |
+
pos_spacy_ids.append(pos[word_idx])
|
| 88 |
+
dep_spacy_ids.append(dep[word_idx])
|
| 89 |
+
depth_spacy_ids.append(depth[word_idx])
|
| 90 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
| 91 |
+
# the label_all_tokens flag.
|
| 92 |
+
else:
|
| 93 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
| 94 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
| 95 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
| 96 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
| 97 |
+
previous_word_idx = word_idx
|
| 98 |
+
|
| 99 |
+
ner_spacy.append(ner_spacy_ids)
|
| 100 |
+
pos_spacy.append(pos_spacy_ids)
|
| 101 |
+
dep_spacy.append(dep_spacy_ids)
|
| 102 |
+
depth_spacy.append(depth_spacy_ids)
|
| 103 |
+
|
| 104 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
| 105 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
| 106 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
| 107 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
| 108 |
+
|
| 109 |
+
return tokenized_inputs
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
|
| 113 |
+
nearest_subtype = None
|
| 114 |
+
nearest_dist = math.inf
|
| 115 |
+
relative_pos = None
|
| 116 |
+
|
| 117 |
+
mid_idx = (end_idx + start_idx) / 2
|
| 118 |
+
for nugget in event_nuggets:
|
| 119 |
+
mid_nugget_idx = (nugget["nugget"]["startOffset"] + nugget["nugget"]["endOffset"]) / 2
|
| 120 |
+
dist = abs(mid_nugget_idx - mid_idx)
|
| 121 |
+
|
| 122 |
+
if dist < nearest_dist:
|
| 123 |
+
nearest_dist = dist
|
| 124 |
+
nearest_subtype = nugget["subtype"]
|
| 125 |
+
for sent in doc.sents:
|
| 126 |
+
if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
|
| 127 |
+
if mid_idx < mid_nugget_idx:
|
| 128 |
+
relative_pos = "before-same-sentence"
|
| 129 |
+
else:
|
| 130 |
+
relative_pos = "after-same-sentence"
|
| 131 |
+
break
|
| 132 |
+
elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
|
| 133 |
+
relative_pos = "after-differ-sentence"
|
| 134 |
+
break
|
| 135 |
+
elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
|
| 136 |
+
relative_pos = "before-differ-sentence"
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
nearest_dist = int(min(10, nearest_dist // 20))
|
| 140 |
+
return nearest_subtype, nearest_dist, relative_pos
|
| 141 |
+
|
| 142 |
+
def find_dep_depth(token):
|
| 143 |
+
depth = 0
|
| 144 |
+
current_token = token
|
| 145 |
+
while current_token.head != current_token:
|
| 146 |
+
depth += 1
|
| 147 |
+
current_token = current_token.head
|
| 148 |
+
return min(depth, 16)
|
| 149 |
+
|
| 150 |
+
def between_idxs(idx, start_idx, end_idx):
|
| 151 |
+
return idx >= start_idx and idx <= end_idx
|
realis_model_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ad63eeee95888dc6f22e94e0a8425a99912f7d727cd255881e8630218a3b7f0
|
| 3 |
+
size 498684837
|
realis_model_utils.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import en_core_web_sm
|
| 4 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 9 |
+
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 11 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
| 12 |
+
|
| 13 |
+
event_nugget_list = ['B-Phishing',
|
| 14 |
+
'I-Phishing',
|
| 15 |
+
'O',
|
| 16 |
+
'B-DiscoverVulnerability',
|
| 17 |
+
'B-Ransom',
|
| 18 |
+
'I-Ransom',
|
| 19 |
+
'B-Databreach',
|
| 20 |
+
'I-DiscoverVulnerability',
|
| 21 |
+
'B-PatchVulnerability',
|
| 22 |
+
'I-PatchVulnerability',
|
| 23 |
+
'I-Databreach']
|
| 24 |
+
|
| 25 |
+
nlp = en_core_web_sm.load()
|
| 26 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
| 27 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
| 28 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
| 29 |
+
|
| 30 |
+
class CustomRobertaWithPOS(nn.Module):
|
| 31 |
+
def __init__(self, num_classes_realis):
|
| 32 |
+
super(CustomRobertaWithPOS, self).__init__()
|
| 33 |
+
self.num_classes_realis = num_classes_realis
|
| 34 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
| 35 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
|
| 36 |
+
self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
|
| 37 |
+
self.depth_embed = nn.Embedding(17, 8)
|
| 38 |
+
self.nugget_embed = nn.Embedding(len(event_nugget_list), 8)
|
| 39 |
+
self.roberta = roberta_model
|
| 40 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 41 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size + 48, self.num_classes_realis)
|
| 42 |
+
|
| 43 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, ner_tags):
|
| 44 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
| 45 |
+
last_hidden_output = outputs.last_hidden_state
|
| 46 |
+
|
| 47 |
+
pos_mask = pos_spacy != -100
|
| 48 |
+
pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
|
| 49 |
+
pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
|
| 50 |
+
pos_embed[pos_mask] = pos_embed_masked
|
| 51 |
+
|
| 52 |
+
ner_mask = ner_spacy != -100
|
| 53 |
+
ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
|
| 54 |
+
ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 55 |
+
ner_embed[ner_mask] = ner_embed_masked
|
| 56 |
+
|
| 57 |
+
dep_mask = dep_spacy != -100
|
| 58 |
+
dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
|
| 59 |
+
dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 60 |
+
dep_embed[dep_mask] = dep_embed_masked
|
| 61 |
+
|
| 62 |
+
depth_mask = depth_spacy != -100
|
| 63 |
+
depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
|
| 64 |
+
depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
|
| 65 |
+
depth_embed[dep_mask] = depth_embed_masked
|
| 66 |
+
|
| 67 |
+
nugget_mask = ner_tags != -100
|
| 68 |
+
nugget_embed_masked = self.nugget_embed(ner_tags[nugget_mask])
|
| 69 |
+
nugget_embed = torch.zeros((ner_tags.shape[0], ner_tags.shape[1], 8), dtype=torch.float).to(device)
|
| 70 |
+
nugget_embed[dep_mask] = nugget_embed_masked
|
| 71 |
+
|
| 72 |
+
features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nugget_embed), 2).to(device)
|
| 73 |
+
features_concat = self.dropout1(features_concat)
|
| 74 |
+
features_concat = self.dropout1(features_concat)
|
| 75 |
+
|
| 76 |
+
logits = self.fc1(features_concat)
|
| 77 |
+
|
| 78 |
+
return logits
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets):
|
| 82 |
+
event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_nuggets]
|
| 83 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
| 84 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
| 85 |
+
return "B-" + event_nuggets[idx]["subtype"]
|
| 86 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
| 87 |
+
return "I-" + event_nuggets[idx]["subtype"]
|
| 88 |
+
return "O"
|
| 89 |
+
|
| 90 |
+
def tokenize_and_align_labels_with_pos_ner_realis(examples, tokenizer, ner_names, label_all_tokens = True):
|
| 91 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
| 92 |
+
#tokenized_inputs.pop('input_ids')
|
| 93 |
+
labels = []
|
| 94 |
+
nuggets = []
|
| 95 |
+
ner_spacy = []
|
| 96 |
+
pos_spacy = []
|
| 97 |
+
dep_spacy = []
|
| 98 |
+
depth_spacy = []
|
| 99 |
+
|
| 100 |
+
for i, (nugget, pos, ner, dep, depth) in enumerate(zip(examples["ner_tags"], examples["pos_spacy"], examples["ner_spacy"], examples["dep_spacy"], examples["depth_spacy"])):
|
| 101 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 102 |
+
previous_word_idx = None
|
| 103 |
+
nugget_ids = []
|
| 104 |
+
ner_spacy_ids = []
|
| 105 |
+
pos_spacy_ids = []
|
| 106 |
+
dep_spacy_ids = []
|
| 107 |
+
depth_spacy_ids = []
|
| 108 |
+
|
| 109 |
+
for word_idx in word_ids:
|
| 110 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
| 111 |
+
# ignored in the loss function.
|
| 112 |
+
if word_idx is None:
|
| 113 |
+
nugget_ids.append(-100)
|
| 114 |
+
ner_spacy_ids.append(-100)
|
| 115 |
+
pos_spacy_ids.append(-100)
|
| 116 |
+
dep_spacy_ids.append(-100)
|
| 117 |
+
depth_spacy_ids.append(-100)
|
| 118 |
+
# We set the label for the first token of each word.
|
| 119 |
+
elif word_idx != previous_word_idx:
|
| 120 |
+
nugget_ids.append(nugget[word_idx])
|
| 121 |
+
ner_spacy_ids.append(ner[word_idx])
|
| 122 |
+
pos_spacy_ids.append(pos[word_idx])
|
| 123 |
+
dep_spacy_ids.append(dep[word_idx])
|
| 124 |
+
depth_spacy_ids.append(depth[word_idx])
|
| 125 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
| 126 |
+
# the label_all_tokens flag.
|
| 127 |
+
else:
|
| 128 |
+
nugget_ids.append(nugget[word_idx] if label_all_tokens else -100)
|
| 129 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
| 130 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
| 131 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
| 132 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
| 133 |
+
previous_word_idx = word_idx
|
| 134 |
+
|
| 135 |
+
nuggets.append(nugget_ids)
|
| 136 |
+
ner_spacy.append(ner_spacy_ids)
|
| 137 |
+
pos_spacy.append(pos_spacy_ids)
|
| 138 |
+
dep_spacy.append(dep_spacy_ids)
|
| 139 |
+
depth_spacy.append(depth_spacy_ids)
|
| 140 |
+
|
| 141 |
+
tokenized_inputs["ner_tags"] = nuggets
|
| 142 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
| 143 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
| 144 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
| 145 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
| 146 |
+
return tokenized_inputs
|
utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
list_of_pos_tags = [
|
| 2 |
+
"ADJ",
|
| 3 |
+
"ADP",
|
| 4 |
+
"ADV",
|
| 5 |
+
"AUX",
|
| 6 |
+
"CCONJ",
|
| 7 |
+
"DET",
|
| 8 |
+
"INTJ",
|
| 9 |
+
"NOUN",
|
| 10 |
+
"NUM",
|
| 11 |
+
"PART",
|
| 12 |
+
"PRON",
|
| 13 |
+
"PROPN",
|
| 14 |
+
"PUNCT",
|
| 15 |
+
"SCONJ",
|
| 16 |
+
"SYM",
|
| 17 |
+
"VERB",
|
| 18 |
+
"X"
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
realis_list = ["O",
|
| 22 |
+
"Generic",
|
| 23 |
+
"Other",
|
| 24 |
+
"Actual"
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
event_args_list = ['O',
|
| 29 |
+
'B-System',
|
| 30 |
+
'I-System',
|
| 31 |
+
'B-Organization',
|
| 32 |
+
'B-Money',
|
| 33 |
+
'I-Money',
|
| 34 |
+
'B-Device',
|
| 35 |
+
'B-Person',
|
| 36 |
+
'I-Person',
|
| 37 |
+
'B-Vulnerability',
|
| 38 |
+
'I-Vulnerability',
|
| 39 |
+
'B-Capabilities',
|
| 40 |
+
'I-Capabilities',
|
| 41 |
+
'I-Organization',
|
| 42 |
+
'B-PaymentMethod',
|
| 43 |
+
'I-PaymentMethod',
|
| 44 |
+
'B-Data',
|
| 45 |
+
'I-Data',
|
| 46 |
+
'B-Number',
|
| 47 |
+
'I-Number',
|
| 48 |
+
'B-Malware',
|
| 49 |
+
'I-Malware',
|
| 50 |
+
'B-PII',
|
| 51 |
+
'I-PII',
|
| 52 |
+
'B-CVE',
|
| 53 |
+
'I-CVE',
|
| 54 |
+
'B-Purpose',
|
| 55 |
+
'I-Purpose',
|
| 56 |
+
'B-File',
|
| 57 |
+
'I-File',
|
| 58 |
+
'I-Device',
|
| 59 |
+
'B-Time',
|
| 60 |
+
'I-Time',
|
| 61 |
+
'B-Software',
|
| 62 |
+
'I-Software',
|
| 63 |
+
'B-Patch',
|
| 64 |
+
'I-Patch',
|
| 65 |
+
'B-Version',
|
| 66 |
+
'I-Version',
|
| 67 |
+
'B-Website',
|
| 68 |
+
'I-Website',
|
| 69 |
+
'B-GPE',
|
| 70 |
+
'I-GPE'
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
event_nugget_list = ['O',
|
| 74 |
+
'B-Ransom',
|
| 75 |
+
'I-Ransom',
|
| 76 |
+
'B-DiscoverVulnerability',
|
| 77 |
+
'I-DiscoverVulnerability',
|
| 78 |
+
'B-PatchVulnerability',
|
| 79 |
+
'I-PatchVulnerability',
|
| 80 |
+
'B-Databreach',
|
| 81 |
+
'I-Databreach',
|
| 82 |
+
'B-Phishing',
|
| 83 |
+
'I-Phishing'
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
arg_2_role = {
|
| 87 |
+
"File" : ['Tool', 'Trusted-Entity'],
|
| 88 |
+
"Person" : ['Victim', 'Attacker', 'Discoverer', 'Releaser', 'Trusted-Entity', 'Vulnerable_System_Owner'],
|
| 89 |
+
"Capabilities" : ['Attack-Pattern', 'Capabilities', 'Issues-Addressed'],
|
| 90 |
+
"Purpose" : ['Purpose'],
|
| 91 |
+
"Time" : ['Time'],
|
| 92 |
+
"PII" : ['Compromised-Data', 'Trusted-Entity'],
|
| 93 |
+
"Data" : ['Compromised-Data', 'Trusted-Entity'],
|
| 94 |
+
"Organization" : ['Victim', 'Releaser', 'Discoverer', 'Attacker', 'Vulnerable_System_Owner', 'Trusted-Entity'],
|
| 95 |
+
"Patch" : ['Patch'],
|
| 96 |
+
"Software" : ['Vulnerable_System', 'Victim', 'Trusted-Entity', 'Supported_Platform'],
|
| 97 |
+
"Vulnerability" : ['Vulnerability'],
|
| 98 |
+
"Version" : ['Patch-Number', 'Vulnerable_System_Version'],
|
| 99 |
+
"Device" : ['Vulnerable_System', 'Victim', 'Supported_Platform'],
|
| 100 |
+
"CVE" : ['CVE'],
|
| 101 |
+
"Number" : ['Number-of-Data', 'Number-of-Victim'],
|
| 102 |
+
"System" : ['Victim', 'Supported_Platform', 'Vulnerable_System', 'Trusted-Entity'],
|
| 103 |
+
"Malware" : ['Tool'],
|
| 104 |
+
"Money" : ['Price', 'Damage-Amount'],
|
| 105 |
+
"PaymentMethod" : ['Payment-Method'],
|
| 106 |
+
"GPE" : ['Place'],
|
| 107 |
+
"Website" : ['Trusted-Entity', 'Tool', 'Vulnerable_System', 'Victim', 'Supported_Platform'],
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def get_content(data):
|
| 111 |
+
return data["content"]
|
| 112 |
+
|
| 113 |
+
def get_event_nugget(data):
|
| 114 |
+
return [
|
| 115 |
+
{"nugget" : event["nugget"], "type" : event["type"], "subtype" : event["subtype"], "realis" : event["realis"]}
|
| 116 |
+
for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]
|
| 117 |
+
]
|
| 118 |
+
def get_event_args(data):
|
| 119 |
+
events = [event for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]]
|
| 120 |
+
args = []
|
| 121 |
+
for event in events:
|
| 122 |
+
if "argument" in event.keys():
|
| 123 |
+
args.extend(event["argument"])
|
| 124 |
+
return args
|
| 125 |
+
|
| 126 |
+
def get_idxs_from_text(text, text_tokenized):
|
| 127 |
+
rest_text = text
|
| 128 |
+
last_idx = 0
|
| 129 |
+
result_dict = []
|
| 130 |
+
|
| 131 |
+
for substring in text_tokenized:
|
| 132 |
+
index = rest_text.find(substring)
|
| 133 |
+
result_dict.append(
|
| 134 |
+
{
|
| 135 |
+
"word" : substring,
|
| 136 |
+
"start_idx" : last_idx + index,
|
| 137 |
+
"end_idx" : last_idx + index + len(substring)
|
| 138 |
+
}
|
| 139 |
+
)
|
| 140 |
+
rest_text = rest_text[index + len(substring) : ]
|
| 141 |
+
last_idx += index + len(substring)
|
| 142 |
+
return result_dict
|
| 143 |
+
|
| 144 |
+
def get_entity_from_idx(start_idx, end_idx, event_nuggets):
|
| 145 |
+
event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
|
| 146 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
| 147 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
| 148 |
+
return "B-" + event_nuggets[idx]["subtype"]
|
| 149 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
| 150 |
+
return "I-" + event_nuggets[idx]["subtype"]
|
| 151 |
+
return "O"
|
| 152 |
+
|
| 153 |
+
def get_entity_and_realis_from_idx(start_idx, end_idx, event_nuggets):
|
| 154 |
+
event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
|
| 155 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
| 156 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
| 157 |
+
return "B-" + event_nuggets[idx]["subtype"], "B-" + event_nuggets[idx]["realis"]
|
| 158 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
| 159 |
+
return "I-" + event_nuggets[idx]["subtype"], "I-" + event_nuggets[idx]["realis"]
|
| 160 |
+
return "O", "O"
|
| 161 |
+
|
| 162 |
+
def get_args_entity_from_idx(start_idx, end_idx, event_args):
|
| 163 |
+
event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_args]
|
| 164 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
| 165 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
| 166 |
+
return "B-" + event_args[idx]["type"]
|
| 167 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
| 168 |
+
return "I-" + event_args[idx]["type"]
|
| 169 |
+
return "O"
|
| 170 |
+
|
| 171 |
+
def split_with_character(string, char):
|
| 172 |
+
result = []
|
| 173 |
+
start = 0
|
| 174 |
+
for i, c in enumerate(string):
|
| 175 |
+
if c == char:
|
| 176 |
+
result.append(string[start:i])
|
| 177 |
+
result.append(char)
|
| 178 |
+
start = i + 1
|
| 179 |
+
result.append(string[start:])
|
| 180 |
+
return [x for x in result if x != '']
|
| 181 |
+
|
| 182 |
+
def extend_list_with_character(content_list, character):
|
| 183 |
+
content_as_words = []
|
| 184 |
+
for word in content_list:
|
| 185 |
+
if character in word:
|
| 186 |
+
split_list = split_with_character(word, character)
|
| 187 |
+
content_as_words.extend(split_list)
|
| 188 |
+
else:
|
| 189 |
+
content_as_words.append(word)
|
| 190 |
+
return content_as_words
|
| 191 |
+
|
| 192 |
+
def find_dict_by_overlap(list_of_dicts, key_value_pairs):
|
| 193 |
+
for dictionary in list_of_dicts:
|
| 194 |
+
if max(dictionary["start"], dictionary["end"]) >= min(key_value_pairs["start"], key_value_pairs["end"]) and max(key_value_pairs["start"], key_value_pairs["end"]) >= min(dictionary["start"], dictionary["end"]):
|
| 195 |
+
return dictionary
|
| 196 |
+
return None
|