CyberZHG commited on
Commit
5f13a7e
1 Parent(s): 3783374

Use input attention mask instead of casual mask in attention

Browse files

The current implementation does not work with left/leading padding.

Files changed (1) hide show
  1. modelling_RW.py +2 -2
modelling_RW.py CHANGED
@@ -271,13 +271,14 @@ class Attention(nn.Module):
271
  else:
272
  present = None
273
 
 
274
  if alibi is None:
275
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
278
 
279
  attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
  )
282
 
283
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
@@ -290,7 +291,6 @@ class Attention(nn.Module):
290
  assert not output_attentions # not supported.
291
  return outputs
292
  else:
293
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
294
  matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]
 
271
  else:
272
  present = None
273
 
274
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(query_layer.dtype)
275
  if alibi is None:
276
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
277
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
278
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
279
 
280
  attn_output = F.scaled_dot_product_attention(
281
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
282
  )
283
 
284
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
 
291
  assert not output_attentions # not supported.
292
  return outputs
293
  else:
 
294
  matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]