Update modeling.py
Browse files- modeling.py +11 -5
modeling.py
CHANGED
@@ -28,6 +28,9 @@ class ST2ModelV2(nn.Module):
|
|
28 |
|
29 |
self.classifier = nn.Linear(self.config.hidden_size, 18)
|
30 |
#self.classifier = nn.Linear(self.config.hidden_size, 12)
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
|
@@ -73,12 +76,14 @@ class ST2ModelV2(nn.Module):
|
|
73 |
#logits = logits.view(batch_size, seq_len, 2, 6).permute(0, 2, 1, 3)
|
74 |
|
75 |
|
76 |
-
print(logits.shape)
|
77 |
-
st.write(logits.shape)
|
78 |
|
79 |
start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.unbind(-1)
|
80 |
|
81 |
-
|
|
|
|
|
82 |
|
83 |
|
84 |
|
@@ -90,6 +95,7 @@ class ST2ModelV2(nn.Module):
|
|
90 |
'end_arg1_logits': end_arg1_logits,
|
91 |
'start_sig_logits': start_sig_logits,
|
92 |
'end_sig_logits': end_sig_logits,
|
|
|
93 |
}
|
94 |
|
95 |
|
@@ -104,8 +110,8 @@ class ST2ModelV2(nn.Module):
|
|
104 |
word_ids,
|
105 |
rel_idx,
|
106 |
):
|
107 |
-
print(start_cause_logits.shape)
|
108 |
-
st.write(start_cause_logits.shape)
|
109 |
|
110 |
# Mask special tokens (CLS and SEP)
|
111 |
for logits in [start_cause_logits, start_effect_logits, end_cause_logits, end_effect_logits]:
|
|
|
28 |
|
29 |
self.classifier = nn.Linear(self.config.hidden_size, 18)
|
30 |
#self.classifier = nn.Linear(self.config.hidden_size, 12)
|
31 |
+
|
32 |
+
if self.args.signal_classification and not self.args.pretrained_signal_detector:
|
33 |
+
self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
|
34 |
|
35 |
|
36 |
|
|
|
76 |
#logits = logits.view(batch_size, seq_len, 2, 6).permute(0, 2, 1, 3)
|
77 |
|
78 |
|
79 |
+
#print(logits.shape)
|
80 |
+
#st.write(logits.shape)
|
81 |
|
82 |
start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.unbind(-1)
|
83 |
|
84 |
+
signal_classification_logits = None
|
85 |
+
if self.args.signal_classification and not self.args.pretrained_signal_detector:
|
86 |
+
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
|
87 |
|
88 |
|
89 |
|
|
|
95 |
'end_arg1_logits': end_arg1_logits,
|
96 |
'start_sig_logits': start_sig_logits,
|
97 |
'end_sig_logits': end_sig_logits,
|
98 |
+
'signal_classification_logits': signal_classification_logits,
|
99 |
}
|
100 |
|
101 |
|
|
|
110 |
word_ids,
|
111 |
rel_idx,
|
112 |
):
|
113 |
+
#print(start_cause_logits.shape)
|
114 |
+
#st.write(start_cause_logits.shape)
|
115 |
|
116 |
# Mask special tokens (CLS and SEP)
|
117 |
for logits in [start_cause_logits, start_effect_logits, end_cause_logits, end_effect_logits]:
|