File size: 7,440 Bytes
77eacb7 cb8b12f 77eacb7 1fb2ae2 77eacb7 cb8b12f 77eacb7 cb8b12f 77eacb7 cb8b12f 77eacb7 cb8b12f 77eacb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from transformers import PreTrainedModel
import torch
import joblib, os
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
from .configuration import CybersecurityKnowledgeGraphConfig
from cybersecurity_knowledge_graph.event_nugget_predict import create_dataloader as event_nugget_dataloader
from cybersecurity_knowledge_graph.event_realis_predict import create_dataloader as event_realis_dataloader
from cybersecurity_knowledge_graph.event_arg_predict import create_dataloader as event_argument_dataloader
class CybersecurityKnowledgeGraphModel(PreTrainedModel):
config_class = CybersecurityKnowledgeGraphConfig
def __init__(self, config):
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained("ehsanaghaei/SecureBERT")
self.event_nugget_model_path = config.event_nugget_model_path
self.event_argument_model_path = config.event_argument_model_path
self.event_realis_model_path = config.event_realis_model_path
self.event_nugget_dataloader = event_nugget_dataloader
self.event_argument_dataloader = event_argument_dataloader
self.event_realis_dataloader = event_realis_dataloader
self.event_nugget_model = NuggetModel(num_classes = 11)
self.event_argument_model = ArgumentModel(num_classes = 43)
self.event_realis_model = RealisModel(num_classes_realis = 4)
self.event_nugget_model.load_state_dict(torch.load(self.event_nugget_model_path))
self.event_realis_model.load_state_dict(torch.load(self.event_realis_model_path))
self.event_argument_model.load_state_dict(torch.load(self.event_argument_model_path))
role_classifiers = {}
folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
for filename in os.listdir(os.getcwd() + folder_path):
if filename.endswith('.joblib'):
file_path = os.getcwd() + os.path.join(folder_path, filename)
clf = joblib.load(file_path)
arg = filename.split(".")[0]
role_classifiers[arg] = clf
self.role_classifiers = role_classifiers
self.embed_model = SentenceTransformer('sentence_transformer')
self.event_nugget_list = config.event_nugget_list
self.event_args_list = config.event_args_list
self.realis_list = config.realis_list
self.arg_2_role = config.arg_2_role
def forward(self, text):
nugget_dataloader, _ = self.event_nugget_dataloader(text)
argument_dataloader, _ = self.event_argument_dataloader(text)
realis_dataloader, _ = self.event_realis_dataloader(text)
nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
no_nuggets = torch.all(nugget_pred == 0, dim=1)
argument_preds = torch.empty(nugget_pred.size())
realis_preds = torch.empty(nugget_pred.size())
for idx, (batch, no_nugget) in enumerate(zip(nugget_pred, no_nuggets)):
if no_nugget:
argument_pred, realis_pred = torch.zeros(batch.size()), torch.zeros(batch.size())
else:
argument_pred = self.forward_model(self.event_argument_model, argument_dataloader)
realis_pred = self.forward_model(self.event_realis_model, realis_dataloader)
argument_preds[idx] = argument_pred
realis_preds[idx] = realis_pred
attention_mask = [batch["attention_mask"] for batch in nugget_dataloader]
attention_mask = torch.cat(attention_mask, dim=-1)
input_ids = [batch["input_ids"] for batch in nugget_dataloader]
input_ids = torch.cat(input_ids, dim=-1)
output = {"nugget" : nugget_pred, "argument" : argument_preds, "realis" : realis_preds, "input_ids" : input_ids, "attention_mask" : attention_mask}
no_of_batch = output['input_ids'].shape[0]
structured_output = []
for b in range(no_of_batch):
token_mask = [True if self.tokenizer.decode(token) not in self.tokenizer.all_special_tokens else False for token in output['input_ids'][b]]
filtered_ids = output['input_ids'][b][token_mask]
filtered_tokens = [self.tokenizer.decode(token) for token in filtered_ids]
filtered_nuggets = output['nugget'][b][token_mask]
filtered_args = output['argument'][b][token_mask]
filtered_realis = output['realis'][b][token_mask]
batch_output = [{"id" : id.item(), "token" : token, "nugget" : self.event_nugget_list[int(nugget.item())], "argument" : self.event_args_list[int(arg.item())], "realis" : self.realis_list[int(realis.item())]}
for id, token, nugget, arg, realis in zip(filtered_ids, filtered_tokens, filtered_nuggets, filtered_args, filtered_realis)]
structured_output.extend(batch_output)
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
entities = []
current_entity = None
for position, label, token in args:
if label.startswith('B-'):
if current_entity is not None:
entities.append(current_entity)
current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
elif label.startswith('I-'):
if current_entity is not None:
current_entity['text'] += ' ' + token.replace(" ", "")
current_entity['end'] = position
for entity in entities:
context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
entity["context"] = context
for entity in entities:
if len(self.arg_2_role[entity["label"]]) > 1:
sent_embed = self.embed_model.encode(entity["context"])
arg_embed = self.embed_model.encode(entity["text"])
embed = np.concatenate((sent_embed, arg_embed))
arg_clf = self.role_classifiers[entity["label"]]
role_id = arg_clf.predict(embed.reshape(1, -1))
role = self.arg_2_role[entity["label"]][role_id[0]]
entity["role"] = role
else:
entity["role"] = self.arg_2_role[entity["label"]][0]
for item in structured_output:
item["role"] = "O"
for entity in entities:
for i in range(entity["start"], entity["end"] + 1):
structured_output[i]["role"] = entity["role"]
return structured_output
def forward_model(self, model, dataloader):
predicted_label = []
for batch in dataloader:
with torch.no_grad():
logits = model(**batch)
batch_predicted_label = logits.argmax(-1)
predicted_label.append(batch_predicted_label)
return torch.cat(predicted_label, dim=-1) |