|  | import streamlit as st | 
					
						
						|  | from annotated_text import annotated_text | 
					
						
						|  | import torch | 
					
						
						|  | from torch.utils.data import DataLoader | 
					
						
						|  |  | 
					
						
						|  | from cybersecurity_knowledge_graph.args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth | 
					
						
						|  | from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS | 
					
						
						|  | from cybersecurity_knowledge_graph.utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list | 
					
						
						|  |  | 
					
						
						|  | from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets | 
					
						
						|  | import spacy | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  | from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | os.environ["TOKENIZERS_PARALLELISM"] = "true" | 
					
						
						|  |  | 
					
						
						|  | def find_dep_depth(token): | 
					
						
						|  | depth = 0 | 
					
						
						|  | current_token = token | 
					
						
						|  | while current_token.head != current_token: | 
					
						
						|  | depth += 1 | 
					
						
						|  | current_token = current_token.head | 
					
						
						|  | return min(depth, 16) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nlp = spacy.load('en_core_web_sm') | 
					
						
						|  |  | 
					
						
						|  | pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"] | 
					
						
						|  | ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"] | 
					
						
						|  | dep_spacy_tag_list = list(nlp.get_pipe("parser").labels) | 
					
						
						|  | event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"] | 
					
						
						|  | arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"] | 
					
						
						|  |  | 
					
						
						|  | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | 
					
						
						|  |  | 
					
						
						|  | model_checkpoint = "ehsanaghaei/SecureBERT" | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) | 
					
						
						|  |  | 
					
						
						|  | from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel | 
					
						
						|  | model_nugget = ArgumentModel(num_classes=43) | 
					
						
						|  | model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device)) | 
					
						
						|  | model_nugget.eval() | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Function: create_dataloader(text_input) | 
					
						
						|  | Description: This function creates a DataLoader for processing text data, tokenizes it, and organizes it into batches. | 
					
						
						|  | Inputs: | 
					
						
						|  | - text_input: The input text to be processed. | 
					
						
						|  | Output: | 
					
						
						|  | - dataloader: A DataLoader for the tokenized and batched text data. | 
					
						
						|  | - tokenized_dataset_ner: The tokenized dataset used for training. | 
					
						
						|  | """ | 
					
						
						|  | def create_dataloader(text_input): | 
					
						
						|  |  | 
					
						
						|  | event_nuggets = get_event_nuggets(text_input) | 
					
						
						|  | doc = nlp(text_input) | 
					
						
						|  |  | 
					
						
						|  | content_as_words_emdash = [tok.text for tok in doc] | 
					
						
						|  | content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash] | 
					
						
						|  | content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash) | 
					
						
						|  |  | 
					
						
						|  | data = [] | 
					
						
						|  |  | 
					
						
						|  | words = [] | 
					
						
						|  | arg_nugget_nearest_subtype = [] | 
					
						
						|  | arg_nugget_nearest_dist = [] | 
					
						
						|  | arg_nugget_relative_pos = [] | 
					
						
						|  |  | 
					
						
						|  | pos_spacy = [tok.pos_ for tok in doc] | 
					
						
						|  | ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc] | 
					
						
						|  | dep_spacy = [tok.dep_ for tok in doc] | 
					
						
						|  | depth_spacy = [find_dep_depth(tok) for tok in doc] | 
					
						
						|  |  | 
					
						
						|  | for content_dict in content_idx_dict: | 
					
						
						|  | start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"] | 
					
						
						|  | nearest_subtype, nearest_dist, relative_pos = find_nearest_nugget_features(doc, content_dict["start_idx"], content_dict["end_idx"], event_nuggets) | 
					
						
						|  | words.append(content_dict["word"]) | 
					
						
						|  |  | 
					
						
						|  | arg_nugget_nearest_subtype.append(nearest_subtype) | 
					
						
						|  | arg_nugget_nearest_dist.append(nearest_dist) | 
					
						
						|  | arg_nugget_relative_pos.append(relative_pos) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"]) | 
					
						
						|  | if content_token_len > tokenizer.model_max_length: | 
					
						
						|  | no_split = (content_token_len // tokenizer.model_max_length) + 2 | 
					
						
						|  | split_len = (len(words) // no_split) + 1 | 
					
						
						|  |  | 
					
						
						|  | last_id = 0 | 
					
						
						|  | threshold = split_len | 
					
						
						|  |  | 
					
						
						|  | for id, token in enumerate(words): | 
					
						
						|  | if token == "." and id > threshold: | 
					
						
						|  | data.append( | 
					
						
						|  | { | 
					
						
						|  | "tokens" : words[last_id : id + 1], | 
					
						
						|  | "pos_spacy" : pos_spacy[last_id : id + 1], | 
					
						
						|  | "ner_spacy" : ner_spacy[last_id : id + 1], | 
					
						
						|  | "dep_spacy" : dep_spacy[last_id : id + 1], | 
					
						
						|  | "depth_spacy" : depth_spacy[last_id : id + 1], | 
					
						
						|  | "nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : id + 1], | 
					
						
						|  | "nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : id + 1], | 
					
						
						|  | "arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : id + 1] | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  | last_id = id + 1 | 
					
						
						|  | threshold += split_len | 
					
						
						|  | data.append({"tokens" : words[last_id : ], | 
					
						
						|  | "pos_spacy" : pos_spacy[last_id : ], | 
					
						
						|  | "ner_spacy" : ner_spacy[last_id : ], | 
					
						
						|  | "dep_spacy" : dep_spacy[last_id : ], | 
					
						
						|  | "depth_spacy" : depth_spacy[last_id : ], | 
					
						
						|  | "nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : ], | 
					
						
						|  | "nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : ], | 
					
						
						|  | "arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : ]}) | 
					
						
						|  | else: | 
					
						
						|  | data.append( | 
					
						
						|  | { | 
					
						
						|  | "tokens" : words, | 
					
						
						|  | "pos_spacy" : pos_spacy, | 
					
						
						|  | "ner_spacy" : ner_spacy, | 
					
						
						|  | "dep_spacy" : dep_spacy, | 
					
						
						|  | "depth_spacy" : depth_spacy, | 
					
						
						|  | "nearest_nugget_subtype" : arg_nugget_nearest_subtype, | 
					
						
						|  | "nearest_nugget_dist" : arg_nugget_nearest_dist, | 
					
						
						|  | "arg_nugget_relative_pos" : arg_nugget_relative_pos | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), | 
					
						
						|  | 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'nearest_nugget_subtype' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_tag_list), names=event_nugget_tag_list, names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'nearest_nugget_dist' : Sequence(feature=ClassLabel(num_classes=11, names=list(range(11)), names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | 'arg_nugget_relative_pos' : Sequence(feature=ClassLabel(num_classes=len(arg_nugget_relative_pos_tag_list), names=arg_nugget_relative_pos_tag_list, names_file=None, id=None), length=-1, id=None), | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  | dataset = Dataset.from_list(data, features=ner_features) | 
					
						
						|  | tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False) | 
					
						
						|  | tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch") | 
					
						
						|  |  | 
					
						
						|  | tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens") | 
					
						
						|  |  | 
					
						
						|  | batch_size = 4 | 
					
						
						|  | dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size) | 
					
						
						|  | return dataloader, tokenized_dataset_ner | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Function: predict(dataloader) | 
					
						
						|  | Description: This function performs prediction on a given dataloader using a trained model for label classification. | 
					
						
						|  | Inputs: | 
					
						
						|  | - dataloader: A DataLoader containing the input data for prediction. | 
					
						
						|  | Output: | 
					
						
						|  | - predicted_label: A tensor containing the predicted labels for each input in the dataloader. | 
					
						
						|  | """ | 
					
						
						|  | def predict(dataloader): | 
					
						
						|  | predicted_label = [] | 
					
						
						|  | for batch in dataloader: | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | logits = model_nugget(**batch) | 
					
						
						|  |  | 
					
						
						|  | batch_predicted_label = logits.argmax(-1) | 
					
						
						|  | predicted_label.append(batch_predicted_label) | 
					
						
						|  | return torch.cat(predicted_label, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Function: show_annotations(text_input) | 
					
						
						|  | Description: This function displays annotated event arguments in the provided input text. | 
					
						
						|  | Inputs: | 
					
						
						|  | - text_input: The input text containing event arguments to be annotated and displayed. | 
					
						
						|  | Output: | 
					
						
						|  | - An interactive display of annotated event arguments within the input text. | 
					
						
						|  | """ | 
					
						
						|  | def show_annotations(text_input): | 
					
						
						|  | st.title("Event Arguments") | 
					
						
						|  |  | 
					
						
						|  | dataloader, tokenized_dataset_ner = create_dataloader(text_input) | 
					
						
						|  | predicted_label = predict(dataloader) | 
					
						
						|  |  | 
					
						
						|  | for idx, labels in enumerate(predicted_label): | 
					
						
						|  | token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] | 
					
						
						|  |  | 
					
						
						|  | tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) | 
					
						
						|  | tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] | 
					
						
						|  |  | 
					
						
						|  | text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) | 
					
						
						|  | idxs = get_idxs_from_text(text, tokens) | 
					
						
						|  |  | 
					
						
						|  | labels = labels[token_mask] | 
					
						
						|  |  | 
					
						
						|  | annotated_text_list = [] | 
					
						
						|  | last_label = "" | 
					
						
						|  | cumulative_tokens = "" | 
					
						
						|  | last_id = 0 | 
					
						
						|  |  | 
					
						
						|  | for idx, label in zip(idxs, labels): | 
					
						
						|  | to_label = event_args_list[label] | 
					
						
						|  | label_short = to_label.split("-")[1] if "-" in to_label else to_label | 
					
						
						|  | if last_label == label_short: | 
					
						
						|  | cumulative_tokens += text[last_id : idx["end_idx"]] | 
					
						
						|  | last_id = idx["end_idx"] | 
					
						
						|  | else: | 
					
						
						|  | if last_label != "": | 
					
						
						|  | if last_label == "O": | 
					
						
						|  | annotated_text_list.append(cumulative_tokens) | 
					
						
						|  | else: | 
					
						
						|  | annotated_text_list.append((cumulative_tokens, last_label)) | 
					
						
						|  | last_label = label_short | 
					
						
						|  | cumulative_tokens = idx["word"] | 
					
						
						|  | last_id = idx["end_idx"] | 
					
						
						|  | if last_label == "O": | 
					
						
						|  | annotated_text_list.append(cumulative_tokens) | 
					
						
						|  | else: | 
					
						
						|  | annotated_text_list.append((cumulative_tokens, last_label)) | 
					
						
						|  |  | 
					
						
						|  | annotated_text(annotated_text_list) | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Function: get_event_args(text_input) | 
					
						
						|  | Description: This function extracts predicted event arguments (event nuggets) from the provided input text. | 
					
						
						|  | Inputs: | 
					
						
						|  | - text_input: The input text containing event nuggets to be extracted. | 
					
						
						|  | Output: | 
					
						
						|  | - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets, | 
					
						
						|  | subtype, and text content. | 
					
						
						|  | """ | 
					
						
						|  | def get_event_args(text_input): | 
					
						
						|  | dataloader, tokenized_dataset_ner = create_dataloader(text_input) | 
					
						
						|  | predicted_label = predict(dataloader) | 
					
						
						|  |  | 
					
						
						|  | predicted_event_nuggets = [] | 
					
						
						|  | text_length = 0 | 
					
						
						|  | for idx, labels in enumerate(predicted_label): | 
					
						
						|  | token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] | 
					
						
						|  |  | 
					
						
						|  | tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) | 
					
						
						|  | tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] | 
					
						
						|  |  | 
					
						
						|  | text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) | 
					
						
						|  | idxs = get_idxs_from_text(text_input[text_length : ], tokens) | 
					
						
						|  |  | 
					
						
						|  | labels = labels[token_mask] | 
					
						
						|  |  | 
					
						
						|  | start_idx = 0 | 
					
						
						|  | end_idx = 0 | 
					
						
						|  | last_label = "" | 
					
						
						|  |  | 
					
						
						|  | for idx, label in zip(idxs, labels): | 
					
						
						|  | to_label = event_args_list[label] | 
					
						
						|  | if "-" in to_label: | 
					
						
						|  | label_split = to_label.split("-")[1] | 
					
						
						|  | else: | 
					
						
						|  | label_split = to_label | 
					
						
						|  |  | 
					
						
						|  | if label_split == last_label: | 
					
						
						|  | end_idx = idx["end_idx"] | 
					
						
						|  | else: | 
					
						
						|  | if text_input[start_idx : end_idx] != "" and last_label != "O": | 
					
						
						|  | predicted_event_nuggets.append( | 
					
						
						|  | { | 
					
						
						|  | "startOffset" : text_length + start_idx, | 
					
						
						|  | "endOffset" : text_length + end_idx, | 
					
						
						|  | "subtype" : last_label, | 
					
						
						|  | "text" : text_input[text_length + start_idx : text_length + end_idx] | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  | start_idx = idx["start_idx"] | 
					
						
						|  | end_idx = idx["start_idx"] + len(idx["word"]) | 
					
						
						|  | last_label = label_split | 
					
						
						|  | text_length += idx["end_idx"] | 
					
						
						|  | return predicted_event_nuggets | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  |