Tarive commited on
Commit
af20e67
·
verified ·
1 Parent(s): 53fdda0

Update models/layers.py

Browse files
Files changed (1) hide show
  1. models/layers.py +26 -58
models/layers.py CHANGED
@@ -4,155 +4,123 @@ import torch
4
  from torch import nn
5
  import torch.nn.functional as F
6
 
7
- try:
8
- from flash_attn_interface import flash_attn_func # type: ignore[import]
9
- except ImportError:
10
- # Fallback to FlashAttention 2
11
- from flash_attn import flash_attn_func # type: ignore[import]
12
-
13
  from models.common import trunc_normal_init_
14
 
 
 
 
 
 
 
 
 
15
 
16
  CosSin = Tuple[torch.Tensor, torch.Tensor]
17
 
18
-
19
  def _find_multiple(a, b):
20
  return (-(a // -b)) * b
21
 
22
-
23
  def rotate_half(x: torch.Tensor):
24
- """Rotates half the hidden dims of the input."""
25
  x1 = x[..., : x.shape[-1] // 2]
26
  x2 = x[..., x.shape[-1] // 2 :]
27
  return torch.cat((-x2, x1), dim=-1)
28
 
29
-
30
  def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
31
- # q, k: [bs, seq_len, num_heads, head_dim]
32
- # cos, sin: [seq_len, head_dim]
33
  orig_dtype = q.dtype
34
  q = q.to(cos.dtype)
35
  k = k.to(cos.dtype)
36
-
37
  q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
38
  k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
39
-
40
  return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
41
 
42
-
43
  class CastedLinear(nn.Module):
44
- def __init__(self,
45
- in_features: int,
46
- out_features: int,
47
- bias: bool):
48
  super().__init__()
49
- # Truncated LeCun normal init
50
  self.weight = nn.Parameter(
51
  trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
52
  )
53
  self.bias = None
54
  if bias:
55
- # Zero init bias
56
  self.bias = nn.Parameter(torch.zeros((out_features, )))
57
 
58
  def forward(self, input: torch.Tensor) -> torch.Tensor:
59
  return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
60
 
61
-
62
  class CastedEmbedding(nn.Module):
63
- def __init__(self,
64
- num_embeddings: int,
65
- embedding_dim: int,
66
- init_std: float,
67
- cast_to: torch.dtype):
68
  super().__init__()
69
  self.cast_to = cast_to
70
-
71
- # Truncated LeCun normal init
72
  self.embedding_weight = nn.Parameter(
73
  trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
74
  )
75
-
76
  def forward(self, input: torch.Tensor) -> torch.Tensor:
77
  return F.embedding(input, self.embedding_weight.to(self.cast_to))
78
 
79
-
80
  class RotaryEmbedding(nn.Module):
81
  def __init__(self, dim, max_position_embeddings, base, device=None):
82
  super().__init__()
83
-
84
- # RoPE
85
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
86
  t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
87
  freqs = torch.outer(t, inv_freq)
88
-
89
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
90
  emb = torch.cat((freqs, freqs), dim=-1)
91
- self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
92
- self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
93
 
94
  def forward(self):
95
  return self.cos_cached, self.sin_cached
96
 
97
-
98
  class Attention(nn.Module):
99
  def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
100
  super().__init__()
101
-
102
  self.hidden_size = hidden_size
103
  self.head_dim = head_dim
104
  self.output_size = head_dim * num_heads
105
  self.num_heads = num_heads
106
  self.num_key_value_heads = num_key_value_heads
107
  self.causal = causal
108
-
109
  self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
110
  self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
111
 
112
  def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
113
  batch_size, seq_len, _ = hidden_states.shape
114
-
115
- # hidden_states: [bs, seq_len, num_heads, head_dim]
116
  qkv = self.qkv_proj(hidden_states)
117
-
118
- # Split head
119
  qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
120
  query = qkv[:, :, :self.num_heads]
121
  key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
122
  value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
123
 
124
- # RoPE
125
  if cos_sin is not None:
126
  cos, sin = cos_sin
127
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
128
 
129
- # flash attn
130
- attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
131
- if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
132
- attn_output = attn_output[0]
133
-
134
- # attn_output: [batch_size, num_heads, seq_len, head_dim]
135
- attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
 
 
 
 
 
136
  return self.o_proj(attn_output)
137
 
138
-
139
  class SwiGLU(nn.Module):
140
  def __init__(self, hidden_size: int, expansion: float):
141
  super().__init__()
142
  inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
143
-
144
  self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
145
- self.down_proj = CastedLinear(inter, hidden_size, bias=False)
146
 
147
  def forward(self, x):
148
  gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
149
  return self.down_proj(F.silu(gate) * up)
150
 
151
-
152
  def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
153
  input_dtype = hidden_states.dtype
154
  hidden_states = hidden_states.to(torch.float32)
155
-
156
  variance = hidden_states.square().mean(-1, keepdim=True)
157
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
158
- return hidden_states.to(input_dtype)
 
4
  from torch import nn
5
  import torch.nn.functional as F
6
 
 
 
 
 
 
 
7
  from models.common import trunc_normal_init_
8
 
9
+ # --- MODIFICATION: Added a fallback for FlashAttention ---
10
+ try:
11
+ from flash_attn import flash_attn_func
12
+ FLASH_ATTENTION_AVAILABLE = True
13
+ except ImportError:
14
+ FLASH_ATTENTION_AVAILABLE = False
15
+ print("⚠️ FlashAttention not found. Falling back to standard PyTorch attention.")
16
+ # --- END MODIFICATION ---
17
 
18
  CosSin = Tuple[torch.Tensor, torch.Tensor]
19
 
 
20
  def _find_multiple(a, b):
21
  return (-(a // -b)) * b
22
 
 
23
  def rotate_half(x: torch.Tensor):
 
24
  x1 = x[..., : x.shape[-1] // 2]
25
  x2 = x[..., x.shape[-1] // 2 :]
26
  return torch.cat((-x2, x1), dim=-1)
27
 
 
28
  def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
 
 
29
  orig_dtype = q.dtype
30
  q = q.to(cos.dtype)
31
  k = k.to(cos.dtype)
 
32
  q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
33
  k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
 
34
  return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
35
 
 
36
  class CastedLinear(nn.Module):
37
+ def __init__(self, in_features: int, out_features: int, bias: bool):
 
 
 
38
  super().__init__()
 
39
  self.weight = nn.Parameter(
40
  trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
41
  )
42
  self.bias = None
43
  if bias:
 
44
  self.bias = nn.Parameter(torch.zeros((out_features, )))
45
 
46
  def forward(self, input: torch.Tensor) -> torch.Tensor:
47
  return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
48
 
 
49
  class CastedEmbedding(nn.Module):
50
+ def __init__(self, num_embeddings: int, embedding_dim: int, init_std: float, cast_to: torch.dtype):
 
 
 
 
51
  super().__init__()
52
  self.cast_to = cast_to
 
 
53
  self.embedding_weight = nn.Parameter(
54
  trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
55
  )
 
56
  def forward(self, input: torch.Tensor) -> torch.Tensor:
57
  return F.embedding(input, self.embedding_weight.to(self.cast_to))
58
 
 
59
  class RotaryEmbedding(nn.Module):
60
  def __init__(self, dim, max_position_embeddings, base, device=None):
61
  super().__init__()
 
 
62
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
63
  t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
64
  freqs = torch.outer(t, inv_freq)
 
 
65
  emb = torch.cat((freqs, freqs), dim=-1)
66
+ self.register_buffer('cos_cached', emb.cos(), persistent=False)
67
+ self.register_buffer('sin_cached', emb.sin(), persistent=False)
68
 
69
  def forward(self):
70
  return self.cos_cached, self.sin_cached
71
 
 
72
  class Attention(nn.Module):
73
  def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
74
  super().__init__()
 
75
  self.hidden_size = hidden_size
76
  self.head_dim = head_dim
77
  self.output_size = head_dim * num_heads
78
  self.num_heads = num_heads
79
  self.num_key_value_heads = num_key_value_heads
80
  self.causal = causal
 
81
  self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
82
  self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
83
 
84
  def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
85
  batch_size, seq_len, _ = hidden_states.shape
 
 
86
  qkv = self.qkv_proj(hidden_states)
 
 
87
  qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
88
  query = qkv[:, :, :self.num_heads]
89
  key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
90
  value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
91
 
 
92
  if cos_sin is not None:
93
  cos, sin = cos_sin
94
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
95
 
96
+ # --- MODIFICATION: Use FlashAttention if available, otherwise fallback ---
97
+ if FLASH_ATTENTION_AVAILABLE:
98
+ attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
99
+ if isinstance(attn_output, tuple):
100
+ attn_output = attn_output[0]
101
+ else:
102
+ query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
103
+ attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=self.causal)
104
+ attn_output = attn_output.transpose(1, 2)
105
+ # --- END MODIFICATION ---
106
+
107
+ attn_output = attn_output.contiguous().view(batch_size, seq_len, self.output_size)
108
  return self.o_proj(attn_output)
109
 
 
110
  class SwiGLU(nn.Module):
111
  def __init__(self, hidden_size: int, expansion: float):
112
  super().__init__()
113
  inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
 
114
  self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
115
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
116
 
117
  def forward(self, x):
118
  gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
119
  return self.down_proj(F.silu(gate) * up)
120
 
 
121
  def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
122
  input_dtype = hidden_states.dtype
123
  hidden_states = hidden_states.to(torch.float32)
 
124
  variance = hidden_states.square().mean(-1, keepdim=True)
125
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
126
+ return hidden_states.to(input_dtype)