Update models/pos_embed.py
Browse files- models/pos_embed.py +48 -55
models/pos_embed.py
CHANGED
|
@@ -100,61 +100,54 @@ def interpolate_pos_embed(model, checkpoint_model):
|
|
| 100 |
# RoPE2D: RoPE implementation in 2D
|
| 101 |
#----------------------------------------------------------
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
RoPE2D = cuRoPE2D
|
| 106 |
-
except ImportError:
|
| 107 |
-
# critical error, we need to use the slow pytorch version
|
| 108 |
-
print("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
| 109 |
-
#raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
self.cache = {}
|
| 118 |
-
|
| 119 |
-
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 120 |
-
if (D,seq_len,device,dtype) not in self.cache:
|
| 121 |
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
| 122 |
-
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 123 |
-
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 124 |
-
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 125 |
-
cos = freqs.cos() # (Seq, Dim)
|
| 126 |
-
sin = freqs.sin()
|
| 127 |
-
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
| 128 |
-
return self.cache[D,seq_len,device,dtype]
|
| 129 |
-
|
| 130 |
-
@staticmethod
|
| 131 |
-
def rotate_half(x):
|
| 132 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 133 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 134 |
-
|
| 135 |
-
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 136 |
-
assert pos1d.ndim==2
|
| 137 |
-
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 138 |
-
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 139 |
-
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# RoPE2D: RoPE implementation in 2D
|
| 101 |
#----------------------------------------------------------
|
| 102 |
|
| 103 |
+
|
| 104 |
+
class RoPE2D(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
def __init__(self, freq=100.0, F0=1.0):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.base = freq
|
| 109 |
+
self.F0 = F0
|
| 110 |
+
self.cache = {}
|
| 111 |
+
|
| 112 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 113 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
| 114 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
| 115 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 116 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 117 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 118 |
+
cos = freqs.cos() # (Seq, Dim)
|
| 119 |
+
sin = freqs.sin()
|
| 120 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
| 121 |
+
return self.cache[D,seq_len,device,dtype]
|
| 122 |
|
| 123 |
+
@staticmethod
|
| 124 |
+
def rotate_half(x):
|
| 125 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 126 |
+
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 129 |
+
assert pos1d.ndim==2
|
| 130 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 131 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 132 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 133 |
+
|
| 134 |
+
def forward(self, tokens, positions):
|
| 135 |
+
"""
|
| 136 |
+
input:
|
| 137 |
+
* tokens: batch_size x nheads x ntokens x dim
|
| 138 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
| 139 |
+
output:
|
| 140 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 141 |
+
"""
|
| 142 |
+
# tokens = tokens.to(torch.float32)
|
| 143 |
+
# #positions = positions.to(torch.float32)
|
| 144 |
+
# assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
| 145 |
+
# D = tokens.size(3) // 2
|
| 146 |
+
# assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
| 147 |
+
# cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
| 148 |
+
# # split features into two along the feature dimension, and apply rope1d on each half
|
| 149 |
+
# y, x = tokens.chunk(2, dim=-1)
|
| 150 |
+
# y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
| 151 |
+
# x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
| 152 |
+
# tokens = torch.cat((y, x), dim=-1)
|
| 153 |
+
return tokens
|