waveletdeboshir commited on
Commit
3e65276
·
verified ·
1 Parent(s): 005595f

Upload encoder.py

Browse files
Files changed (1) hide show
  1. encoder.py +73 -70
encoder.py CHANGED
@@ -6,13 +6,16 @@ from typing import List, Optional, Tuple, Union
6
  import torch
7
  from torch import Tensor, nn
8
 
9
- try:
10
- from flash_attn import flash_attn_func
11
 
12
- IMPORT_FLASH = True
13
- except Exception as err:
14
- IMPORT_FLASH = False
15
- IMPORT_FLASH_ERR = err
 
 
 
16
 
17
  # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
18
 
@@ -35,59 +38,59 @@ def apply_rotary_pos_emb(
35
  return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
36
 
37
 
38
- def apply_masked_flash_attn(
39
- q: Tensor,
40
- k: Tensor,
41
- v: Tensor,
42
- mask: Tensor,
43
- h: int,
44
- d_k: int,
45
- ) -> Tensor:
46
- """
47
- Applies Flash Attention with padding masks.
48
- """
49
-
50
- from einops import rearrange
51
- from flash_attn import flash_attn_varlen_func
52
- from flash_attn.bert_padding import pad_input, unpad_input
53
-
54
- pad_mask = ~mask[:, 0, :]
55
- b, t = pad_mask.shape
56
- q = q.view(b, t, h * d_k)
57
- k = k.view(b, t, h * d_k)
58
- v = v.view(b, t, h * d_k)
59
-
60
- q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
61
- q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
62
-
63
- k_unpad = unpad_input(k, pad_mask)[0]
64
- k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
65
-
66
- v_unpad = unpad_input(v, pad_mask)[0]
67
- v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
68
-
69
- lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
70
- cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
71
- max_seqlen_q = torch.max(lengths_q)
72
-
73
- output_unpad = flash_attn_varlen_func(
74
- q_unpad,
75
- k_unpad,
76
- v_unpad,
77
- cu_seqlens_q,
78
- cu_seqlens_q,
79
- max_seqlen_q,
80
- max_seqlen_q,
81
- )
82
-
83
- scores = pad_input(
84
- rearrange(output_unpad, "nnz h d -> nnz (h d)"),
85
- indices_q,
86
- b,
87
- t,
88
- )
89
-
90
- return scores
91
 
92
 
93
  class StridingSubsampling(nn.Module):
@@ -266,17 +269,17 @@ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
266
  value.view(t, b, self.h * self.d_k).transpose(0, 1),
267
  )
268
 
269
- if not self.flash_attn:
270
- scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
271
- out = self.forward_attention(v, scores, mask)
272
- else:
273
- if mask is None:
274
- scores = flash_attn_func(q, k, v)
275
- else:
276
- scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
277
-
278
- scores = scores.view(b, -1, self.h * self.d_k)
279
- out = self.linear_out(scores)
280
 
281
  return out
282
 
 
6
  import torch
7
  from torch import Tensor, nn
8
 
9
+ # try:
10
+ # from flash_attn import flash_attn_func
11
 
12
+ # IMPORT_FLASH = True
13
+ # except Exception as err:
14
+ # IMPORT_FLASH = False
15
+ # IMPORT_FLASH_ERR = err
16
+
17
+ IMPORT_FLASH = False
18
+ IMPORT_FLASH_ERR = "Flash Attention not installed."
19
 
20
  # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
21
 
 
38
  return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
39
 
40
 
41
+ # def apply_masked_flash_attn(
42
+ # q: Tensor,
43
+ # k: Tensor,
44
+ # v: Tensor,
45
+ # mask: Tensor,
46
+ # h: int,
47
+ # d_k: int,
48
+ # ) -> Tensor:
49
+ # """
50
+ # Applies Flash Attention with padding masks.
51
+ # """
52
+
53
+ # from einops import rearrange
54
+ # from flash_attn import flash_attn_varlen_func
55
+ # from flash_attn.bert_padding import pad_input, unpad_input
56
+
57
+ # pad_mask = ~mask[:, 0, :]
58
+ # b, t = pad_mask.shape
59
+ # q = q.view(b, t, h * d_k)
60
+ # k = k.view(b, t, h * d_k)
61
+ # v = v.view(b, t, h * d_k)
62
+
63
+ # q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
64
+ # q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
65
+
66
+ # k_unpad = unpad_input(k, pad_mask)[0]
67
+ # k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
68
+
69
+ # v_unpad = unpad_input(v, pad_mask)[0]
70
+ # v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
71
+
72
+ # lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
73
+ # cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
74
+ # max_seqlen_q = torch.max(lengths_q)
75
+
76
+ # output_unpad = flash_attn_varlen_func(
77
+ # q_unpad,
78
+ # k_unpad,
79
+ # v_unpad,
80
+ # cu_seqlens_q,
81
+ # cu_seqlens_q,
82
+ # max_seqlen_q,
83
+ # max_seqlen_q,
84
+ # )
85
+
86
+ # scores = pad_input(
87
+ # rearrange(output_unpad, "nnz h d -> nnz (h d)"),
88
+ # indices_q,
89
+ # b,
90
+ # t,
91
+ # )
92
+
93
+ # return scores
94
 
95
 
96
  class StridingSubsampling(nn.Module):
 
269
  value.view(t, b, self.h * self.d_k).transpose(0, 1),
270
  )
271
 
272
+ # if not self.flash_attn:
273
+ scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
274
+ out = self.forward_attention(v, scores, mask)
275
+ # else:
276
+ # if mask is None:
277
+ # scores = flash_attn_func(q, k, v)
278
+ # else:
279
+ # scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
280
+
281
+ # scores = scores.view(b, -1, self.h * self.d_k)
282
+ # out = self.linear_out(scores)
283
 
284
  return out
285