bartduis commited on
Commit
1a31912
·
1 Parent(s): 9e5ce3d

Force curope, TODO: fix pytorch rope implementation for bfloat16

Browse files
Files changed (1) hide show
  1. models/pos_embed.py +7 -4
models/pos_embed.py CHANGED
@@ -104,8 +104,9 @@ try:
104
  from extensions.curope import cuRoPE2D
105
  RoPE2D = cuRoPE2D
106
  except ImportError:
107
- print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
108
-
 
109
  class RoPE2D(torch.nn.Module):
110
 
111
  def __init__(self, freq=100.0, F0=1.0):
@@ -135,7 +136,7 @@ except ImportError:
135
  cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
136
  sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
137
  return (tokens * cos) + (self.rotate_half(tokens) * sin)
138
-
139
  def forward(self, tokens, positions):
140
  """
141
  input:
@@ -144,6 +145,8 @@ except ImportError:
144
  output:
145
  * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
146
  """
 
 
147
  assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
148
  D = tokens.size(3) // 2
149
  assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
@@ -153,4 +156,4 @@ except ImportError:
153
  y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
154
  x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
155
  tokens = torch.cat((y, x), dim=-1)
156
- return tokens
 
104
  from extensions.curope import cuRoPE2D
105
  RoPE2D = cuRoPE2D
106
  except ImportError:
107
+ # critical error, we need to use the slow pytorch version
108
+ raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
109
+
110
  class RoPE2D(torch.nn.Module):
111
 
112
  def __init__(self, freq=100.0, F0=1.0):
 
136
  cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
137
  sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
138
  return (tokens * cos) + (self.rotate_half(tokens) * sin)
139
+
140
  def forward(self, tokens, positions):
141
  """
142
  input:
 
145
  output:
146
  * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
147
  """
148
+ tokens = tokens.to(torch.float32)
149
+ #positions = positions.to(torch.float32)
150
  assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
  D = tokens.size(3) // 2
152
  assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
 
156
  y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
  x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
  tokens = torch.cat((y, x), dim=-1)
159
+ return tokens