Update RotaryEmbedding caching
#33
by
beibin79
- opened
- modelling_RW.py +11 -15
modelling_RW.py
CHANGED
@@ -56,13 +56,12 @@ class RotaryEmbedding(torch.nn.Module):
|
|
56 |
base=10000,
|
57 |
):
|
58 |
super().__init__()
|
59 |
-
inv_freq = 1.0 / (base
|
|
|
60 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
61 |
self.head_dim = head_dim
|
62 |
-
self.
|
63 |
-
self.
|
64 |
-
self.cos_cached: torch.Tensor | None = None
|
65 |
-
self.sin_cached: torch.Tensor | None = None
|
66 |
|
67 |
def cos_sin(
|
68 |
self,
|
@@ -70,27 +69,24 @@ class RotaryEmbedding(torch.nn.Module):
|
|
70 |
device="cuda",
|
71 |
dtype=torch.bfloat16,
|
72 |
) -> torch.Tensor:
|
73 |
-
if seq_len
|
74 |
self.seq_len_cached = seq_len
|
75 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
76 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
77 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
78 |
-
|
79 |
if dtype in [torch.float16, torch.bfloat16]:
|
80 |
emb = emb.float()
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
self.sin_cached = emb.sin()[None, :, :]
|
84 |
-
|
85 |
-
self.cos_cached = self.cos_cached.type(dtype)
|
86 |
-
self.sin_cached = self.sin_cached.type(dtype)
|
87 |
-
|
88 |
-
return self.cos_cached, self.sin_cached
|
89 |
|
90 |
def forward(self, q, k):
|
91 |
batch, seq_len, head_dim = q.shape
|
|
|
92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
93 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) *
|
|
|
94 |
|
95 |
|
96 |
def _make_causal_mask(
|
|
|
56 |
base=10000,
|
57 |
):
|
58 |
super().__init__()
|
59 |
+
inv_freq = 1.0 / (base
|
60 |
+
**(torch.arange(0, head_dim, 2).float() / head_dim))
|
61 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
62 |
self.head_dim = head_dim
|
63 |
+
self.cos_cache_dict: dict = {}
|
64 |
+
self.sin_cache_dict: dict = {}
|
|
|
|
|
65 |
|
66 |
def cos_sin(
|
67 |
self,
|
|
|
69 |
device="cuda",
|
70 |
dtype=torch.bfloat16,
|
71 |
) -> torch.Tensor:
|
72 |
+
if seq_len not in self.cos_cache_dict or seq_len not in self.sin_cache_dict:
|
73 |
self.seq_len_cached = seq_len
|
74 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
75 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
76 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
|
77 |
if dtype in [torch.float16, torch.bfloat16]:
|
78 |
emb = emb.float()
|
79 |
+
self.cos_cache_dict[seq_len] = emb.cos()[None, :, :].type(dtype)
|
80 |
+
self.sin_cache_dict[seq_len] = emb.sin()[None, :, :].type(dtype)
|
81 |
|
82 |
+
return self.cos_cache_dict[seq_len], self.sin_cache_dict[seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
def forward(self, q, k):
|
85 |
batch, seq_len, head_dim = q.shape
|
86 |
+
assert seq_len is not None, "seq_len must be known and not None"
|
87 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
88 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) *
|
89 |
+
sin)
|
90 |
|
91 |
|
92 |
def _make_causal_mask(
|