Update models/pos_embed.py
Browse files- models/pos_embed.py +48 -55
models/pos_embed.py
CHANGED
@@ -100,61 +100,54 @@ def interpolate_pos_embed(model, checkpoint_model):
|
|
100 |
# RoPE2D: RoPE implementation in 2D
|
101 |
#----------------------------------------------------------
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
RoPE2D = cuRoPE2D
|
106 |
-
except ImportError:
|
107 |
-
# critical error, we need to use the slow pytorch version
|
108 |
-
print("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
109 |
-
#raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
self.cache = {}
|
118 |
-
|
119 |
-
def get_cos_sin(self, D, seq_len, device, dtype):
|
120 |
-
if (D,seq_len,device,dtype) not in self.cache:
|
121 |
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
122 |
-
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
123 |
-
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
124 |
-
freqs = torch.cat((freqs, freqs), dim=-1)
|
125 |
-
cos = freqs.cos() # (Seq, Dim)
|
126 |
-
sin = freqs.sin()
|
127 |
-
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
128 |
-
return self.cache[D,seq_len,device,dtype]
|
129 |
-
|
130 |
-
@staticmethod
|
131 |
-
def rotate_half(x):
|
132 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
133 |
-
return torch.cat((-x2, x1), dim=-1)
|
134 |
-
|
135 |
-
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
136 |
-
assert pos1d.ndim==2
|
137 |
-
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
138 |
-
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
139 |
-
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# RoPE2D: RoPE implementation in 2D
|
101 |
#----------------------------------------------------------
|
102 |
|
103 |
+
|
104 |
+
class RoPE2D(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
def __init__(self, freq=100.0, F0=1.0):
|
107 |
+
super().__init__()
|
108 |
+
self.base = freq
|
109 |
+
self.F0 = F0
|
110 |
+
self.cache = {}
|
111 |
+
|
112 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
113 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
114 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
115 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
116 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
117 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
118 |
+
cos = freqs.cos() # (Seq, Dim)
|
119 |
+
sin = freqs.sin()
|
120 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
121 |
+
return self.cache[D,seq_len,device,dtype]
|
122 |
|
123 |
+
@staticmethod
|
124 |
+
def rotate_half(x):
|
125 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
126 |
+
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
129 |
+
assert pos1d.ndim==2
|
130 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
131 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
132 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
133 |
+
|
134 |
+
def forward(self, tokens, positions):
|
135 |
+
"""
|
136 |
+
input:
|
137 |
+
* tokens: batch_size x nheads x ntokens x dim
|
138 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
139 |
+
output:
|
140 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
141 |
+
"""
|
142 |
+
# tokens = tokens.to(torch.float32)
|
143 |
+
# #positions = positions.to(torch.float32)
|
144 |
+
# assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
145 |
+
# D = tokens.size(3) // 2
|
146 |
+
# assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
147 |
+
# cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
148 |
+
# # split features into two along the feature dimension, and apply rope1d on each half
|
149 |
+
# y, x = tokens.chunk(2, dim=-1)
|
150 |
+
# y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
151 |
+
# x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
152 |
+
# tokens = torch.cat((y, x), dim=-1)
|
153 |
+
return tokens
|