Fix device and type in RotaryPositionalEmbedding
Browse files- encoder.py +2 -2
encoder.py
CHANGED
@@ -354,9 +354,9 @@ class RotaryPositionalEmbedding(PositionalEncoding):
|
|
354 |
return None
|
355 |
positions = torch.arange(0, length, dtype=torch.float32, device=device)
|
356 |
inv_freq = 1.0 / (
|
357 |
-
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
358 |
)
|
359 |
-
t = torch.arange(length, device=positions.device
|
360 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
361 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
362 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|
|
|
354 |
return None
|
355 |
positions = torch.arange(0, length, dtype=torch.float32, device=device)
|
356 |
inv_freq = 1.0 / (
|
357 |
+
self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)
|
358 |
)
|
359 |
+
t = torch.arange(length, device=positions.device, dtype=inv_freq.dtype)
|
360 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
361 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
362 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|