Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
2690a96
1
Parent(s):
1841ebe
fixed seed issue
Browse files
model.py
CHANGED
@@ -90,9 +90,10 @@ class MimicTransformer(Module):
|
|
90 |
cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
|
91 |
else:
|
92 |
cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
|
93 |
-
|
94 |
-
last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
|
95 |
-
last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
|
|
|
96 |
xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
|
97 |
return (cls_results, xai_logits)
|
98 |
|
|
|
90 |
cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
|
91 |
else:
|
92 |
cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
|
93 |
+
last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
|
94 |
+
# last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
|
95 |
+
# last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
|
96 |
+
last_layer_attn = last_attn[:, -1, :, :]
|
97 |
xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
|
98 |
return (cls_results, xai_logits)
|
99 |
|