davda54 commited on
Commit
36aeed6
·
verified ·
1 Parent(s): 695d6bf

make FlashAttention logic more robust

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +5 -5
modeling_gptbert.py CHANGED
@@ -367,7 +367,7 @@ class SelfAttention(nn.Module):
367
  theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
368
 
369
  # Initialize rotary embeddings based on whether FlashAttention is available
370
- if is_flash_attn_2_available():
371
  self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
372
  else:
373
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
@@ -418,7 +418,7 @@ class SelfAttention(nn.Module):
418
 
419
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
420
  # Get original shape info
421
- if is_flash_attn_2_available():
422
  # Unpadded case
423
  indices, cu_seqlens, max_seqlen = padding_info
424
  total_seqlen = hidden_layer.size(0)
@@ -433,7 +433,7 @@ class SelfAttention(nn.Module):
433
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
434
  value = self.v_proj(hidden_layer)
435
 
436
- if is_flash_attn_2_available():
437
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
438
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
439
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
@@ -645,7 +645,7 @@ class GptBertModel(GptBertPreTrainedModel):
645
  else:
646
  attention_mask = attention_mask.bool()
647
 
648
- if is_flash_attn_2_available():
649
  if len(attention_mask.size()) != 2:
650
  raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
651
  with torch.no_grad():
@@ -676,7 +676,7 @@ class GptBertModel(GptBertPreTrainedModel):
676
  contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
677
 
678
  # Pad output if using FlashAttention
679
- if is_flash_attn_2_available():
680
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
681
  if output_hidden_states:
682
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
 
367
  theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
368
 
369
  # Initialize rotary embeddings based on whether FlashAttention is available
370
+ if flash_attn_varlen_qkvpacked_func is not None:
371
  self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
372
  else:
373
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
 
418
 
419
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
420
  # Get original shape info
421
+ if flash_attn_varlen_qkvpacked_func is not None:
422
  # Unpadded case
423
  indices, cu_seqlens, max_seqlen = padding_info
424
  total_seqlen = hidden_layer.size(0)
 
433
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
434
  value = self.v_proj(hidden_layer)
435
 
436
+ if flash_attn_varlen_qkvpacked_func is not None:
437
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
438
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
439
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
 
645
  else:
646
  attention_mask = attention_mask.bool()
647
 
648
+ if flash_attn_varlen_qkvpacked_func is not None:
649
  if len(attention_mask.size()) != 2:
650
  raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
651
  with torch.no_grad():
 
676
  contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
677
 
678
  # Pad output if using FlashAttention
679
+ if flash_attn_varlen_qkvpacked_func is not None:
680
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
681
  if output_hidden_states:
682
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]