Spaces:
Running
Running
Update models/layers.py
Browse files- models/layers.py +26 -58
models/layers.py
CHANGED
@@ -4,155 +4,123 @@ import torch
|
|
4 |
from torch import nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
try:
|
8 |
-
from flash_attn_interface import flash_attn_func # type: ignore[import]
|
9 |
-
except ImportError:
|
10 |
-
# Fallback to FlashAttention 2
|
11 |
-
from flash_attn import flash_attn_func # type: ignore[import]
|
12 |
-
|
13 |
from models.common import trunc_normal_init_
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
17 |
|
18 |
-
|
19 |
def _find_multiple(a, b):
|
20 |
return (-(a // -b)) * b
|
21 |
|
22 |
-
|
23 |
def rotate_half(x: torch.Tensor):
|
24 |
-
"""Rotates half the hidden dims of the input."""
|
25 |
x1 = x[..., : x.shape[-1] // 2]
|
26 |
x2 = x[..., x.shape[-1] // 2 :]
|
27 |
return torch.cat((-x2, x1), dim=-1)
|
28 |
|
29 |
-
|
30 |
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
31 |
-
# q, k: [bs, seq_len, num_heads, head_dim]
|
32 |
-
# cos, sin: [seq_len, head_dim]
|
33 |
orig_dtype = q.dtype
|
34 |
q = q.to(cos.dtype)
|
35 |
k = k.to(cos.dtype)
|
36 |
-
|
37 |
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
38 |
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
39 |
-
|
40 |
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
41 |
|
42 |
-
|
43 |
class CastedLinear(nn.Module):
|
44 |
-
def __init__(self,
|
45 |
-
in_features: int,
|
46 |
-
out_features: int,
|
47 |
-
bias: bool):
|
48 |
super().__init__()
|
49 |
-
# Truncated LeCun normal init
|
50 |
self.weight = nn.Parameter(
|
51 |
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
52 |
)
|
53 |
self.bias = None
|
54 |
if bias:
|
55 |
-
# Zero init bias
|
56 |
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
57 |
|
58 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
59 |
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
60 |
|
61 |
-
|
62 |
class CastedEmbedding(nn.Module):
|
63 |
-
def __init__(self,
|
64 |
-
num_embeddings: int,
|
65 |
-
embedding_dim: int,
|
66 |
-
init_std: float,
|
67 |
-
cast_to: torch.dtype):
|
68 |
super().__init__()
|
69 |
self.cast_to = cast_to
|
70 |
-
|
71 |
-
# Truncated LeCun normal init
|
72 |
self.embedding_weight = nn.Parameter(
|
73 |
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
74 |
)
|
75 |
-
|
76 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
77 |
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
78 |
|
79 |
-
|
80 |
class RotaryEmbedding(nn.Module):
|
81 |
def __init__(self, dim, max_position_embeddings, base, device=None):
|
82 |
super().__init__()
|
83 |
-
|
84 |
-
# RoPE
|
85 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
86 |
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
87 |
freqs = torch.outer(t, inv_freq)
|
88 |
-
|
89 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
90 |
emb = torch.cat((freqs, freqs), dim=-1)
|
91 |
-
self.cos_cached
|
92 |
-
self.sin_cached
|
93 |
|
94 |
def forward(self):
|
95 |
return self.cos_cached, self.sin_cached
|
96 |
|
97 |
-
|
98 |
class Attention(nn.Module):
|
99 |
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
100 |
super().__init__()
|
101 |
-
|
102 |
self.hidden_size = hidden_size
|
103 |
self.head_dim = head_dim
|
104 |
self.output_size = head_dim * num_heads
|
105 |
self.num_heads = num_heads
|
106 |
self.num_key_value_heads = num_key_value_heads
|
107 |
self.causal = causal
|
108 |
-
|
109 |
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
110 |
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
111 |
|
112 |
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
113 |
batch_size, seq_len, _ = hidden_states.shape
|
114 |
-
|
115 |
-
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
116 |
qkv = self.qkv_proj(hidden_states)
|
117 |
-
|
118 |
-
# Split head
|
119 |
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
120 |
query = qkv[:, :, :self.num_heads]
|
121 |
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
122 |
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
123 |
|
124 |
-
# RoPE
|
125 |
if cos_sin is not None:
|
126 |
cos, sin = cos_sin
|
127 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
128 |
|
129 |
-
#
|
130 |
-
|
131 |
-
|
132 |
-
attn_output
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
return self.o_proj(attn_output)
|
137 |
|
138 |
-
|
139 |
class SwiGLU(nn.Module):
|
140 |
def __init__(self, hidden_size: int, expansion: float):
|
141 |
super().__init__()
|
142 |
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
143 |
-
|
144 |
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
145 |
-
self.down_proj
|
146 |
|
147 |
def forward(self, x):
|
148 |
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
149 |
return self.down_proj(F.silu(gate) * up)
|
150 |
|
151 |
-
|
152 |
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
153 |
input_dtype = hidden_states.dtype
|
154 |
hidden_states = hidden_states.to(torch.float32)
|
155 |
-
|
156 |
variance = hidden_states.square().mean(-1, keepdim=True)
|
157 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
158 |
-
return hidden_states.to(input_dtype)
|
|
|
4 |
from torch import nn
|
5 |
import torch.nn.functional as F
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from models.common import trunc_normal_init_
|
8 |
|
9 |
+
# --- MODIFICATION: Added a fallback for FlashAttention ---
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_func
|
12 |
+
FLASH_ATTENTION_AVAILABLE = True
|
13 |
+
except ImportError:
|
14 |
+
FLASH_ATTENTION_AVAILABLE = False
|
15 |
+
print("⚠️ FlashAttention not found. Falling back to standard PyTorch attention.")
|
16 |
+
# --- END MODIFICATION ---
|
17 |
|
18 |
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
19 |
|
|
|
20 |
def _find_multiple(a, b):
|
21 |
return (-(a // -b)) * b
|
22 |
|
|
|
23 |
def rotate_half(x: torch.Tensor):
|
|
|
24 |
x1 = x[..., : x.shape[-1] // 2]
|
25 |
x2 = x[..., x.shape[-1] // 2 :]
|
26 |
return torch.cat((-x2, x1), dim=-1)
|
27 |
|
|
|
28 |
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
|
|
|
|
29 |
orig_dtype = q.dtype
|
30 |
q = q.to(cos.dtype)
|
31 |
k = k.to(cos.dtype)
|
|
|
32 |
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
33 |
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
|
|
34 |
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
35 |
|
|
|
36 |
class CastedLinear(nn.Module):
|
37 |
+
def __init__(self, in_features: int, out_features: int, bias: bool):
|
|
|
|
|
|
|
38 |
super().__init__()
|
|
|
39 |
self.weight = nn.Parameter(
|
40 |
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
41 |
)
|
42 |
self.bias = None
|
43 |
if bias:
|
|
|
44 |
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
45 |
|
46 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
47 |
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
48 |
|
|
|
49 |
class CastedEmbedding(nn.Module):
|
50 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, init_std: float, cast_to: torch.dtype):
|
|
|
|
|
|
|
|
|
51 |
super().__init__()
|
52 |
self.cast_to = cast_to
|
|
|
|
|
53 |
self.embedding_weight = nn.Parameter(
|
54 |
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
55 |
)
|
|
|
56 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
57 |
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
58 |
|
|
|
59 |
class RotaryEmbedding(nn.Module):
|
60 |
def __init__(self, dim, max_position_embeddings, base, device=None):
|
61 |
super().__init__()
|
|
|
|
|
62 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
63 |
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
64 |
freqs = torch.outer(t, inv_freq)
|
|
|
|
|
65 |
emb = torch.cat((freqs, freqs), dim=-1)
|
66 |
+
self.register_buffer('cos_cached', emb.cos(), persistent=False)
|
67 |
+
self.register_buffer('sin_cached', emb.sin(), persistent=False)
|
68 |
|
69 |
def forward(self):
|
70 |
return self.cos_cached, self.sin_cached
|
71 |
|
|
|
72 |
class Attention(nn.Module):
|
73 |
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
74 |
super().__init__()
|
|
|
75 |
self.hidden_size = hidden_size
|
76 |
self.head_dim = head_dim
|
77 |
self.output_size = head_dim * num_heads
|
78 |
self.num_heads = num_heads
|
79 |
self.num_key_value_heads = num_key_value_heads
|
80 |
self.causal = causal
|
|
|
81 |
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
82 |
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
83 |
|
84 |
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
85 |
batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
|
86 |
qkv = self.qkv_proj(hidden_states)
|
|
|
|
|
87 |
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
88 |
query = qkv[:, :, :self.num_heads]
|
89 |
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
90 |
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
91 |
|
|
|
92 |
if cos_sin is not None:
|
93 |
cos, sin = cos_sin
|
94 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
95 |
|
96 |
+
# --- MODIFICATION: Use FlashAttention if available, otherwise fallback ---
|
97 |
+
if FLASH_ATTENTION_AVAILABLE:
|
98 |
+
attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
|
99 |
+
if isinstance(attn_output, tuple):
|
100 |
+
attn_output = attn_output[0]
|
101 |
+
else:
|
102 |
+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
103 |
+
attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=self.causal)
|
104 |
+
attn_output = attn_output.transpose(1, 2)
|
105 |
+
# --- END MODIFICATION ---
|
106 |
+
|
107 |
+
attn_output = attn_output.contiguous().view(batch_size, seq_len, self.output_size)
|
108 |
return self.o_proj(attn_output)
|
109 |
|
|
|
110 |
class SwiGLU(nn.Module):
|
111 |
def __init__(self, hidden_size: int, expansion: float):
|
112 |
super().__init__()
|
113 |
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
|
|
114 |
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
115 |
+
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
116 |
|
117 |
def forward(self, x):
|
118 |
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
119 |
return self.down_proj(F.silu(gate) * up)
|
120 |
|
|
|
121 |
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
122 |
input_dtype = hidden_states.dtype
|
123 |
hidden_states = hidden_states.to(torch.float32)
|
|
|
124 |
variance = hidden_states.square().mean(-1, keepdim=True)
|
125 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
126 |
+
return hidden_states.to(input_dtype)
|