bartduis commited on
Commit
d43e49d
·
verified ·
1 Parent(s): 0a76fc3

Update models/pos_embed.py

Browse files
Files changed (1) hide show
  1. models/pos_embed.py +48 -55
models/pos_embed.py CHANGED
@@ -100,61 +100,54 @@ def interpolate_pos_embed(model, checkpoint_model):
100
  # RoPE2D: RoPE implementation in 2D
101
  #----------------------------------------------------------
102
 
103
- try:
104
- from extensions.curope import cuRoPE2D
105
- RoPE2D = cuRoPE2D
106
- except ImportError:
107
- # critical error, we need to use the slow pytorch version
108
- print("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
109
- #raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
110
 
111
- class RoPE2D(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def __init__(self, freq=100.0, F0=1.0):
114
- super().__init__()
115
- self.base = freq
116
- self.F0 = F0
117
- self.cache = {}
118
-
119
- def get_cos_sin(self, D, seq_len, device, dtype):
120
- if (D,seq_len,device,dtype) not in self.cache:
121
- inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
122
- t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
123
- freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
124
- freqs = torch.cat((freqs, freqs), dim=-1)
125
- cos = freqs.cos() # (Seq, Dim)
126
- sin = freqs.sin()
127
- self.cache[D,seq_len,device,dtype] = (cos,sin)
128
- return self.cache[D,seq_len,device,dtype]
129
-
130
- @staticmethod
131
- def rotate_half(x):
132
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
133
- return torch.cat((-x2, x1), dim=-1)
134
-
135
- def apply_rope1d(self, tokens, pos1d, cos, sin):
136
- assert pos1d.ndim==2
137
- cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
138
- sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
139
- return (tokens * cos) + (self.rotate_half(tokens) * sin)
140
 
141
- def forward(self, tokens, positions):
142
- """
143
- input:
144
- * tokens: batch_size x nheads x ntokens x dim
145
- * positions: batch_size x ntokens x 2 (y and x position of each token)
146
- output:
147
- * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
148
- """
149
- # tokens = tokens.to(torch.float32)
150
- # #positions = positions.to(torch.float32)
151
- # assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
152
- # D = tokens.size(3) // 2
153
- # assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
154
- # cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
155
- # # split features into two along the feature dimension, and apply rope1d on each half
156
- # y, x = tokens.chunk(2, dim=-1)
157
- # y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
158
- # x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
159
- # tokens = torch.cat((y, x), dim=-1)
160
- return tokens
 
 
 
 
 
 
 
100
  # RoPE2D: RoPE implementation in 2D
101
  #----------------------------------------------------------
102
 
103
+
104
+ class RoPE2D(torch.nn.Module):
 
 
 
 
 
105
 
106
+ def __init__(self, freq=100.0, F0=1.0):
107
+ super().__init__()
108
+ self.base = freq
109
+ self.F0 = F0
110
+ self.cache = {}
111
+
112
+ def get_cos_sin(self, D, seq_len, device, dtype):
113
+ if (D,seq_len,device,dtype) not in self.cache:
114
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
115
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
116
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
117
+ freqs = torch.cat((freqs, freqs), dim=-1)
118
+ cos = freqs.cos() # (Seq, Dim)
119
+ sin = freqs.sin()
120
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
121
+ return self.cache[D,seq_len,device,dtype]
122
 
123
+ @staticmethod
124
+ def rotate_half(x):
125
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
126
+ return torch.cat((-x2, x1), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
129
+ assert pos1d.ndim==2
130
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
131
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
132
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
133
+
134
+ def forward(self, tokens, positions):
135
+ """
136
+ input:
137
+ * tokens: batch_size x nheads x ntokens x dim
138
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
139
+ output:
140
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
141
+ """
142
+ # tokens = tokens.to(torch.float32)
143
+ # #positions = positions.to(torch.float32)
144
+ # assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
145
+ # D = tokens.size(3) // 2
146
+ # assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
147
+ # cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
148
+ # # split features into two along the feature dimension, and apply rope1d on each half
149
+ # y, x = tokens.chunk(2, dim=-1)
150
+ # y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
151
+ # x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
152
+ # tokens = torch.cat((y, x), dim=-1)
153
+ return tokens