Update RotaryEmbedding caching

#33
by beibin79 - opened
Files changed (1) hide show
  1. modelling_RW.py +11 -15
modelling_RW.py CHANGED
@@ -56,13 +56,12 @@ class RotaryEmbedding(torch.nn.Module):
56
  base=10000,
57
  ):
58
  super().__init__()
59
- inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
 
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
- self.seq_len_cached = None
63
- self.batch_size_cached = None
64
- self.cos_cached: torch.Tensor | None = None
65
- self.sin_cached: torch.Tensor | None = None
66
 
67
  def cos_sin(
68
  self,
@@ -70,27 +69,24 @@ class RotaryEmbedding(torch.nn.Module):
70
  device="cuda",
71
  dtype=torch.bfloat16,
72
  ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
  self.seq_len_cached = seq_len
75
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
-
79
  if dtype in [torch.float16, torch.bfloat16]:
80
  emb = emb.float()
 
 
81
 
82
- self.cos_cached = emb.cos()[None, :, :]
83
- self.sin_cached = emb.sin()[None, :, :]
84
-
85
- self.cos_cached = self.cos_cached.type(dtype)
86
- self.sin_cached = self.sin_cached.type(dtype)
87
-
88
- return self.cos_cached, self.sin_cached
89
 
90
  def forward(self, q, k):
91
  batch, seq_len, head_dim = q.shape
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
94
 
95
 
96
  def _make_causal_mask(
 
56
  base=10000,
57
  ):
58
  super().__init__()
59
+ inv_freq = 1.0 / (base
60
+ **(torch.arange(0, head_dim, 2).float() / head_dim))
61
  self.register_buffer("inv_freq", inv_freq, persistent=False)
62
  self.head_dim = head_dim
63
+ self.cos_cache_dict: dict = {}
64
+ self.sin_cache_dict: dict = {}
 
 
65
 
66
  def cos_sin(
67
  self,
 
69
  device="cuda",
70
  dtype=torch.bfloat16,
71
  ) -> torch.Tensor:
72
+ if seq_len not in self.cos_cache_dict or seq_len not in self.sin_cache_dict:
73
  self.seq_len_cached = seq_len
74
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
75
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
76
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
 
77
  if dtype in [torch.float16, torch.bfloat16]:
78
  emb = emb.float()
79
+ self.cos_cache_dict[seq_len] = emb.cos()[None, :, :].type(dtype)
80
+ self.sin_cache_dict[seq_len] = emb.sin()[None, :, :].type(dtype)
81
 
82
+ return self.cos_cache_dict[seq_len], self.sin_cache_dict[seq_len]
 
 
 
 
 
 
83
 
84
  def forward(self, q, k):
85
  batch, seq_len, head_dim = q.shape
86
+ assert seq_len is not None, "seq_len must be known and not None"
87
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
88
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) *
89
+ sin)
90
 
91
 
92
  def _make_causal_mask(