namespace-Pt commited on
Commit
bf9b66d
·
verified ·
1 Parent(s): 53d5690

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +6 -0
  2. configuration_qwen2.py +2 -2
  3. modeling_beacon.py +53 -7
  4. modeling_qwen2.py +41 -330
  5. modeling_utils.py +493 -10
README.md CHANGED
@@ -16,6 +16,12 @@ pipeline_tag: text-generation
16
  - **Low-Cost**
17
  - it is light-weight and can be efficiently trained with roughly 1B tokens.
18
 
 
 
 
 
 
 
19
 
20
  # Usage
21
  ```python
 
16
  - **Low-Cost**
17
  - it is light-weight and can be efficiently trained with roughly 1B tokens.
18
 
19
+ # Environment
20
+ ```
21
+ pip install transformers
22
+ pip install flash-attn --no-build-isolation
23
+ ```
24
+
25
 
26
  # Usage
27
  ```python
configuration_qwen2.py CHANGED
@@ -115,8 +115,8 @@ class Qwen2Config(PretrainedConfig):
115
  rope_scaling=None,
116
  max_window_layers=28,
117
  attention_dropout=0.0,
118
- beacon_window=2048,
119
- beacon_stride=2048,
120
  beacon_attn="full-coverage",
121
  beacon_ratio=[2,4,8,16,32],
122
  beacon_ratio_mix="step-random",
 
115
  rope_scaling=None,
116
  max_window_layers=28,
117
  attention_dropout=0.0,
118
+ beacon_window=1024,
119
+ beacon_stride=1024,
120
  beacon_attn="full-coverage",
121
  beacon_ratio=[2,4,8,16,32],
122
  beacon_ratio_mix="step-random",
modeling_beacon.py CHANGED
@@ -90,6 +90,10 @@ class Memory(torch.nn.Module):
90
  self.all_attention_mask = None
91
  self.all_labels = None
92
 
 
 
 
 
93
  # the raw activations of recent tokens
94
  self.raw_activations = [(None, None) for _ in range(self.config.num_hidden_layers)]
95
  # the attention sink activations
@@ -147,7 +151,7 @@ class Memory(torch.nn.Module):
147
  raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim]
148
  return sink_memory_size, beacon_memory_size, raw_memory_size
149
 
150
- def prepare(self, input_ids, attention_mask, labels):
151
  """
152
  Prepare inputs for the model. These inputs belong to the same sequence.
153
  """
@@ -179,6 +183,19 @@ class Memory(torch.nn.Module):
179
  else:
180
  self.all_labels = torch.cat([self.all_labels, labels], dim=1)
181
  assert self.all_input_ids.shape[1] == self.all_labels.shape[1], f"Found inconsistent all_input_ids {self.all_input_ids.shape} and all_labels {self.all_labels.shape}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def set_compression_ratio(self, start_idx, end_idx):
184
  """Choose a condensing ratio from self.config.beacon_ratio"""
@@ -399,10 +416,27 @@ class Memory(torch.nn.Module):
399
  # In the last window, we do not need to append beacons because they will not be used at all
400
  if self.training and end_idx == self.all_sequence_length:
401
  next_start_idx = start_idx
 
 
 
 
 
 
 
 
 
 
402
  raw_size_to_cache = -1
403
  beacon_size = 0
404
- compression_ratio = 1
 
 
 
 
405
  is_full_window = False
 
 
 
406
 
407
  else:
408
  #============================================#
@@ -511,9 +545,9 @@ class Memory(torch.nn.Module):
511
  # update the reminder
512
  self._interleave_remainder = (input_len + self._interleave_remainder) % compression_ratio
513
 
514
- # NOTE: skip computing loss in the very first window because the beacon tokens will be used in the next window
515
- if self.training and self._step_idx == 0 and not (self.config.beacon_pos == 'interleave' and self.config.beacon_attn == 'full-coverage'):
516
- labels[:] = -100
517
 
518
  # t2 = time.time()
519
 
@@ -607,12 +641,15 @@ class Memory(torch.nn.Module):
607
  self._end_idx = end_idx
608
  self._step_idx += 1
609
 
 
 
610
  # print(f"beacon_size: {beacon_size}")
611
  # print(f"raw_size_to_cache: {raw_size_to_cache}")
 
612
  # print(f"input_ids: {input_ids}")
613
  # print(f"beacon_indices: {beacon_indices}")
614
  # print(f"position_ids: {position_ids}")
615
- # print(f"attention_mask:\n{attention_mask}")
616
  # x = input()
617
  # if x == "s":
618
  # return
@@ -627,6 +664,16 @@ class Memory(torch.nn.Module):
627
  # NOTE: the past_key_values are incrementally returned (only the new keys and values are returned)
628
  previous_raw_key, previous_raw_value = self.raw_activations[layer_idx]
629
 
 
 
 
 
 
 
 
 
 
 
630
  if self.beacon_activations[layer_idx][0] is None and self.config.beacon_sink_size > 0:
631
  # save the sink activations
632
  # NOTE: we do not slice the key/value activations, which may cause duplication when beacon_ratio=-1 for the first window, but it's okay
@@ -696,7 +743,6 @@ class Memory(torch.nn.Module):
696
  # NOTE: we must use dict to override values, otherwise trainer cannot find loss
697
  model_outputs["loss"] = loss
698
  model_outputs["batch_loss"] = batch_loss
699
- model_outputs["valid_token_num"] = self._valid_token_num
700
 
701
  # override last_hidden_states (used in generation)
702
  beacon_size = self._all_beacon_sizes[-1]
 
90
  self.all_attention_mask = None
91
  self.all_labels = None
92
 
93
+ # NOTE: will be reset in prepare()
94
+ self.beacon_skip_first = None
95
+ self.beacon_skip_last = None
96
+
97
  # the raw activations of recent tokens
98
  self.raw_activations = [(None, None) for _ in range(self.config.num_hidden_layers)]
99
  # the attention sink activations
 
151
  raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim]
152
  return sink_memory_size, beacon_memory_size, raw_memory_size
153
 
154
+ def prepare(self, input_ids, attention_mask, labels, skip_first=None, skip_last=None):
155
  """
156
  Prepare inputs for the model. These inputs belong to the same sequence.
157
  """
 
183
  else:
184
  self.all_labels = torch.cat([self.all_labels, labels], dim=1)
185
  assert self.all_input_ids.shape[1] == self.all_labels.shape[1], f"Found inconsistent all_input_ids {self.all_input_ids.shape} and all_labels {self.all_labels.shape}!"
186
+
187
+ # how many tokens to skip at the beginning of the sequence? (They will be packed in a single chunk and processed by the model, after which their activations will be cached in sink_activations.)
188
+ if skip_first is not None:
189
+ assert self.config.beacon_parallel_window == 1, f"Make sure the parallel window is set to 1 when using beacon_skip!"
190
+ assert self.config.beacon_window == self.config.beacon_stride, f"Make sure the beacon_window equals to beacon_stride when using beacon_skip."
191
+ assert self.config.beacon_sink_size == 0, f"Make sure the beacon_sink_size is set to 0 when using beacon_skip!"
192
+ # stop compression after how many tokens
193
+ if skip_last is not None:
194
+ skip_first = skip_first if skip_first is not None else 0
195
+ assert (skip_last - skip_first) % self.config.beacon_window == 0, f"skip_last ({skip_last}) - skip_first ({skip_first}) = {skip_last - skip_first} is not divisible by window size {self.config.beacon_window}"
196
+ assert self.config.beacon_sink_size == 0, "Make sure the beacon_sink_size is zero when using skip_last!"
197
+ self.beacon_skip_first = skip_first
198
+ self.beacon_skip_last = skip_last
199
 
200
  def set_compression_ratio(self, start_idx, end_idx):
201
  """Choose a condensing ratio from self.config.beacon_ratio"""
 
416
  # In the last window, we do not need to append beacons because they will not be used at all
417
  if self.training and end_idx == self.all_sequence_length:
418
  next_start_idx = start_idx
419
+ is_full_window = False
420
+ raw_size_to_cache = -1
421
+ beacon_size = 0
422
+ compression_ratio = -1
423
+
424
+ elif self._step_idx == 0 and self.beacon_skip_first is not None:
425
+ end_idx = start_idx + self.beacon_skip_first
426
+ assert end_idx < self.all_sequence_length
427
+ next_start_idx = end_idx
428
+ is_full_window = True
429
  raw_size_to_cache = -1
430
  beacon_size = 0
431
+ compression_ratio = -1
432
+
433
+ elif self.beacon_skip_last is not None and start_idx >= self.beacon_skip_last:
434
+ end_idx = min(start_idx + self.config.beacon_window, self.all_sequence_length)
435
+ next_start_idx = end_idx
436
  is_full_window = False
437
+ raw_size_to_cache = -1
438
+ beacon_size = 0
439
+ compression_ratio = -1
440
 
441
  else:
442
  #============================================#
 
545
  # update the reminder
546
  self._interleave_remainder = (input_len + self._interleave_remainder) % compression_ratio
547
 
548
+ # NOTE: skip computing loss in the very first window because the beacon tokens will be used in the next window
549
+ if self.training and self._step_idx == 0 and not (self.config.beacon_pos == 'interleave' and self.config.beacon_attn == 'full-coverage'):
550
+ labels[:] = -100
551
 
552
  # t2 = time.time()
553
 
 
641
  self._end_idx = end_idx
642
  self._step_idx += 1
643
 
644
+ # print(f"start_idx: {start_idx}")
645
+ # print(f"next_start_idx: {next_start_idx}")
646
  # print(f"beacon_size: {beacon_size}")
647
  # print(f"raw_size_to_cache: {raw_size_to_cache}")
648
+ # print(f"interleave_remainder:{self._interleave_remainder}")
649
  # print(f"input_ids: {input_ids}")
650
  # print(f"beacon_indices: {beacon_indices}")
651
  # print(f"position_ids: {position_ids}")
652
+ # print(f"attention_mask:\n{attention_mask == 0}")
653
  # x = input()
654
  # if x == "s":
655
  # return
 
664
  # NOTE: the past_key_values are incrementally returned (only the new keys and values are returned)
665
  previous_raw_key, previous_raw_value = self.raw_activations[layer_idx]
666
 
667
+ if self.beacon_skip_first is not None and self.sink_activations[layer_idx][0] is None:
668
+ assert key.shape[self.k_seq_dim] == self.beacon_skip_first
669
+ assert value.shape[self.k_seq_dim] == self.beacon_skip_first
670
+ self.sink_activations[layer_idx] = [
671
+ key,
672
+ value,
673
+ ]
674
+ # NOTE: no need to update raw activations and beacon activations as all activations are kept as sink activations
675
+ continue
676
+
677
  if self.beacon_activations[layer_idx][0] is None and self.config.beacon_sink_size > 0:
678
  # save the sink activations
679
  # NOTE: we do not slice the key/value activations, which may cause duplication when beacon_ratio=-1 for the first window, but it's okay
 
743
  # NOTE: we must use dict to override values, otherwise trainer cannot find loss
744
  model_outputs["loss"] = loss
745
  model_outputs["batch_loss"] = batch_loss
 
746
 
747
  # override last_hidden_states (used in generation)
748
  beacon_size = self._all_beacon_sizes[-1]
modeling_qwen2.py CHANGED
@@ -30,8 +30,7 @@ from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
  from transformers.activations import ACT2FN
33
- from transformers.cache_utils import Cache, DynamicCache
34
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
  from transformers.modeling_utils import PreTrainedModel
37
  from transformers.utils import (
@@ -53,7 +52,7 @@ if is_flash_attn_2_available():
53
 
54
  from .configuration_qwen2 import Qwen2Config
55
  from .modeling_beacon import Memory
56
- from .modeling_utils import optional_grad_ctx, compute_loss, BeaconModelOutput
57
 
58
 
59
  logger = logging.get_logger(__name__)
@@ -99,183 +98,6 @@ class Qwen2RMSNorm(nn.Module):
99
  return self.weight * hidden_states.to(input_dtype)
100
 
101
 
102
- # Copied from transformers.models.llama.modeling_llama.rotate_half
103
- def rotate_half(x):
104
- """Rotates half the hidden dims of the input."""
105
- x1 = x[..., : x.shape[-1] // 2]
106
- x2 = x[..., x.shape[-1] // 2 :]
107
- return torch.cat((-x2, x1), dim=-1)
108
-
109
-
110
- class Qwen2RotaryEmbedding(nn.Module):
111
- def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None):
112
- super().__init__()
113
-
114
- self.dim = dim
115
- self.max_position_embeddings = max_position_embeddings
116
- self.base = base
117
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
118
- self.register_buffer("inv_freq", inv_freq, persistent=False)
119
-
120
- # Build here to make `torch.jit.trace` work.
121
- self._set_cos_sin_cache(
122
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
123
- )
124
-
125
- def _set_cos_sin_cache(self, seq_len, device, dtype):
126
- self.max_seq_len_cached = seq_len
127
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
128
-
129
- freqs = torch.outer(t, self.inv_freq)
130
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
131
- emb = torch.cat((freqs, freqs), dim=-1)
132
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
133
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
134
-
135
- def forward(self, q, k, position_ids):
136
- seq_len = max(position_ids.max().item() + 1, k.shape[2])
137
-
138
- # x: [bs, num_attention_heads, seq_len, head_size]
139
- if seq_len > self.max_seq_len_cached:
140
- self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
141
-
142
- # batch_size, 1, key_len, head_dim
143
- k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
144
- k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
145
-
146
- q_cos = k_cos[..., -q.shape[2]:, :]
147
- q_sin = k_sin[..., -q.shape[2]:, :]
148
-
149
- q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
150
- k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
151
- return q_embed, k_embed
152
-
153
-
154
- class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):
155
- """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
156
-
157
- def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
158
- self.scaling_factor = scaling_factor
159
- super().__init__(dim, max_position_embeddings, base, device)
160
-
161
- def _set_cos_sin_cache(self, seq_len, device, dtype):
162
- self.max_seq_len_cached = seq_len
163
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
164
- t = t / self.scaling_factor
165
-
166
- freqs = torch.outer(t, self.inv_freq)
167
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
168
- emb = torch.cat((freqs, freqs), dim=-1)
169
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
170
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
171
-
172
-
173
- class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):
174
- """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
175
-
176
- def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
177
- self.scaling_factor = scaling_factor
178
- super().__init__(dim, max_position_embeddings, base, device)
179
-
180
- def _set_cos_sin_cache(self, seq_len, device, dtype):
181
- self.max_seq_len_cached = seq_len
182
-
183
- if seq_len > self.max_position_embeddings:
184
- base = self.base * (
185
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
186
- ) ** (self.dim / (self.dim - 2))
187
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
188
- self.register_buffer("inv_freq", inv_freq, persistent=False)
189
-
190
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
191
-
192
- freqs = torch.outer(t, self.inv_freq)
193
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
194
- emb = torch.cat((freqs, freqs), dim=-1)
195
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
196
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
197
-
198
-
199
- class Qwen2YarnRotaryEmbedding(nn.Module):
200
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
201
- super().__init__()
202
-
203
- self.base = base
204
- self.dim = dim
205
- self.scaling_factor = scaling_factor
206
- self.beta_slow = beta_slow
207
- self.beta_fast = beta_fast
208
- self.max_position_embeddings = max_position_embeddings
209
-
210
- self._set_cos_sin_cache(
211
- seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype()
212
- )
213
-
214
- def _get_factor(self, device, dtype):
215
- # the dimension whose index is smaller than fast_dim rotates more than beta_fast
216
- fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
217
- fast_dim = max(math.floor(fast_dim), 0)
218
- # the dimension whose index is bigger than slow_dim rotates less than beta_slow
219
- slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
220
- slow_dim = min(math.ceil(slow_dim), self.dim - 1)
221
-
222
- if fast_dim == slow_dim:
223
- slow_dim += 0.001
224
-
225
- # NOTE: very important to use full precision here so that the factor is correct
226
- dim_arange = torch.arange(0, self.dim // 2, device=device, dtype=torch.float32)
227
- dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
228
- dim_factor = torch.clamp(dim_factor, 0, 1)
229
-
230
- # align with the paper notation
231
- return (1 - dim_factor)
232
-
233
- def _get_temperature(self):
234
- if self.scaling_factor <= 1:
235
- return 1.0
236
- return 0.07 * math.log(self.scaling_factor) + 1.0
237
-
238
- def _set_cos_sin_cache(self, seq_len, device, dtype):
239
- dim_arange = torch.arange(0, self.dim, 2, device=device) / self.dim
240
- # dim / 2
241
- freq = self.base ** dim_arange
242
- theta = 1 / freq
243
- interleave_theta = theta / self.scaling_factor
244
-
245
- factor = self._get_factor(device, dtype)
246
- yarn_theta = factor * theta + (1 - factor) * interleave_theta
247
- self.register_buffer("inv_freq", yarn_theta, persistent=False)
248
-
249
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
250
- freqs = torch.outer(t, self.inv_freq)
251
- emb = torch.cat((freqs, freqs), dim=-1)
252
-
253
- # get attention temperature
254
- temperature = self._get_temperature()
255
-
256
- self.register_buffer("cos_cached", (emb.cos() * temperature).to(dtype), persistent=False)
257
- self.register_buffer("sin_cached", (emb.sin() * temperature).to(dtype), persistent=False)
258
- self.max_seq_len_cached = seq_len
259
-
260
- def forward(self, q, k, position_ids):
261
- seq_len = max(position_ids.max().item() + 1, k.shape[2])
262
-
263
- # x: [bs, num_attention_heads, seq_len, head_size]
264
- if seq_len > self.max_seq_len_cached:
265
- self.scaling_factor = seq_len / self.max_position_embeddings
266
- self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
267
-
268
- k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
269
- k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
270
-
271
- q_cos = k_cos[..., -q.shape[2]:, :]
272
- q_sin = k_sin[..., -q.shape[2]:, :]
273
-
274
- q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
275
- k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
276
- return q_embed, k_embed
277
-
278
-
279
  # Copied from transformers.models.mistral.modeling_mistral.Qwen2MLP with Qwen2->Qwen2
280
  class Qwen2MLP(nn.Module):
281
  def __init__(self, config):
@@ -288,54 +110,8 @@ class Qwen2MLP(nn.Module):
288
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
289
  self.act_fn = ACT2FN[config.hidden_act]
290
 
291
- if "mlp" in config.beacon_param:
292
- self.beacon_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
293
- self.beacon_up_proj.weight.data.zero_()
294
- self.beacon_up_proj._is_hf_initialized = True
295
-
296
- self.beacon_down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
297
- self.beacon_down_proj.weight.data.zero_()
298
- self.beacon_down_proj._is_hf_initialized = True
299
-
300
- def _init_beacon_proj(self, missing_keys):
301
- """Initialize the beacon projection weight with that of the ordinal projection."""
302
- if "mlp" in self.config.beacon_param:
303
- if is_deepspeed_zero3_enabled():
304
- # FIXME: after deepspeed initialization, some weights becomes non-zero
305
- # For Mistral, there are rows that are full of zeros
306
- # For Mistral, there are values bigger than 1e29...
307
-
308
- import deepspeed
309
- params = [self.up_proj.weight, self.down_proj.weight, self.beacon_up_proj.weight, self.beacon_down_proj.weight]
310
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
311
- if (self.beacon_up_proj.weight.sum(-1) == 0).any() or (self.beacon_up_proj.weight > 1e29).any():
312
- self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
313
- self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
314
- else:
315
- if any("beacon_up_proj" in missing_key for missing_key in missing_keys):
316
- # only copy the value in-place, without tieing the weight
317
- self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
318
- self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
319
-
320
- def forward(self, x, beacon_size, beacon_indices):
321
- if "mlp" in self.config.beacon_param:
322
- # NOTE: when beacon_pos == "interleave", the beacon_indices points to all beacon tokens in the current window (cached activations + input_ids), so we shall slice out the part corresponding to the input_ids
323
- if beacon_size > 0:
324
- cur_beacon_indices = beacon_indices[-x.shape[1]:]
325
- ordinal_hidden_states = x[:, cur_beacon_indices == 0]
326
- beacon_hidden_states = x[:, cur_beacon_indices == 1]
327
-
328
- ordinal_down_proj = self.down_proj(self.act_fn(self.gate_proj(ordinal_hidden_states)) * self.up_proj(ordinal_hidden_states))
329
- beacon_down_proj = self.beacon_down_proj(self.act_fn(self.gate_proj(beacon_hidden_states)) * self.beacon_up_proj(beacon_hidden_states))
330
-
331
- down_proj = beacon_down_proj.new_ones(x.shape)
332
- down_proj[:, beacon_indices == 0] = ordinal_down_proj
333
- down_proj[:, beacon_indices == 1] = beacon_down_proj
334
- else:
335
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
336
- else:
337
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
338
-
339
  return down_proj
340
 
341
 
@@ -386,7 +162,7 @@ class Qwen2Attention(nn.Module):
386
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
387
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
388
 
389
- self._init_rope()
390
 
391
  # NOTE: add extra parameters for beacon tokens
392
  # skip post initialization to speed up loading
@@ -408,54 +184,6 @@ class Qwen2Attention(nn.Module):
408
  self.beacon_o_proj.weight.data.zero_()
409
  self.beacon_o_proj._is_hf_initialized = True
410
 
411
- def _init_rope(self):
412
- if self.config.rope_scaling is None:
413
- self.rotary_emb = Qwen2RotaryEmbedding(
414
- self.head_dim,
415
- max_position_embeddings=self.max_position_embeddings,
416
- base=self.rope_theta,
417
- )
418
- else:
419
- scaling_type = self.config.rope_scaling["type"]
420
- scaling_factor = self.config.rope_scaling["factor"]
421
- if scaling_type == "linear":
422
- self.rotary_emb = Qwen2LinearScalingRotaryEmbedding(
423
- self.head_dim,
424
- max_position_embeddings=self.max_position_embeddings,
425
- scaling_factor=scaling_factor,
426
- base=self.rope_theta,
427
- )
428
- elif scaling_type == "dynamic":
429
- self.rotary_emb = Qwen2DynamicNTKScalingRotaryEmbedding(
430
- self.head_dim,
431
- max_position_embeddings=self.max_position_embeddings,
432
- scaling_factor=scaling_factor,
433
- base=self.rope_theta,
434
- )
435
- elif scaling_type == "yarn":
436
- self.rotary_emb = Qwen2YarnRotaryEmbedding(
437
- self.head_dim,
438
- max_position_embeddings=self.max_position_embeddings,
439
- scaling_factor=scaling_factor,
440
- base=self.rope_theta,
441
- )
442
- elif scaling_type == "yarn-t":
443
- self.rotary_emb = Qwen2YarnDynamicTemperatureRotaryEmbedding(
444
- self.head_dim,
445
- max_position_embeddings=self.max_position_embeddings,
446
- scaling_factor=scaling_factor,
447
- base=self.rope_theta,
448
- )
449
- elif scaling_type == "yarn-t-logn":
450
- self.rotary_emb = Qwen2YarnDynamicTemperatureLogNRotaryEmbedding(
451
- self.head_dim,
452
- max_position_embeddings=self.max_position_embeddings,
453
- scaling_factor=scaling_factor,
454
- base=self.rope_theta,
455
- )
456
- else:
457
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
458
-
459
  def _init_beacon_proj(self, missing_keys):
460
  """Initialize the beacon projection weight with that of the ordinal projection."""
461
  beacon_param = self.config.beacon_param
@@ -538,44 +266,37 @@ class Qwen2Attention(nn.Module):
538
  # NOTE: when beacon_pos == "interleave", the beacon_indices points to all beacon tokens in the current window (cached activations + input_ids), so we shall slice out the part corresponding to the input_ids
539
  cur_beacon_indices = beacon_indices[-hidden_states.shape[1]:]
540
 
541
- ordinal_hidden_states = hidden_states[:, cur_beacon_indices == 0]
542
- beacon_hidden_states = hidden_states[:, cur_beacon_indices == 1]
543
-
544
  if "q" in self.config.beacon_param:
545
- ordinal_query_states = self.q_proj(ordinal_hidden_states)
546
- beacon_query_states = self.beacon_q_proj(beacon_hidden_states)
547
- query_states = beacon_query_states.new_zeros((ordinal_query_states.shape[0], cur_beacon_indices.shape[0], ordinal_query_states.shape[2]))
548
- query_states[:, cur_beacon_indices == 0] = ordinal_query_states
549
- query_states[:, cur_beacon_indices == 1] = beacon_query_states
550
- # NOTE: replicate hidden states for beacon tokens in case of parallel windows
551
  if (cur_beacon_indices == 2).any():
552
- query_states[:, cur_beacon_indices == 2] = beacon_query_states[:, :(cur_beacon_indices == 2).sum()]
553
-
 
554
  else:
555
  query_states = self.q_proj(hidden_states)
556
 
557
  if "k" in self.config.beacon_param:
558
- ordinal_key_states = self.k_proj(ordinal_hidden_states)
559
- beacon_key_states = self.beacon_k_proj(beacon_hidden_states)
560
- key_states = beacon_key_states.new_zeros((ordinal_key_states.shape[0], cur_beacon_indices.shape[0], ordinal_key_states.shape[2]))
561
- key_states[:, cur_beacon_indices == 0] = ordinal_key_states
562
- key_states[:, cur_beacon_indices == 1] = beacon_key_states
563
- # NOTE: replicate hidden states for beacon tokens in case of parallel windows
564
  if (cur_beacon_indices == 2).any():
565
- key_states[:, cur_beacon_indices == 2] = beacon_key_states[:, :(cur_beacon_indices == 2).sum()]
566
-
 
567
  else:
568
  key_states = self.k_proj(hidden_states)
569
-
570
  if "v" in self.config.beacon_param:
571
- ordinal_value_states = self.v_proj(ordinal_hidden_states)
572
- beacon_value_states = self.beacon_v_proj(beacon_hidden_states)
573
- value_states = beacon_value_states.new_zeros((ordinal_value_states.shape[0], cur_beacon_indices.shape[0], ordinal_value_states.shape[2]))
574
- value_states[:, cur_beacon_indices == 0] = ordinal_value_states
575
- value_states[:, cur_beacon_indices == 1] = beacon_value_states
576
- # NOTE: replicate hidden states for beacon tokens in case of parallel windows
577
  if (cur_beacon_indices == 2).any():
578
- value_states[:, cur_beacon_indices == 2] = beacon_value_states[:, :(cur_beacon_indices == 2).sum()]
 
 
579
  else:
580
  value_states = self.v_proj(hidden_states)
581
 
@@ -592,14 +313,9 @@ class Qwen2Attention(nn.Module):
592
  cur_beacon_indices = beacon_indices[-attn_output.shape[1]:]
593
 
594
  if "o" in self.config.beacon_param:
595
- ordinal_attn_output = self.o_proj(attn_output[:, cur_beacon_indices == 0])
596
- beacon_attn_output = self.beacon_o_proj(attn_output[:, cur_beacon_indices == 1])
597
- attn_output = beacon_attn_output.new_zeros(attn_output.shape)
598
- attn_output[:, cur_beacon_indices == 0] = ordinal_attn_output
599
- attn_output[:, cur_beacon_indices == 1] = beacon_attn_output
600
- # NOTE: replicate hidden states for beacon tokens in case of parallel windows
601
- # if (cur_beacon_indices == 2).any():
602
- # attn_output[:, cur_beacon_indices == 2] = beacon_attn_output[:, :(cur_beacon_indices == 2).sum()]
603
  else:
604
  attn_output = self.o_proj(attn_output)
605
  else:
@@ -1036,10 +752,6 @@ class Qwen2DecoderLayer(nn.Module):
1036
  (see `past_key_values`).
1037
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1038
  """
1039
-
1040
- # NOTE: get beacon_size in case the mlp is included in beacon_param
1041
- past_key, past_value, beacon_size, beacon_indices = past_key_value
1042
-
1043
  residual = hidden_states
1044
 
1045
  hidden_states = self.input_layernorm(hidden_states)
@@ -1058,7 +770,7 @@ class Qwen2DecoderLayer(nn.Module):
1058
  # Fully Connected
1059
  residual = hidden_states
1060
  hidden_states = self.post_attention_layernorm(hidden_states)
1061
- hidden_states = self.mlp(hidden_states, beacon_size, beacon_indices)
1062
  hidden_states = residual + hidden_states
1063
 
1064
  outputs = (hidden_states,)
@@ -1426,7 +1138,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1426
  # initialize weights of possible q,k,v,o,mlp
1427
  for layer in model.model.layers:
1428
  layer.self_attn._init_beacon_proj(missing_keys)
1429
- layer.mlp._init_beacon_proj(missing_keys)
1430
 
1431
  return model
1432
 
@@ -1438,12 +1149,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1438
  past_key_values: Optional[List[torch.FloatTensor]] = None,
1439
  inputs_embeds: Optional[torch.FloatTensor] = None,
1440
  labels: Optional[torch.LongTensor] = None,
1441
- shift_labels: Optional[bool] = True,
1442
  use_cache: Optional[bool] = None,
1443
  output_attentions: Optional[bool] = None,
1444
  output_hidden_states: Optional[bool] = None,
1445
  return_dict: Optional[bool] = None,
1446
- ) -> Union[Tuple, BeaconModelOutput]:
1447
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1448
  output_hidden_states = (
1449
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1474,19 +1184,19 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1474
 
1475
  loss = None
1476
  batch_loss = None
1477
- valid_token_num = None
1478
 
1479
  if labels is not None:
1480
- loss, batch_loss, valid_token_num = compute_loss(logits, labels, shift=shift_labels)
1481
 
1482
  if not return_dict:
1483
  output = (logits,) + outputs[1:]
1484
  return (loss,) + output if loss is not None else output
1485
 
1486
- return BeaconModelOutput(
1487
  loss=loss,
1488
  batch_loss=batch_loss,
1489
- valid_token_num=valid_token_num,
1490
  logits=logits,
1491
  past_key_values=outputs.past_key_values,
1492
  hidden_states=outputs.hidden_states,
@@ -1504,6 +1214,8 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1504
  output_attentions: Optional[bool] = None,
1505
  output_hidden_states: Optional[bool] = None,
1506
  return_dict: Optional[bool] = None,
 
 
1507
  ):
1508
  # t1 = time.time()
1509
 
@@ -1511,12 +1223,13 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1511
  self.memory.prepare(
1512
  input_ids=input_ids,
1513
  attention_mask=attention_mask,
1514
- labels=labels
 
 
1515
  )
1516
 
1517
  # t2 = time.time()
1518
 
1519
- # after the first window, one token at a time
1520
  while not self.memory.finish:
1521
 
1522
  # t3 = time.time()
@@ -1536,8 +1249,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1536
  output_hidden_states=output_hidden_states,
1537
  return_dict=return_dict,
1538
  labels=labels,
1539
- # NOTE: the labels have been shifted so that all tokens in the window have the proper loss
1540
- shift_labels=False,
1541
  )
1542
 
1543
  # t5 = time.time()
@@ -1549,7 +1260,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1549
 
1550
  if labels is not None:
1551
  # update loss
1552
- self.memory.update_loss(outputs.batch_loss, outputs.valid_token_num)
1553
 
1554
  # t7 = time.time()
1555
 
@@ -1567,7 +1278,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1567
  # input()
1568
 
1569
  return outputs
1570
-
1571
  def forward(self, **kwargs):
1572
  """Forward computation over a batch of sequences.
1573
  """
 
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
  from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache
 
34
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.utils import (
 
52
 
53
  from .configuration_qwen2 import Qwen2Config
54
  from .modeling_beacon import Memory
55
+ from .modeling_utils import optional_grad_ctx, compute_loss, get_rope, ModelOutput
56
 
57
 
58
  logger = logging.get_logger(__name__)
 
98
  return self.weight * hidden_states.to(input_dtype)
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Copied from transformers.models.mistral.modeling_mistral.Qwen2MLP with Qwen2->Qwen2
102
  class Qwen2MLP(nn.Module):
103
  def __init__(self, config):
 
110
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
111
  self.act_fn = ACT2FN[config.hidden_act]
112
 
113
+ def forward(self, x):
114
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return down_proj
116
 
117
 
 
162
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
163
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
164
 
165
+ self.rotary_emb = get_rope(self.head_dim, config.rope_theta, config.max_position_embeddings, getattr(config, "rope_scaling", None))
166
 
167
  # NOTE: add extra parameters for beacon tokens
168
  # skip post initialization to speed up loading
 
184
  self.beacon_o_proj.weight.data.zero_()
185
  self.beacon_o_proj._is_hf_initialized = True
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def _init_beacon_proj(self, missing_keys):
188
  """Initialize the beacon projection weight with that of the ordinal projection."""
189
  beacon_param = self.config.beacon_param
 
266
  # NOTE: when beacon_pos == "interleave", the beacon_indices points to all beacon tokens in the current window (cached activations + input_ids), so we shall slice out the part corresponding to the input_ids
267
  cur_beacon_indices = beacon_indices[-hidden_states.shape[1]:]
268
 
269
+ # NOTE: there is slight redundant computation because ordinal tokens should never be projected by beacon matrices, but we are doing this for efficiency
 
 
270
  if "q" in self.config.beacon_param:
271
+ ordinal_query_states = self.q_proj(hidden_states)
272
+ beacon_query_states = self.beacon_q_proj(hidden_states)
273
+ query_states = torch.where((cur_beacon_indices == 0)[:, None], ordinal_query_states, beacon_query_states)
 
 
 
274
  if (cur_beacon_indices == 2).any():
275
+ # beacon_indices == 2 means the beacon token is used to replicate the ones in previous window for parallel encoding
276
+ # we should slice out all beacon tokens then copy them to the replicate beacon tokens
277
+ query_states[:, cur_beacon_indices == 2] = beacon_query_states[:, cur_beacon_indices == 1][:, :(cur_beacon_indices == 2).sum()]
278
  else:
279
  query_states = self.q_proj(hidden_states)
280
 
281
  if "k" in self.config.beacon_param:
282
+ ordinal_key_states = self.k_proj(hidden_states)
283
+ beacon_key_states = self.beacon_k_proj(hidden_states)
284
+ key_states = torch.where((cur_beacon_indices == 0)[:, None], ordinal_key_states, beacon_key_states)
 
 
 
285
  if (cur_beacon_indices == 2).any():
286
+ # beacon_indices == 2 means the beacon token is used to replicate the ones in previous window for parallel encoding
287
+ # we should slice out all beacon tokens then copy them to the replicate beacon tokens
288
+ key_states[:, cur_beacon_indices == 2] = beacon_key_states[:, cur_beacon_indices == 1][:, :(cur_beacon_indices == 2).sum()]
289
  else:
290
  key_states = self.k_proj(hidden_states)
291
+
292
  if "v" in self.config.beacon_param:
293
+ ordinal_value_states = self.v_proj(hidden_states)
294
+ beacon_value_states = self.beacon_v_proj(hidden_states)
295
+ value_states = torch.where((cur_beacon_indices == 0)[:, None], ordinal_value_states, beacon_value_states)
 
 
 
296
  if (cur_beacon_indices == 2).any():
297
+ # beacon_indices == 2 means the beacon token is used to replicate the ones in previous window for parallel encoding
298
+ # we should slice out all beacon tokens then copy them to the replicate beacon tokens
299
+ value_states[:, cur_beacon_indices == 2] = beacon_value_states[:, cur_beacon_indices == 1][:, :(cur_beacon_indices == 2).sum()]
300
  else:
301
  value_states = self.v_proj(hidden_states)
302
 
 
313
  cur_beacon_indices = beacon_indices[-attn_output.shape[1]:]
314
 
315
  if "o" in self.config.beacon_param:
316
+ ordinal_attn_output = self.o_proj(attn_output)
317
+ beacon_attn_output = self.beacon_o_proj(attn_output)
318
+ attn_output = torch.where((cur_beacon_indices == 0)[:, None], ordinal_attn_output, beacon_attn_output)
 
 
 
 
 
319
  else:
320
  attn_output = self.o_proj(attn_output)
321
  else:
 
752
  (see `past_key_values`).
753
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
754
  """
 
 
 
 
755
  residual = hidden_states
756
 
757
  hidden_states = self.input_layernorm(hidden_states)
 
770
  # Fully Connected
771
  residual = hidden_states
772
  hidden_states = self.post_attention_layernorm(hidden_states)
773
+ hidden_states = self.mlp(hidden_states)
774
  hidden_states = residual + hidden_states
775
 
776
  outputs = (hidden_states,)
 
1138
  # initialize weights of possible q,k,v,o,mlp
1139
  for layer in model.model.layers:
1140
  layer.self_attn._init_beacon_proj(missing_keys)
 
1141
 
1142
  return model
1143
 
 
1149
  past_key_values: Optional[List[torch.FloatTensor]] = None,
1150
  inputs_embeds: Optional[torch.FloatTensor] = None,
1151
  labels: Optional[torch.LongTensor] = None,
 
1152
  use_cache: Optional[bool] = None,
1153
  output_attentions: Optional[bool] = None,
1154
  output_hidden_states: Optional[bool] = None,
1155
  return_dict: Optional[bool] = None,
1156
+ ) -> Union[Tuple, ModelOutput]:
1157
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1158
  output_hidden_states = (
1159
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1184
 
1185
  loss = None
1186
  batch_loss = None
1187
+ token_loss = None
1188
 
1189
  if labels is not None:
1190
+ loss, batch_loss, token_loss = compute_loss(logits, labels, shift=False)
1191
 
1192
  if not return_dict:
1193
  output = (logits,) + outputs[1:]
1194
  return (loss,) + output if loss is not None else output
1195
 
1196
+ return ModelOutput(
1197
  loss=loss,
1198
  batch_loss=batch_loss,
1199
+ token_loss=token_loss,
1200
  logits=logits,
1201
  past_key_values=outputs.past_key_values,
1202
  hidden_states=outputs.hidden_states,
 
1214
  output_attentions: Optional[bool] = None,
1215
  output_hidden_states: Optional[bool] = None,
1216
  return_dict: Optional[bool] = None,
1217
+ beacon_skip_first: Optional[int] = None,
1218
+ beacon_skip_last: Optional[int] = None,
1219
  ):
1220
  # t1 = time.time()
1221
 
 
1223
  self.memory.prepare(
1224
  input_ids=input_ids,
1225
  attention_mask=attention_mask,
1226
+ labels=labels,
1227
+ skip_first=beacon_skip_first,
1228
+ skip_last=beacon_skip_last,
1229
  )
1230
 
1231
  # t2 = time.time()
1232
 
 
1233
  while not self.memory.finish:
1234
 
1235
  # t3 = time.time()
 
1249
  output_hidden_states=output_hidden_states,
1250
  return_dict=return_dict,
1251
  labels=labels,
 
 
1252
  )
1253
 
1254
  # t5 = time.time()
 
1260
 
1261
  if labels is not None:
1262
  # update loss
1263
+ self.memory.update_loss(outputs.batch_loss, (labels != -100).sum(-1))
1264
 
1265
  # t7 = time.time()
1266
 
 
1278
  # input()
1279
 
1280
  return outputs
1281
+
1282
  def forward(self, **kwargs):
1283
  """Forward computation over a batch of sequences.
1284
  """
modeling_utils.py CHANGED
@@ -29,14 +29,28 @@ def move_to_device(data, device):
29
  else:
30
  return data
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def compute_loss(logits, labels, shift=False):
33
  """
34
  Returns:
35
  token_loss: batch_size, seq_length
36
  """
37
  if shift:
38
- logits = logits[:, :-1, :].contiguous()
39
- labels = labels[:, 1:].contiguous()
40
 
41
  labels = labels.to(logits.device)
42
  batch_size = logits.shape[0]
@@ -63,7 +77,7 @@ def compute_loss(logits, labels, shift=False):
63
  if (valid_token_num == 0).any():
64
  batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
65
 
66
- return loss, batch_loss, valid_token_num
67
 
68
 
69
  @torch.no_grad()
@@ -89,14 +103,15 @@ def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=Non
89
 
90
  output = model(**x)
91
 
 
 
92
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
93
  if hasattr(output, "batch_loss"):
94
  # output from our model has batch_loss by default
95
  batch_loss = output.batch_loss
96
- valid_token_num = output.valid_token_num
97
  else:
98
  # output from other models does not
99
- loss, batch_loss, valid_token_num = compute_loss(output.logits, x["labels"], shift=True)
100
 
101
  index = index.tolist()
102
  batch_loss = batch_loss.tolist()
@@ -194,14 +209,15 @@ def evaluate_nll(model, dataloader, accelerator:Optional[Accelerator]=None):
194
 
195
  output = model(**x)
196
 
 
 
197
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
198
  if hasattr(output, "batch_loss"):
199
  # output from our model has batch_loss by default
200
  batch_loss = output.batch_loss
201
- valid_token_num = output.valid_token_num
202
  else:
203
  # output from other models does not
204
- loss, batch_loss, valid_token_num = compute_loss(output.logits, x["labels"], shift=True)
205
 
206
  if accelerator is not None and accelerator.num_processes > 1:
207
  # num_device * batch_size
@@ -216,13 +232,480 @@ def evaluate_nll(model, dataloader, accelerator:Optional[Accelerator]=None):
216
  return all_loss
217
 
218
 
219
-
220
  @dataclass
221
- class BeaconModelOutput(BaseModelOutputWithPast):
222
  loss: Optional[torch.FloatTensor] = None
223
  batch_loss: Optional[torch.FloatTensor] = None
224
- valid_token_num: Optional[torch.LongTensor] = None
225
  logits: torch.FloatTensor = None
226
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
227
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
228
  attentions: Optional[Tuple[torch.FloatTensor]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
  return data
31
 
32
+ def get_shifted_labels(input_ids):
33
+ if isinstance(input_ids, torch.Tensor):
34
+ labels = input_ids.clone()
35
+ labels = torch.cat([labels[:, 1:], labels.new_zeros((input_ids.shape[0], 1)) - 100], dim=-1)
36
+ elif isinstance(input_ids, list) and isinstance(input_ids[0], int):
37
+ labels = input_ids.copy()
38
+ labels = labels[1:] + [-100]
39
+ elif isinstance(input_ids, list) and isinstance(input_ids[0], list):
40
+ labels = input_ids.copy()
41
+ for i, label in enumerate(labels):
42
+ labels[i] = labels[i][1:] + [-100]
43
+ else:
44
+ raise NotImplementedError
45
+ return labels
46
+
47
  def compute_loss(logits, labels, shift=False):
48
  """
49
  Returns:
50
  token_loss: batch_size, seq_length
51
  """
52
  if shift:
53
+ labels = get_shifted_labels(labels)
 
54
 
55
  labels = labels.to(logits.device)
56
  batch_size = logits.shape[0]
 
77
  if (valid_token_num == 0).any():
78
  batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
79
 
80
+ return loss, batch_loss, token_loss
81
 
82
 
83
  @torch.no_grad()
 
103
 
104
  output = model(**x)
105
 
106
+ valid_token_num = (x["labels"] != -100).sum(-1)
107
+
108
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
109
  if hasattr(output, "batch_loss"):
110
  # output from our model has batch_loss by default
111
  batch_loss = output.batch_loss
 
112
  else:
113
  # output from other models does not
114
+ loss, batch_loss, token_loss = compute_loss(output.logits, x["labels"], shift=True)
115
 
116
  index = index.tolist()
117
  batch_loss = batch_loss.tolist()
 
209
 
210
  output = model(**x)
211
 
212
+ valid_token_num = (x["labels"] != -100).sum()
213
+
214
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
215
  if hasattr(output, "batch_loss"):
216
  # output from our model has batch_loss by default
217
  batch_loss = output.batch_loss
 
218
  else:
219
  # output from other models does not
220
+ loss, batch_loss, token_loss = compute_loss(output.logits, x["labels"], shift=True)
221
 
222
  if accelerator is not None and accelerator.num_processes > 1:
223
  # num_device * batch_size
 
232
  return all_loss
233
 
234
 
 
235
  @dataclass
236
+ class ModelOutput(BaseModelOutputWithPast):
237
  loss: Optional[torch.FloatTensor] = None
238
  batch_loss: Optional[torch.FloatTensor] = None
239
+ token_loss: Optional[torch.FloatTensor] = None
240
  logits: torch.FloatTensor = None
241
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
242
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
243
  attentions: Optional[Tuple[torch.FloatTensor]] = None
244
+
245
+
246
+
247
+ ########## Various RoPE Scaling Methods Below (wrap the encoding process within the module for convenience) ##########
248
+
249
+ def get_rope(head_dim, base, max_position_embeddings, rope_scaling=None):
250
+ """
251
+ Get rope module. {native, linear scaling, dynamic ntk scaling, yarn scaling, llama3 scaling}
252
+ """
253
+ if rope_scaling is None:
254
+ rope = RotaryEmbedding(
255
+ dim=head_dim,
256
+ base=base,
257
+ max_position_embeddings=max_position_embeddings,
258
+ )
259
+ else:
260
+ scaling_type = rope_scaling["type"]
261
+ scaling_factor = rope_scaling["factor"]
262
+ if scaling_type == "linear":
263
+ rope = LinearScalingRotaryEmbedding(
264
+ dim=head_dim,
265
+ base=base,
266
+ max_position_embeddings=max_position_embeddings,
267
+ scaling_factor=scaling_factor,
268
+ )
269
+ elif scaling_type == "dynamic":
270
+ rope = DynamicNTKScalingRotaryEmbedding(
271
+ dim=head_dim,
272
+ base=base,
273
+ max_position_embeddings=max_position_embeddings,
274
+ scaling_factor=scaling_factor,
275
+ )
276
+ elif scaling_type == "yarn":
277
+ rope = YarnRotaryEmbedding(
278
+ dim=head_dim,
279
+ base=base,
280
+ max_position_embeddings=max_position_embeddings,
281
+ scaling_factor=scaling_factor,
282
+ )
283
+ elif scaling_type == "yarn-t":
284
+ rope = YarnDynamicTemperatureRotaryEmbedding(
285
+ dim=head_dim,
286
+ base=base,
287
+ max_position_embeddings=max_position_embeddings,
288
+ scaling_factor=scaling_factor,
289
+ )
290
+ elif scaling_type == "yarn-t-logn":
291
+ rope = YarnDynamicTemperatureLogNRotaryEmbedding(
292
+ dim=head_dim,
293
+ base=base,
294
+ max_position_embeddings=max_position_embeddings,
295
+ scaling_factor=scaling_factor,
296
+ )
297
+ elif scaling_type == "llama3":
298
+ rope = Llama3RotaryEmbedding(
299
+ dim=head_dim,
300
+ base=base,
301
+ max_position_embeddings=max_position_embeddings,
302
+ scaling_factor=scaling_factor,
303
+ original_max_position_embeddings=rope_scaling.get("original_max_position_embeddings", 8192),
304
+ low_freq_factor=rope_scaling.get("low_freq_factor", 1),
305
+ high_freq_factor=rope_scaling.get("high_freq_factor", 4),
306
+ )
307
+ else:
308
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
309
+
310
+ return rope
311
+
312
+
313
+ def rotate_half(x):
314
+ """Rotates half the hidden dims of the input."""
315
+ x1 = x[..., : x.shape[-1] // 2]
316
+ x2 = x[..., x.shape[-1] // 2 :]
317
+ return torch.cat((-x2, x1), dim=-1)
318
+
319
+
320
+ class RotaryEmbedding(torch.nn.Module):
321
+ def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None):
322
+ super().__init__()
323
+
324
+ self.dim = dim
325
+ self.max_position_embeddings = max_position_embeddings
326
+ self.base = base
327
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
328
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
329
+
330
+ # Build here to make `torch.jit.trace` work.
331
+ self._set_cos_sin_cache(
332
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
333
+ )
334
+
335
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
336
+ self.max_seq_len_cached = seq_len
337
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
338
+ freqs = torch.outer(t, self.inv_freq)
339
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
340
+ emb = torch.cat((freqs, freqs), dim=-1)
341
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
342
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
343
+
344
+ def forward(self, q, k, position_ids):
345
+ seq_len = max(position_ids.max().item() + 1, k.shape[2])
346
+
347
+ # x: [bs, num_attention_heads, seq_len, head_size]
348
+ if seq_len > self.max_seq_len_cached:
349
+ self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
350
+
351
+ # batch_size, 1, key_len, head_dim
352
+ k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
353
+ k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
354
+
355
+ q_cos = k_cos[..., -q.shape[2]:, :]
356
+ q_sin = k_sin[..., -q.shape[2]:, :]
357
+
358
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
359
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
360
+ return q_embed, k_embed
361
+
362
+
363
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
364
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
365
+
366
+ def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
367
+ self.scaling_factor = scaling_factor
368
+ super().__init__(dim, max_position_embeddings, base, device)
369
+
370
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
371
+ self.max_seq_len_cached = seq_len
372
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
373
+ t = t / self.scaling_factor
374
+
375
+ freqs = torch.outer(t, self.inv_freq)
376
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
377
+ emb = torch.cat((freqs, freqs), dim=-1)
378
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
379
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
380
+
381
+
382
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
383
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
384
+
385
+ def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
386
+ self.scaling_factor = scaling_factor
387
+ super().__init__(dim, max_position_embeddings, base, device)
388
+
389
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
390
+ self.max_seq_len_cached = seq_len
391
+
392
+ if seq_len > self.max_position_embeddings:
393
+ base = self.base * (
394
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
395
+ ) ** (self.dim / (self.dim - 2))
396
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
397
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
398
+
399
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
400
+
401
+ freqs = torch.outer(t, self.inv_freq)
402
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
403
+ emb = torch.cat((freqs, freqs), dim=-1)
404
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
405
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
406
+
407
+
408
+ class YarnRotaryEmbedding(torch.nn.Module):
409
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
410
+ super().__init__()
411
+
412
+ self.base = base
413
+ self.dim = dim
414
+ self.scaling_factor = scaling_factor
415
+ self.beta_slow = beta_slow
416
+ self.beta_fast = beta_fast
417
+ self.max_position_embeddings = max_position_embeddings
418
+
419
+ self._set_cos_sin_cache(
420
+ seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
421
+ )
422
+
423
+ def _get_factor(self):
424
+ # the dimension whose index is smaller than fast_dim rotates more than beta_fast
425
+ fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
426
+ fast_dim = max(math.floor(fast_dim), 0)
427
+ # the dimension whose index is bigger than slow_dim rotates less than beta_slow
428
+ slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
429
+ slow_dim = min(math.ceil(slow_dim), self.dim - 1)
430
+
431
+ if fast_dim == slow_dim:
432
+ slow_dim += 0.001
433
+
434
+ # NOTE: very important to use full precision here so that the factor is correct
435
+ dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
436
+ dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
437
+ dim_factor = torch.clamp(dim_factor, 0, 1)
438
+
439
+ # align with the paper notation
440
+ return (1 - dim_factor)
441
+
442
+ def _get_temperature(self):
443
+ if self.scaling_factor <= 1:
444
+ return 1.0
445
+ return 0.07 * math.log(self.scaling_factor) + 1.0
446
+
447
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
448
+ dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
449
+ # dim / 2
450
+ freq = self.base ** dim_arange
451
+ theta = 1 / freq
452
+ interleave_theta = theta / self.scaling_factor
453
+
454
+ factor = self._get_factor().to(device)
455
+ yarn_theta = factor * theta + (1 - factor) * interleave_theta
456
+ self.register_buffer("inv_freq", yarn_theta, persistent=False)
457
+
458
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
459
+ freqs = torch.outer(t, self.inv_freq)
460
+ emb = torch.cat((freqs, freqs), dim=-1)
461
+
462
+ # get attention temperature
463
+ temperature = self._get_temperature()
464
+
465
+ self.register_buffer("cos_cached", emb.cos() * temperature, persistent=False)
466
+ self.register_buffer("sin_cached", emb.sin() * temperature, persistent=False)
467
+ self.max_seq_len_cached = seq_len
468
+
469
+ def forward(self, q, k, position_ids):
470
+ seq_len = max(position_ids.max().item() + 1, k.shape[2])
471
+
472
+ # x: [bs, num_attention_heads, seq_len, head_size]
473
+ if seq_len > self.max_seq_len_cached:
474
+ self.scaling_factor = seq_len / self.max_position_embeddings
475
+ self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
476
+
477
+ k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
478
+ k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
479
+
480
+ q_cos = k_cos[..., -q.shape[2]:, :]
481
+ q_sin = k_sin[..., -q.shape[2]:, :]
482
+
483
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
484
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
485
+ return q_embed, k_embed
486
+
487
+
488
+ class YarnDynamicTemperatureRotaryEmbedding(torch.nn.Module):
489
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
490
+ super().__init__()
491
+
492
+ self.base = base
493
+ self.dim = dim
494
+ self.scaling_factor = scaling_factor
495
+ self.beta_slow = beta_slow
496
+ self.beta_fast = beta_fast
497
+ self.max_position_embeddings = max_position_embeddings
498
+
499
+ self._set_cos_sin_cache(
500
+ seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
501
+ )
502
+
503
+ def _get_factor(self):
504
+ # the dimension whose index is smaller than fast_dim rotates more than beta_fast
505
+ fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
506
+ fast_dim = max(math.floor(fast_dim), 0)
507
+ # the dimension whose index is bigger than slow_dim rotates less than beta_slow
508
+ slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
509
+ slow_dim = min(math.ceil(slow_dim), self.dim - 1)
510
+
511
+ if fast_dim == slow_dim:
512
+ slow_dim += 0.001
513
+
514
+ # NOTE: very important to use full precision here so that the factor is correct
515
+ dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
516
+ dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
517
+ dim_factor = torch.clamp(dim_factor, 0, 1)
518
+
519
+ # align with the paper notation
520
+ return (1 - dim_factor)
521
+
522
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
523
+ dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
524
+ # dim / 2
525
+ freq = self.base ** dim_arange
526
+ theta = 1 / freq
527
+ interleave_theta = theta / self.scaling_factor
528
+
529
+ factor = self._get_factor().to(device)
530
+ yarn_theta = factor * theta + (1 - factor) * interleave_theta
531
+ self.register_buffer("inv_freq", yarn_theta, persistent=False)
532
+
533
+ positions = torch.arange(seq_len, device=device, dtype=torch.float32)
534
+ freqs = torch.outer(positions, self.inv_freq)
535
+ emb = torch.cat((freqs, freqs), dim=-1)
536
+
537
+ # NOTE: get attention temperature that will be applied on the query vector
538
+ # temperature = torch.log(positions + 1) / math.log(self.max_position_embeddings)
539
+ temperature = (0.07 * torch.log((positions + 1) / self.max_position_embeddings) + 1) ** 2
540
+ temperature[:self.max_position_embeddings] = 1
541
+ self.register_buffer("temperature", temperature.unsqueeze(1), persistent=False)
542
+
543
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
544
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
545
+ self.max_seq_len_cached = seq_len
546
+
547
+ def forward(self, q, k, position_ids):
548
+ seq_len = max(position_ids.max().item() + 1, k.shape[2])
549
+
550
+ # x: [bs, num_attention_heads, seq_len, head_size]
551
+ if seq_len > self.max_seq_len_cached:
552
+ self.scaling_factor = seq_len / self.max_position_embeddings
553
+ self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
554
+
555
+ # batch_size, 1, key_len, head_dim
556
+ k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
557
+ k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
558
+
559
+ q_cos = k_cos[..., -q.shape[2]:, :]
560
+ q_sin = k_sin[..., -q.shape[2]:, :]
561
+
562
+ q_position_ids = position_ids[:, -q.shape[2]:]
563
+ temperature = self.temperature[q_position_ids].to(dtype=k.dtype).unsqueeze(1)
564
+ q_cos = q_cos * temperature
565
+ q_sin = q_sin * temperature
566
+
567
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
568
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
569
+ return q_embed, k_embed
570
+
571
+
572
+ class YarnDynamicTemperatureLogNRotaryEmbedding(torch.nn.Module):
573
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
574
+ super().__init__()
575
+
576
+ self.base = base
577
+ self.dim = dim
578
+ self.scaling_factor = scaling_factor
579
+ self.beta_slow = beta_slow
580
+ self.beta_fast = beta_fast
581
+ self.max_position_embeddings = max_position_embeddings
582
+
583
+ self._set_cos_sin_cache(
584
+ seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
585
+ )
586
+
587
+ def _get_factor(self):
588
+ # the dimension whose index is smaller than fast_dim rotates more than beta_fast
589
+ fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
590
+ fast_dim = max(math.floor(fast_dim), 0)
591
+ # the dimension whose index is bigger than slow_dim rotates less than beta_slow
592
+ slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
593
+ slow_dim = min(math.ceil(slow_dim), self.dim - 1)
594
+
595
+ if fast_dim == slow_dim:
596
+ slow_dim += 0.001
597
+
598
+ # NOTE: very important to use full precision here so that the factor is correct
599
+ dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
600
+ dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
601
+ dim_factor = torch.clamp(dim_factor, 0, 1)
602
+
603
+ # align with the paper notation
604
+ return (1 - dim_factor)
605
+
606
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
607
+ dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
608
+ # dim / 2
609
+ freq = self.base ** dim_arange
610
+ theta = 1 / freq
611
+ interleave_theta = theta / self.scaling_factor
612
+
613
+ factor = self._get_factor().to(device)
614
+ yarn_theta = factor * theta + (1 - factor) * interleave_theta
615
+ self.register_buffer("inv_freq", yarn_theta, persistent=False)
616
+
617
+ positions = torch.arange(seq_len, device=device, dtype=torch.float32)
618
+ freqs = torch.outer(positions, self.inv_freq)
619
+ emb = torch.cat((freqs, freqs), dim=-1)
620
+
621
+ # NOTE: get attention temperature that will be applied on the query vector
622
+ temperature = torch.log(positions + 1) / math.log(self.max_position_embeddings)
623
+ # temperature = (0.07 * torch.log((positions + 1) / self.max_position_embeddings) + 1) ** 2
624
+ temperature[:self.max_position_embeddings] = 1
625
+ self.register_buffer("temperature", temperature.unsqueeze(1), persistent=False)
626
+
627
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
628
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
629
+ self.max_seq_len_cached = seq_len
630
+
631
+ def forward(self, q, k, position_ids):
632
+ seq_len = max(position_ids.max().item() + 1, k.shape[2])
633
+
634
+ # x: [bs, num_attention_heads, seq_len, head_size]
635
+ if seq_len > self.max_seq_len_cached:
636
+ self.scaling_factor = seq_len / self.max_position_embeddings
637
+ self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
638
+
639
+ # batch_size, 1, key_len, head_dim
640
+ k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
641
+ k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
642
+
643
+ q_cos = k_cos[..., -q.shape[2]:, :]
644
+ q_sin = k_sin[..., -q.shape[2]:, :]
645
+
646
+ q_position_ids = position_ids[:, -q.shape[2]:]
647
+ temperature = self.temperature[q_position_ids].to(dtype=k.dtype).unsqueeze(1)
648
+ q_cos = q_cos * temperature
649
+ q_sin = q_sin * temperature
650
+
651
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
652
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
653
+ return q_embed, k_embed
654
+
655
+
656
+ class Llama3RotaryEmbedding(torch.nn.Module):
657
+ def __init__(self, dim, max_position_embeddings=8192, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=8192, low_freq_factor=1, high_freq_factor=4):
658
+ super().__init__()
659
+
660
+ self.base = base
661
+ self.dim = dim
662
+ self.scaling_factor = scaling_factor
663
+ self.original_max_position_embeddings = original_max_position_embeddings
664
+ self.max_position_embeddings = max(max_position_embeddings, int(original_max_position_embeddings * scaling_factor))
665
+ self.low_freq_factor = low_freq_factor
666
+ self.high_freq_factor = high_freq_factor
667
+
668
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
669
+ low_freq_wavelen = self.original_max_position_embeddings / low_freq_factor
670
+ high_freq_wavelen = self.original_max_position_embeddings / high_freq_factor
671
+ new_freqs = []
672
+ for freq in inv_freq:
673
+ wavelen = 2 * math.pi / freq
674
+ if wavelen < high_freq_wavelen:
675
+ new_freqs.append(freq)
676
+ elif wavelen > low_freq_wavelen:
677
+ new_freqs.append(freq / scaling_factor)
678
+ else:
679
+ assert low_freq_wavelen != high_freq_wavelen
680
+ smooth = (self.original_max_position_embeddings / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
681
+ new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
682
+ inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
683
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
684
+
685
+ self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=device)
686
+
687
+ def _set_cos_sin_cache(self, seq_len, device):
688
+ self.max_seq_len_cached = seq_len
689
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
690
+ freqs = torch.outer(t, self.inv_freq)
691
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
692
+ emb = torch.cat((freqs, freqs), dim=-1)
693
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
694
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
695
+
696
+ def forward(self, q, k, position_ids):
697
+ seq_len = max(position_ids.max().item() + 1, k.shape[2])
698
+
699
+ # x: [bs, num_attention_heads, seq_len, head_size]
700
+ if seq_len > self.max_seq_len_cached:
701
+ self._set_cos_sin_cache(seq_len=seq_len, device=k.device)
702
+
703
+ k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
704
+ k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
705
+
706
+ q_cos = k_cos[..., -q.shape[2]:, :]
707
+ q_sin = k_sin[..., -q.shape[2]:, :]
708
+
709
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
710
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
711
+ return q_embed, k_embed