make FlashAttention logic more robust
Browse files- modeling_gptbert.py +5 -5
modeling_gptbert.py
CHANGED
@@ -367,7 +367,7 @@ class SelfAttention(nn.Module):
|
|
367 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
368 |
|
369 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
370 |
-
if
|
371 |
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
372 |
else:
|
373 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
@@ -418,7 +418,7 @@ class SelfAttention(nn.Module):
|
|
418 |
|
419 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
420 |
# Get original shape info
|
421 |
-
if
|
422 |
# Unpadded case
|
423 |
indices, cu_seqlens, max_seqlen = padding_info
|
424 |
total_seqlen = hidden_layer.size(0)
|
@@ -433,7 +433,7 @@ class SelfAttention(nn.Module):
|
|
433 |
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
434 |
value = self.v_proj(hidden_layer)
|
435 |
|
436 |
-
if
|
437 |
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
438 |
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
439 |
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
@@ -645,7 +645,7 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
645 |
else:
|
646 |
attention_mask = attention_mask.bool()
|
647 |
|
648 |
-
if
|
649 |
if len(attention_mask.size()) != 2:
|
650 |
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
651 |
with torch.no_grad():
|
@@ -676,7 +676,7 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
676 |
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
677 |
|
678 |
# Pad output if using FlashAttention
|
679 |
-
if
|
680 |
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
681 |
if output_hidden_states:
|
682 |
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|
|
|
367 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
368 |
|
369 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
370 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
371 |
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
372 |
else:
|
373 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
|
|
418 |
|
419 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
420 |
# Get original shape info
|
421 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
422 |
# Unpadded case
|
423 |
indices, cu_seqlens, max_seqlen = padding_info
|
424 |
total_seqlen = hidden_layer.size(0)
|
|
|
433 |
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
434 |
value = self.v_proj(hidden_layer)
|
435 |
|
436 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
437 |
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
438 |
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
439 |
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
|
|
645 |
else:
|
646 |
attention_mask = attention_mask.bool()
|
647 |
|
648 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
649 |
if len(attention_mask.size()) != 2:
|
650 |
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
651 |
with torch.no_grad():
|
|
|
676 |
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
677 |
|
678 |
# Pad output if using FlashAttention
|
679 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
680 |
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
681 |
if output_hidden_states:
|
682 |
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|