davda54 commited on
Commit
ebfe554
·
verified ·
1 Parent(s): 576f0ce

removed SDPA

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +30 -14
modeling_gptbert.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
 
7
  from functools import partial, lru_cache
8
 
@@ -37,17 +38,11 @@ try:
37
  logger.warning_once(
38
  "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
39
  )
40
- torch.backends.cuda.enable_flash_sdp(False)
41
- torch.backends.cuda.enable_mem_efficient_sdp(False)
42
- torch.backends.cuda.enable_math_sdp(True)
43
  except ImportError:
44
  flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
45
  logger.warning_once(
46
  "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
47
  )
48
- torch.backends.cuda.enable_flash_sdp(False)
49
- torch.backends.cuda.enable_mem_efficient_sdp(False)
50
- torch.backends.cuda.enable_math_sdp(True)
51
 
52
 
53
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
@@ -318,6 +313,25 @@ class RotaryPositionalEmbeddings(nn.Module):
318
  return out.type_as(x)
319
 
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  class SelfAttention(nn.Module):
322
  def __init__(self, config: GptBertConfig, layer_idx: int):
323
  super().__init__()
@@ -347,6 +361,7 @@ class SelfAttention(nn.Module):
347
  self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
348
  self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
349
 
 
350
  self.dropout = nn.Dropout(config.hidden_dropout)
351
 
352
  theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
@@ -390,14 +405,15 @@ class SelfAttention(nn.Module):
390
  else:
391
  attention_mask = window_mask
392
 
393
- output = F.scaled_dot_product_attention(
394
- query=query,
395
- key=key,
396
- value=value,
397
- attn_mask=attention_mask if not self.is_causal else None,
398
- dropout_p=self.config.attention_dropout if self.training else 0.0,
399
- is_causal=self.is_causal
400
- )
 
401
  return output
402
 
403
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
+ from torch import _softmax_backward_data as _softmax_backward_data
7
 
8
  from functools import partial, lru_cache
9
 
 
38
  logger.warning_once(
39
  "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
40
  )
 
 
 
41
  except ImportError:
42
  flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
43
  logger.warning_once(
44
  "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
45
  )
 
 
 
46
 
47
 
48
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
 
313
  return out.type_as(x)
314
 
315
 
316
+ class MaskedSoftmax(torch.autograd.Function):
317
+ @staticmethod
318
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
319
+ ctx.dim = dim
320
+ x.masked_fill_(mask, float('-inf'))
321
+ x = torch.softmax(x, ctx.dim)
322
+ x.masked_fill_(mask, 0.0)
323
+ ctx.save_for_backward(x)
324
+ return x
325
+
326
+ @staticmethod
327
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
328
+ output: torch.Tensor
329
+
330
+ output, = ctx.saved_tensors
331
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
332
+ return inputGrad, None, None
333
+
334
+
335
  class SelfAttention(nn.Module):
336
  def __init__(self, config: GptBertConfig, layer_idx: int):
337
  super().__init__()
 
361
  self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
362
  self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
363
 
364
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
365
  self.dropout = nn.Dropout(config.hidden_dropout)
366
 
367
  theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
 
405
  else:
406
  attention_mask = window_mask
407
 
408
+ attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
409
+ attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
410
+
411
+ attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
412
+ attention_probabilities = self.attention_dropout(attention_probabilities)
413
+
414
+ output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
415
+ output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
416
+
417
  return output
418
 
419
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):