anamargarida commited on
Commit
55ad1b4
·
verified ·
1 Parent(s): 441f540

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +11 -5
modeling.py CHANGED
@@ -28,6 +28,9 @@ class ST2ModelV2(nn.Module):
28
 
29
  self.classifier = nn.Linear(self.config.hidden_size, 18)
30
  #self.classifier = nn.Linear(self.config.hidden_size, 12)
 
 
 
31
 
32
 
33
 
@@ -73,12 +76,14 @@ class ST2ModelV2(nn.Module):
73
  #logits = logits.view(batch_size, seq_len, 2, 6).permute(0, 2, 1, 3)
74
 
75
 
76
- print(logits.shape)
77
- st.write(logits.shape)
78
 
79
  start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.unbind(-1)
80
 
81
-
 
 
82
 
83
 
84
 
@@ -90,6 +95,7 @@ class ST2ModelV2(nn.Module):
90
  'end_arg1_logits': end_arg1_logits,
91
  'start_sig_logits': start_sig_logits,
92
  'end_sig_logits': end_sig_logits,
 
93
  }
94
 
95
 
@@ -104,8 +110,8 @@ class ST2ModelV2(nn.Module):
104
  word_ids,
105
  rel_idx,
106
  ):
107
- print(start_cause_logits.shape)
108
- st.write(start_cause_logits.shape)
109
 
110
  # Mask special tokens (CLS and SEP)
111
  for logits in [start_cause_logits, start_effect_logits, end_cause_logits, end_effect_logits]:
 
28
 
29
  self.classifier = nn.Linear(self.config.hidden_size, 18)
30
  #self.classifier = nn.Linear(self.config.hidden_size, 12)
31
+
32
+ if self.args.signal_classification and not self.args.pretrained_signal_detector:
33
+ self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
34
 
35
 
36
 
 
76
  #logits = logits.view(batch_size, seq_len, 2, 6).permute(0, 2, 1, 3)
77
 
78
 
79
+ #print(logits.shape)
80
+ #st.write(logits.shape)
81
 
82
  start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.unbind(-1)
83
 
84
+ signal_classification_logits = None
85
+ if self.args.signal_classification and not self.args.pretrained_signal_detector:
86
+ signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
87
 
88
 
89
 
 
95
  'end_arg1_logits': end_arg1_logits,
96
  'start_sig_logits': start_sig_logits,
97
  'end_sig_logits': end_sig_logits,
98
+ 'signal_classification_logits': signal_classification_logits,
99
  }
100
 
101
 
 
110
  word_ids,
111
  rel_idx,
112
  ):
113
+ #print(start_cause_logits.shape)
114
+ #st.write(start_cause_logits.shape)
115
 
116
  # Mask special tokens (CLS and SEP)
117
  for logits in [start_cause_logits, start_effect_logits, end_cause_logits, end_effect_logits]: