cpi-connect commited on
Commit
acd5094
·
1 Parent(s): c431a37

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -1
  2. configuration.py +5 -7
  3. model.py +8 -8
config.json CHANGED
@@ -97,6 +97,10 @@
97
  "Supported_Platform"
98
  ]
99
  },
 
 
 
 
100
  "event_args_list": [
101
  "O",
102
  "B-System",
@@ -158,7 +162,6 @@
158
  ],
159
  "event_nugget_model_path": "cybersecurity_knowledge_graph/nugget_model_state_dict.pth",
160
  "event_realis_model_path": "cybersecurity_knowledge_graph/realis_model_state_dict.pth",
161
- "model_type": "cybersecurity_knowledge_graph",
162
  "realis_list": [
163
  "O",
164
  "Generic",
 
97
  "Supported_Platform"
98
  ]
99
  },
100
+ "auto_map": {
101
+ "AutoConfig": "configuration.CybersecurityKnowledgeGraphConfig",
102
+ "AutoModelForTokenClassification": "model.CybersecurityKnowledgeGraphModel"
103
+ },
104
  "event_args_list": [
105
  "O",
106
  "B-System",
 
162
  ],
163
  "event_nugget_model_path": "cybersecurity_knowledge_graph/nugget_model_state_dict.pth",
164
  "event_realis_model_path": "cybersecurity_knowledge_graph/realis_model_state_dict.pth",
 
165
  "realis_list": [
166
  "O",
167
  "Generic",
configuration.py CHANGED
@@ -1,13 +1,11 @@
1
  from transformers import PretrainedConfig
2
  import torch
3
 
4
- import utils
5
 
6
 
7
 
8
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
9
- model_type = "cybersecurity_knowledge_graph"
10
-
11
  def __init__(
12
  self,
13
  event_nugget_model_path : str = "nugget_model_state_dict.pth",
@@ -19,9 +17,9 @@ class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
19
  self.event_argument_model_path = event_argument_model_path
20
  self.event_realis_model_path = event_realis_model_path
21
 
22
- self.event_nugget_list = utils.event_nugget_list
23
- self.event_args_list = utils.event_args_list
24
- self.realis_list = utils.realis_list
25
- self.arg_2_role = utils.arg_2_role
26
 
27
  super().__init__(**kwargs)
 
1
  from transformers import PretrainedConfig
2
  import torch
3
 
4
+ import cybersecurity_knowledge_graph.utils
5
 
6
 
7
 
8
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
 
 
9
  def __init__(
10
  self,
11
  event_nugget_model_path : str = "nugget_model_state_dict.pth",
 
17
  self.event_argument_model_path = event_argument_model_path
18
  self.event_realis_model_path = event_realis_model_path
19
 
20
+ self.event_nugget_list = cybersecurity_knowledge_graph.utils.event_nugget_list
21
+ self.event_args_list = cybersecurity_knowledge_graph.utils.event_args_list
22
+ self.realis_list = cybersecurity_knowledge_graph.utils.realis_list
23
+ self.arg_2_role = cybersecurity_knowledge_graph.utils.arg_2_role
24
 
25
  super().__init__(**kwargs)
model.py CHANGED
@@ -6,15 +6,15 @@ from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer
7
 
8
 
9
- from nugget_model_utils import CustomRobertaWithPOS as NuggetModel
10
- from args_model_utils import CustomRobertaWithPOS as ArgumentModel
11
- from realis_model_utils import CustomRobertaWithPOS as RealisModel
12
 
13
- from configuration import CybersecurityKnowledgeGraphConfig
14
 
15
- from event_nugget_predict import create_dataloader as event_nugget_dataloader
16
- from event_realis_predict import create_dataloader as event_realis_dataloader
17
- from event_arg_predict import create_dataloader as event_argument_dataloader
18
 
19
  class CybersecurityKnowledgeGraphModel(PreTrainedModel):
20
  config_class = CybersecurityKnowledgeGraphConfig
@@ -40,7 +40,7 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
40
  self.event_argument_model.load_state_dict(torch.load(self.event_argument_model_path))
41
 
42
  role_classifiers = {}
43
- folder_path = '/arg_role_models'
44
 
45
  for filename in os.listdir(os.getcwd() + folder_path):
46
  if filename.endswith('.joblib'):
 
6
  from transformers import AutoTokenizer
7
 
8
 
9
+ from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
10
+ from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
11
+ from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
12
 
13
+ from .configuration import CybersecurityKnowledgeGraphConfig
14
 
15
+ from cybersecurity_knowledge_graph.event_nugget_predict import create_dataloader as event_nugget_dataloader
16
+ from cybersecurity_knowledge_graph.event_realis_predict import create_dataloader as event_realis_dataloader
17
+ from cybersecurity_knowledge_graph.event_arg_predict import create_dataloader as event_argument_dataloader
18
 
19
  class CybersecurityKnowledgeGraphModel(PreTrainedModel):
20
  config_class = CybersecurityKnowledgeGraphConfig
 
40
  self.event_argument_model.load_state_dict(torch.load(self.event_argument_model_path))
41
 
42
  role_classifiers = {}
43
+ folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
44
 
45
  for filename in os.listdir(os.getcwd() + folder_path):
46
  if filename.endswith('.joblib'):