Force curope, TODO: fix pytorch rope implementation for bfloat16
Browse files- models/pos_embed.py +7 -4
models/pos_embed.py
CHANGED
@@ -104,8 +104,9 @@ try:
|
|
104 |
from extensions.curope import cuRoPE2D
|
105 |
RoPE2D = cuRoPE2D
|
106 |
except ImportError:
|
107 |
-
|
108 |
-
|
|
|
109 |
class RoPE2D(torch.nn.Module):
|
110 |
|
111 |
def __init__(self, freq=100.0, F0=1.0):
|
@@ -135,7 +136,7 @@ except ImportError:
|
|
135 |
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
136 |
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
137 |
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
138 |
-
|
139 |
def forward(self, tokens, positions):
|
140 |
"""
|
141 |
input:
|
@@ -144,6 +145,8 @@ except ImportError:
|
|
144 |
output:
|
145 |
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
146 |
"""
|
|
|
|
|
147 |
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
148 |
D = tokens.size(3) // 2
|
149 |
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
@@ -153,4 +156,4 @@ except ImportError:
|
|
153 |
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
154 |
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
155 |
tokens = torch.cat((y, x), dim=-1)
|
156 |
-
return tokens
|
|
|
104 |
from extensions.curope import cuRoPE2D
|
105 |
RoPE2D = cuRoPE2D
|
106 |
except ImportError:
|
107 |
+
# critical error, we need to use the slow pytorch version
|
108 |
+
raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
109 |
+
|
110 |
class RoPE2D(torch.nn.Module):
|
111 |
|
112 |
def __init__(self, freq=100.0, F0=1.0):
|
|
|
136 |
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
137 |
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
138 |
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
139 |
+
|
140 |
def forward(self, tokens, positions):
|
141 |
"""
|
142 |
input:
|
|
|
145 |
output:
|
146 |
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
147 |
"""
|
148 |
+
tokens = tokens.to(torch.float32)
|
149 |
+
#positions = positions.to(torch.float32)
|
150 |
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
151 |
D = tokens.size(3) // 2
|
152 |
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
|
|
156 |
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
157 |
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
158 |
tokens = torch.cat((y, x), dim=-1)
|
159 |
+
return tokens
|