waveletdeboshir commited on
Commit
8d98c92
·
verified ·
1 Parent(s): c173722

Fix device and type in RotaryPositionalEmbedding

Browse files
Files changed (1) hide show
  1. encoder.py +2 -2
encoder.py CHANGED
@@ -354,9 +354,9 @@ class RotaryPositionalEmbedding(PositionalEncoding):
354
  return None
355
  positions = torch.arange(0, length, dtype=torch.float32, device=device)
356
  inv_freq = 1.0 / (
357
- self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
358
  )
359
- t = torch.arange(length, device=positions.device).type_as(inv_freq)
360
  freqs = torch.einsum("i,j->ij", t, inv_freq)
361
  emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
362
  return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
 
354
  return None
355
  positions = torch.arange(0, length, dtype=torch.float32, device=device)
356
  inv_freq = 1.0 / (
357
+ self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)
358
  )
359
+ t = torch.arange(length, device=positions.device, dtype=inv_freq.dtype)
360
  freqs = torch.einsum("i,j->ij", t, inv_freq)
361
  emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
362
  return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])