loua19 commited on
Commit
d7346e7
·
0 Parent(s):
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AriaForCausalLM"
4
+ ],
5
+ "bos_token_id": 0,
6
+ "eos_token_id": 1,
7
+ "hidden_size": 1536,
8
+ "intermediate_size": 6144,
9
+ "max_position_embeddings": 8192,
10
+ "model_type": "aria",
11
+ "num_attention_heads": 24,
12
+ "num_hidden_layers": 16,
13
+ "torch_dtype": "bfloat16",
14
+ "transformers_version": "4.45.0",
15
+ "use_cache": true,
16
+ "vocab_size": 17727,
17
+ "auto_map": {
18
+ "AutoConfig": "configuration_aria.AriaConfig",
19
+ "AutoModel": "modeling_aria.AriaModel",
20
+ "AutoModelForCausalLM": "modeling_aria.AriaForCausalLM"
21
+ }
22
+ }
configuration_aria.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class AriaConfig(PretrainedConfig):
5
+ model_type = "aria"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size: int = 17727,
11
+ hidden_size: int = 1536,
12
+ embedding_size: int | None = None,
13
+ num_hidden_layers: int = 16,
14
+ num_attention_heads: int = 64,
15
+ intermediate_size: int = 6144,
16
+ max_position_embeddings: int = 8192,
17
+ use_cache: bool = True,
18
+ bos_token_id: int = 0,
19
+ eos_token_id: int = 1,
20
+ tie_word_embeddings: bool = False,
21
+ output_attentions: bool = False,
22
+ output_hidden_states: bool = False,
23
+ return_dict: bool = False,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(
27
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
28
+ )
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.embedding_size = embedding_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.intermediate_size = intermediate_size
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.use_cache = use_cache
37
+ self.tie_word_embeddings = tie_word_embeddings
38
+ self.output_attentions = output_attentions
39
+ self.output_hidden_states = output_hidden_states
40
+ self.return_dict = return_dict
41
+
42
+ if self.intermediate_size % self.hidden_size != 0:
43
+ raise ValueError(
44
+ "The intermediate size needs to be divisible by hidden size."
45
+ )
46
+
47
+ if self.hidden_size % self.num_attention_heads != 0:
48
+ raise ValueError(
49
+ "The hidden size needs to be divisible by the number of attention heads."
50
+ )
51
+
52
+ @property
53
+ def ff_mult(self):
54
+ return self.intermediate_size // self.hidden_size
55
+
56
+
57
+ __all__ = ["AriaConfig"]
modeling_aria.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
2
+
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F, CrossEntropyLoss
10
+
11
+ from transformers import Cache, DynamicCache, StaticCache
12
+ from transformers.utils import logging
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ BaseModelOutputWithPoolingAndProjection,
19
+ )
20
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
21
+
22
+ from .configuration_aria import AriaConfig
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class AriaPreTrainedModel(PreTrainedModel):
29
+ config_class = AriaConfig
30
+ base_model_prefix = "aria"
31
+ supports_gradient_checkpointing = True
32
+ _no_split_modules = ["AriaBlock"]
33
+ _skip_keys_device_placement = "past_key_values"
34
+ _supports_flash_attn_2 = False
35
+ _supports_cache_class = True
36
+ _supports_quantized_cache = True
37
+ _supports_static_cache = True
38
+ _supports_sdpa = True
39
+ _supports_flex_attn = False
40
+
41
+ def _init_weights(self, module):
42
+ if isinstance(module, nn.Linear):
43
+ module.weight.data.normal_(
44
+ mean=0.0, std=self.config.initializer_range
45
+ )
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ elif isinstance(module, nn.Embedding):
49
+ module.weight.data.normal_(
50
+ mean=0.0, std=self.config.initializer_range
51
+ )
52
+ if module.padding_idx is not None:
53
+ module.weight.data[module.padding_idx].zero_()
54
+ elif isinstance(module, nn.LayerNorm):
55
+ module.bias.data.zero_()
56
+ module.weight.data.fill_(1.0)
57
+
58
+
59
+ class TransformerBlock(nn.Module):
60
+ def __init__(self, model_config: AriaConfig, layer_idx: int):
61
+ super().__init__()
62
+
63
+ self.drop_p = 0.0
64
+ self.n_heads = model_config.num_attention_heads
65
+ self.d_model = model_config.hidden_size
66
+ self.d_head = (
67
+ model_config.hidden_size // model_config.num_attention_heads
68
+ )
69
+ self.max_seq_len = model_config.max_position_embeddings
70
+ self.layer_idx = layer_idx
71
+
72
+ # Attention
73
+ self.mixed_qkv = nn.Linear(
74
+ in_features=self.d_model,
75
+ out_features=3 * self.d_model,
76
+ bias=False,
77
+ )
78
+ self.att_proj_linear = nn.Linear(
79
+ in_features=self.d_model,
80
+ out_features=self.d_model,
81
+ bias=False,
82
+ )
83
+
84
+ # FF Layer
85
+ self.ff_gate_proj = nn.Linear(
86
+ in_features=self.d_model,
87
+ out_features=self.d_model * model_config.ff_mult,
88
+ bias=False,
89
+ )
90
+ self.ff_up_proj = nn.Linear(
91
+ in_features=self.d_model,
92
+ out_features=self.d_model * model_config.ff_mult,
93
+ bias=False,
94
+ )
95
+ self.ff_down_proj = nn.Linear(
96
+ in_features=self.d_model * model_config.ff_mult,
97
+ out_features=self.d_model,
98
+ bias=False,
99
+ )
100
+
101
+ # Pre layer norms
102
+ self.norm1 = nn.LayerNorm(self.d_model)
103
+ self.norm2 = nn.LayerNorm(self.d_model)
104
+
105
+ def forward(
106
+ self,
107
+ x: torch.Tensor,
108
+ attention_mask: torch.Tensor,
109
+ freqs_cis: torch.Tensor,
110
+ position_ids: Optional[torch.Tensor] = None,
111
+ past_key_values: Optional[
112
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
113
+ ] = None,
114
+ use_cache: Optional[bool] = None,
115
+ output_attentions: Optional[bool] = None,
116
+ output_hidden_states: Optional[bool] = None,
117
+ return_dict: Optional[bool] = None,
118
+ cache_position: Optional[torch.Tensor] = None,
119
+ ):
120
+ attn_output, attn_weights, present = self._att_block(
121
+ self.norm1(x),
122
+ attention_mask,
123
+ freqs_cis,
124
+ past_key_values=past_key_values,
125
+ use_cache=use_cache,
126
+ output_attentions=output_attentions,
127
+ cache_position=cache_position,
128
+ )
129
+
130
+ x = x + attn_output
131
+ x = x + self._ff_block(self.norm2(x))
132
+
133
+ outputs = (x, present)
134
+ if use_cache:
135
+ outputs = (x, present, attn_weights)
136
+ else:
137
+ outputs = (x, attn_weights)
138
+
139
+ return outputs
140
+
141
+ def _att_block(
142
+ self,
143
+ x: torch.Tensor,
144
+ attention_mask: torch.Tensor,
145
+ freqs_cis: torch.Tensor,
146
+ past_key_values: Optional[
147
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
148
+ ] = None,
149
+ use_cache: Optional[bool] = None,
150
+ output_attentions: Optional[bool] = None,
151
+ cache_position: Optional[torch.Tensor] = None,
152
+ ):
153
+ batch_size, seq_len, _ = x.shape
154
+ mixed_qkv = self.mixed_qkv(x)
155
+ xq, xk, xv = mixed_qkv.chunk(3, -1)
156
+
157
+ # Reshape for rotary embeddings
158
+ # Need contiguous for q, k since in-place RoPE cannot be applied on a view
159
+ xq = xq.reshape(
160
+ batch_size, seq_len, self.n_heads, self.d_head
161
+ ).contiguous()
162
+ xk = xk.reshape(
163
+ batch_size, seq_len, self.n_heads, self.d_head
164
+ ).contiguous()
165
+ xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
166
+
167
+ # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
168
+ xq = apply_rotary_emb(xq, freqs_cis)
169
+ xk = apply_rotary_emb(xk, freqs_cis)
170
+ xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv))
171
+
172
+ if past_key_values is not None:
173
+ cache_kwargs = {
174
+ # "sin": sin,
175
+ # "cos": cos,
176
+ # "partial_rotation_size": self.rotary_ndims,
177
+ "cache_position": cache_position,
178
+ }
179
+ xk, xv = past_key_values.update(
180
+ xk, xv, self.layer_idx, cache_kwargs
181
+ )
182
+ # scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head)
183
+ att = F.scaled_dot_product_attention(
184
+ query=xq,
185
+ key=xk,
186
+ value=xv,
187
+ attn_mask=attention_mask,
188
+ is_causal=True,
189
+ )
190
+
191
+ # Reshape for out: (b_sz, s_len, n_head, d_head)
192
+ out = att.transpose(1, 2).contiguous()
193
+ out = out.view(batch_size, seq_len, self.n_heads * self.d_head)
194
+
195
+ if not output_attentions:
196
+ att = None
197
+
198
+ return self.att_proj_linear(out), att, past_key_values
199
+
200
+ def _ff_block(self, x: torch.Tensor):
201
+ return self.ff_down_proj(
202
+ F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x)
203
+ )
204
+
205
+
206
+ class AriaModel(AriaPreTrainedModel):
207
+ """Transformer decoder with no language model head.
208
+
209
+ Args:
210
+ model_config (ModelConfig): Model config settings.
211
+ """
212
+
213
+ def __init__(self, model_config: AriaConfig):
214
+ super().__init__(model_config)
215
+ self.model_config = model_config
216
+ self.freqs_cis = None
217
+
218
+ self.tok_embeddings = nn.Embedding(
219
+ num_embeddings=model_config.vocab_size,
220
+ embedding_dim=model_config.hidden_size,
221
+ )
222
+
223
+ self.out_layer_norm = nn.LayerNorm(model_config.hidden_size)
224
+ self.encode_layers = nn.ModuleList()
225
+ for i in range(model_config.num_hidden_layers):
226
+ self.encode_layers.append(TransformerBlock(model_config, i))
227
+
228
+ self.gradient_checkpointing = False
229
+ self.post_init()
230
+
231
+ def forward(
232
+ self,
233
+ input_ids: Optional[torch.Tensor] = None,
234
+ attention_mask: Optional[torch.Tensor] = None,
235
+ position_ids: Optional[torch.Tensor] = None,
236
+ inputs_embeds: Optional[torch.Tensor] = None,
237
+ past_key_values: Optional[
238
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
239
+ ] = None,
240
+ use_cache: Optional[bool] = None,
241
+ output_attentions: Optional[bool] = None,
242
+ output_hidden_states: Optional[bool] = None,
243
+ return_dict: Optional[bool] = None,
244
+ cache_position: Optional[torch.Tensor] = None,
245
+ ):
246
+ """Forward pass of Transformer.
247
+
248
+ Args:
249
+ src (torch.tensor): Input to encoder block, of shape (batch_size,
250
+ seq_len, d_model).
251
+ attn_mask (Optional[torch.tensor]): Attention mask of shape
252
+ (batch_size, seq_len). Defaults to None.
253
+ past_kv (Optional[list[KVCache]]): a list of kv caches. The list index
254
+ corresponds to the layer index.
255
+
256
+ Returns:
257
+ torch.tensor: Model outputs with shape (batch_size, seq_len,
258
+ d_model).
259
+ """
260
+ output_attentions = (
261
+ output_attentions
262
+ if output_attentions is not None
263
+ else self.model_config.output_attentions
264
+ )
265
+ output_hidden_states = (
266
+ output_hidden_states
267
+ if output_hidden_states is not None
268
+ else self.model_config.output_hidden_states
269
+ )
270
+ return_dict = (
271
+ return_dict
272
+ if return_dict is not None
273
+ else self.model_config.use_return_dict
274
+ )
275
+ use_cache = (
276
+ use_cache if use_cache is not None else self.model_config.use_cache
277
+ )
278
+
279
+ if (input_ids is None) ^ (inputs_embeds is not None):
280
+ raise ValueError(
281
+ "You must specify exactly one of input_ids or inputs_embeds"
282
+ )
283
+
284
+ if self.gradient_checkpointing and self.training:
285
+ if use_cache:
286
+ logger.warning_once(
287
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
288
+ )
289
+ use_cache = False
290
+
291
+ if inputs_embeds is None:
292
+ inputs_embeds = self.tok_embeddings(input_ids)
293
+
294
+ return_legacy_cache = False
295
+ if use_cache and not isinstance(past_key_values, Cache):
296
+ return_legacy_cache = True
297
+ if past_key_values is None:
298
+ past_key_values = DynamicCache()
299
+ else:
300
+ past_key_values = DynamicCache.from_legacy_cache(
301
+ past_key_values
302
+ )
303
+ logger.warning_once(
304
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
305
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
306
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
307
+ )
308
+
309
+ seq_length = inputs_embeds.shape[1]
310
+ if cache_position is None:
311
+ past_seen_tokens = (
312
+ past_key_values.get_seq_length()
313
+ if past_key_values is not None
314
+ else 0
315
+ )
316
+ cache_position = torch.arange(
317
+ past_seen_tokens,
318
+ past_seen_tokens + seq_length,
319
+ device=inputs_embeds.device,
320
+ )
321
+
322
+ if position_ids is None:
323
+ position_ids = cache_position.unsqueeze(0)
324
+ hidden_states = inputs_embeds
325
+
326
+ causal_mask = self._update_causal_mask(
327
+ attention_mask,
328
+ inputs_embeds,
329
+ cache_position,
330
+ past_key_values,
331
+ output_attentions,
332
+ )
333
+
334
+ if self.freqs_cis is None:
335
+ self.freqs_cis = precompute_freqs_cis(
336
+ seq_len=self.model_config.max_position_embeddings,
337
+ n_elem=self.model_config.hidden_size
338
+ // self.model_config.num_attention_heads,
339
+ base=500000,
340
+ dtype=hidden_states.dtype,
341
+ ).to(input_ids.device)
342
+ freqs_cis = self.freqs_cis[: input_ids.shape[1]]
343
+
344
+ kwargs = {
345
+ "position_ids": position_ids,
346
+ "past_key_values": past_key_values,
347
+ "use_cache": use_cache,
348
+ "output_attentions": output_attentions,
349
+ "output_hidden_states": output_hidden_states,
350
+ "return_dict": return_dict,
351
+ "cache_position": cache_position,
352
+ }
353
+ next_decoder_cache = None
354
+ if self.gradient_checkpointing:
355
+ for layer in self.encode_layers:
356
+
357
+ def create_custom_forward(module):
358
+ def custom_forward(*args):
359
+ return module(*args)[0]
360
+
361
+ return custom_forward
362
+
363
+ hidden_states = torch.utils.checkpoint.checkpoint(
364
+ create_custom_forward(layer),
365
+ hidden_states,
366
+ causal_mask,
367
+ freqs_cis,
368
+ **kwargs,
369
+ preserve_rng_state=True,
370
+ use_reentrant=True,
371
+ )
372
+ else:
373
+ all_attentions = () if output_attentions else None
374
+ all_hidden_states = () if output_hidden_states else None
375
+ for layer in self.encode_layers:
376
+ if output_hidden_states:
377
+ all_hidden_states = all_hidden_states + (hidden_states,)
378
+ outputs = layer(
379
+ hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs
380
+ )
381
+ hidden_states = outputs[0]
382
+ if use_cache is True:
383
+ next_decoder_cache = outputs[1]
384
+ if output_attentions:
385
+ all_attentions = all_attentions + (
386
+ outputs[2 if use_cache else 1],
387
+ )
388
+ if output_hidden_states:
389
+ all_hidden_states = all_hidden_states + (hidden_states,)
390
+
391
+ hidden_states = self.out_layer_norm(hidden_states)
392
+ next_cache = next_decoder_cache if use_cache else None
393
+
394
+ if return_legacy_cache:
395
+ next_cache = next_cache.to_legacy_cache()
396
+
397
+ if not return_dict:
398
+ return tuple(
399
+ v
400
+ for v in [
401
+ hidden_states,
402
+ next_cache,
403
+ all_hidden_states,
404
+ all_attentions,
405
+ ]
406
+ if v is not None
407
+ )
408
+
409
+ return BaseModelOutputWithPast(
410
+ last_hidden_state=hidden_states,
411
+ past_key_values=next_cache,
412
+ hidden_states=all_hidden_states,
413
+ attentions=all_attentions,
414
+ )
415
+
416
+ def _update_causal_mask(
417
+ self,
418
+ attention_mask: torch.Tensor,
419
+ input_tensor: torch.Tensor,
420
+ cache_position: torch.Tensor,
421
+ past_key_values: Cache,
422
+ output_attentions: bool,
423
+ ):
424
+ if self.model_config._attn_implementation == "flash_attention_2":
425
+ if attention_mask is not None and (attention_mask == 0.0).any():
426
+ return attention_mask
427
+ return None
428
+
429
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
430
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
431
+ # to infer the attention mask.
432
+ past_seen_tokens = (
433
+ past_key_values.get_seq_length()
434
+ if past_key_values is not None
435
+ else 0
436
+ )
437
+ using_static_cache = isinstance(past_key_values, StaticCache)
438
+
439
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
440
+ if (
441
+ self.model_config._attn_implementation == "sdpa"
442
+ and not using_static_cache
443
+ and not output_attentions
444
+ ):
445
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
446
+ attention_mask,
447
+ inputs_embeds=input_tensor,
448
+ past_key_values_length=past_seen_tokens,
449
+ is_training=self.training,
450
+ ):
451
+ return None
452
+
453
+ dtype, device = input_tensor.dtype, input_tensor.device
454
+ sequence_length = input_tensor.shape[1]
455
+ if using_static_cache:
456
+ target_length = past_key_values.get_max_cache_shape()
457
+ else:
458
+ target_length = (
459
+ attention_mask.shape[-1]
460
+ if isinstance(attention_mask, torch.Tensor)
461
+ else past_seen_tokens + sequence_length + 1
462
+ )
463
+
464
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
465
+ causal_mask = (
466
+ self._prepare_4d_causal_attention_mask_with_cache_position(
467
+ attention_mask,
468
+ sequence_length=sequence_length,
469
+ target_length=target_length,
470
+ dtype=dtype,
471
+ device=device,
472
+ cache_position=cache_position,
473
+ batch_size=input_tensor.shape[0],
474
+ )
475
+ )
476
+
477
+ if (
478
+ self.model_config._attn_implementation == "sdpa"
479
+ and attention_mask is not None
480
+ and attention_mask.device.type == "cuda"
481
+ and not output_attentions
482
+ ):
483
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
484
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
485
+ # Details: https://github.com/pytorch/pytorch/issues/110213
486
+ min_dtype = torch.finfo(dtype).min
487
+ causal_mask = AttentionMaskConverter._unmask_unattended(
488
+ causal_mask, min_dtype
489
+ )
490
+
491
+ return causal_mask
492
+
493
+ @staticmethod
494
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
495
+ def _prepare_4d_causal_attention_mask_with_cache_position(
496
+ attention_mask: torch.Tensor,
497
+ sequence_length: int,
498
+ target_length: int,
499
+ dtype: torch.dtype,
500
+ device: torch.device,
501
+ cache_position: torch.Tensor,
502
+ batch_size: int,
503
+ **kwargs,
504
+ ):
505
+ if attention_mask is not None and attention_mask.dim() == 4:
506
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
507
+ causal_mask = attention_mask
508
+ else:
509
+ min_dtype = torch.finfo(dtype).min
510
+ causal_mask = torch.full(
511
+ (sequence_length, target_length),
512
+ fill_value=min_dtype,
513
+ dtype=dtype,
514
+ device=device,
515
+ )
516
+ if sequence_length != 1:
517
+ causal_mask = torch.triu(causal_mask, diagonal=1)
518
+ causal_mask *= torch.arange(
519
+ target_length, device=device
520
+ ) > cache_position.reshape(-1, 1)
521
+ causal_mask = causal_mask[None, None, :, :].expand(
522
+ batch_size, 1, -1, -1
523
+ )
524
+ if attention_mask is not None:
525
+ causal_mask = (
526
+ causal_mask.clone()
527
+ ) # copy to contiguous memory for in-place edit
528
+ mask_length = attention_mask.shape[-1]
529
+ padding_mask = (
530
+ causal_mask[:, :, :, :mask_length]
531
+ + attention_mask[:, None, None, :]
532
+ )
533
+ padding_mask = padding_mask == 0
534
+ causal_mask[:, :, :, :mask_length] = causal_mask[
535
+ :, :, :, :mask_length
536
+ ].masked_fill(padding_mask, min_dtype)
537
+
538
+ return causal_mask
539
+
540
+
541
+ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
542
+ """Transformer decoder with head for language modelling.
543
+
544
+ Args:
545
+ model_config (ModelConfig): Model config settings.
546
+ """
547
+
548
+ def __init__(self, model_config: AriaConfig):
549
+ super().__init__(model_config)
550
+ self.model_config = model_config
551
+ self.max_seq_len = model_config.max_position_embeddings
552
+ self.model = AriaModel(model_config)
553
+ self.lm_head = nn.Linear(
554
+ model_config.hidden_size, model_config.vocab_size, bias=False
555
+ )
556
+ self.post_init()
557
+
558
+ def forward(
559
+ self,
560
+ input_ids: Optional[torch.Tensor] = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ position_ids: Optional[torch.Tensor] = None,
563
+ inputs_embeds: Optional[torch.Tensor] = None,
564
+ past_key_values: Optional[
565
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
566
+ ] = None,
567
+ labels: Optional[torch.Tensor] = None,
568
+ use_cache: Optional[bool] = None,
569
+ output_attentions: Optional[bool] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ cache_position: Optional[torch.Tensor] = None,
573
+ ):
574
+ """Forward pass of Transformer decoder with LM head."""
575
+ return_dict = (
576
+ return_dict
577
+ if return_dict is not None
578
+ else self.model_config.use_return_dict
579
+ )
580
+ outputs = self.model(
581
+ input_ids,
582
+ attention_mask=attention_mask,
583
+ position_ids=position_ids,
584
+ inputs_embeds=inputs_embeds,
585
+ past_key_values=past_key_values,
586
+ use_cache=use_cache,
587
+ output_attentions=output_attentions,
588
+ output_hidden_states=output_hidden_states,
589
+ return_dict=return_dict,
590
+ cache_position=cache_position,
591
+ )
592
+ hidden = outputs[0]
593
+ lm_logits = self.lm_head(hidden)
594
+
595
+ lm_loss = None
596
+ if labels is not None:
597
+ # move labels to correct device to enable model parallelism
598
+ labels = labels.to(lm_logits.device)
599
+ # we are doing next-token prediction; shift prediction scores and input ids by one
600
+ shift_logits = lm_logits[:, :-1, :].contiguous()
601
+ labels = labels[:, 1:].contiguous()
602
+ loss_fct = CrossEntropyLoss()
603
+ lm_loss = loss_fct(
604
+ shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
605
+ )
606
+
607
+ if not return_dict:
608
+ output = (lm_logits,) + outputs[1:]
609
+ return ((lm_loss,) + output) if lm_loss is not None else output
610
+
611
+ return CausalLMOutputWithPast(
612
+ loss=lm_loss,
613
+ logits=lm_logits,
614
+ past_key_values=outputs.past_key_values,
615
+ hidden_states=outputs.hidden_states,
616
+ attentions=outputs.attentions,
617
+ )
618
+
619
+
620
+ class AriaForSequenceEmbeddings(AriaPreTrainedModel):
621
+ """Transformer decoder embedding head for contrastive learning.
622
+
623
+ Args:
624
+ model_config (ModelConfig): Model config settings.
625
+ """
626
+
627
+ def __init__(self, model_config: AriaConfig):
628
+ super().__init__(model_config)
629
+ assert model_config.embedding_size
630
+
631
+ self.model_config = model_config
632
+ self.max_seq_len = model_config.max_position_embeddings
633
+ self.model = AriaModel(model_config)
634
+ self.emb_head = nn.Linear(
635
+ model_config.hidden_size, model_config.embedding_size, bias=False
636
+ )
637
+ self.post_init()
638
+
639
+ def forward(
640
+ self,
641
+ input_ids: torch.Tensor,
642
+ attention_mask: Optional[torch.Tensor] = None,
643
+ position_ids: Optional[torch.Tensor] = None,
644
+ inputs_embeds: Optional[torch.Tensor] = None,
645
+ past_key_values: Optional[
646
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
647
+ ] = None,
648
+ labels: Optional[torch.Tensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ output_attentions: Optional[bool] = None,
651
+ output_hidden_states: Optional[bool] = None,
652
+ return_dict: Optional[bool] = None,
653
+ cache_position: Optional[torch.Tensor] = None,
654
+ ):
655
+ """Forward pass of Transformer decoder with embedding head. Pooled
656
+ embedding is extracted from EOS token."""
657
+
658
+ return_dict = (
659
+ return_dict
660
+ if return_dict is not None
661
+ else self.model_config.use_return_dict
662
+ )
663
+
664
+ if (
665
+ position_ids is not None
666
+ or inputs_embeds is not None
667
+ or past_key_values is not None
668
+ or labels is not None
669
+ or cache_position is not None
670
+ or use_cache
671
+ ):
672
+ raise ValueError("Provided args unsupported for embedding head")
673
+
674
+ _batch_size = input_ids.shape[0]
675
+ eos_mask = input_ids == self.config.eos_token_id
676
+ if not eos_mask.any(dim=1).all():
677
+ raise ValueError(
678
+ "Each sequence must contain at least one EOS token"
679
+ )
680
+ eos_pos = eos_mask.int().argmax(dim=1)
681
+
682
+ outputs = self.model(
683
+ input_ids,
684
+ attention_mask=attention_mask,
685
+ output_attentions=output_attentions,
686
+ output_hidden_states=output_hidden_states,
687
+ return_dict=return_dict,
688
+ use_cache=False,
689
+ )
690
+ hidden = outputs[0]
691
+ embedding = self.emb_head(hidden)
692
+ pooled_embedding = embedding[
693
+ torch.arange(_batch_size, device=input_ids.device), eos_pos
694
+ ]
695
+ if not return_dict:
696
+ output = (pooled_embedding,) + outputs[1:]
697
+ return output
698
+
699
+ return BaseModelOutputWithPoolingAndProjection(
700
+ last_hidden_state=embedding,
701
+ pooler_output=pooled_embedding,
702
+ hidden_states=outputs.hidden_states,
703
+ attentions=outputs.attentions,
704
+ )
705
+
706
+
707
+ def precompute_freqs_cis(
708
+ seq_len: int,
709
+ n_elem: int,
710
+ base: int = 500000,
711
+ dtype: torch.dtype = torch.bfloat16,
712
+ ):
713
+ freqs = 1.0 / (
714
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
715
+ )
716
+ t = torch.arange(seq_len, device=freqs.device)
717
+ freqs = torch.outer(t, freqs)
718
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
719
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
720
+
721
+ return cache.to(dtype=dtype)
722
+
723
+
724
+ @torch.jit.script
725
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
726
+ """
727
+ In-place RoPE. Credits to Katherine Crowson:
728
+ x shape (b_sz, s_len, n_head, d_head).
729
+ cos, sin shape (s_len, d_head // 2).
730
+ """
731
+
732
+ d = x.shape[-1] // 2
733
+ cos = freqs_cis[..., 0][None, :, None]
734
+ sin = freqs_cis[..., 1][None, :, None]
735
+ x1, x2 = x[..., :d], x[..., d : d * 2]
736
+ tmp = x1.clone()
737
+ x1.mul_(cos).addcmul_(x2, sin, value=-1)
738
+ x2.mul_(cos).addcmul_(tmp, sin, value=1)
739
+ return x
740
+
741
+
742
+ __all__ = [
743
+ "AriaPreTrainedModel",
744
+ "AriaModel",
745
+ "TransformerBlock",
746
+ "AriaForCausalLM",
747
+ "AriaForSequenceEmbeddings",
748
+ ]
tokenization_aria.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple
2
+
3
+ from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
4
+ from transformers.utils import logging, TensorType, to_py_obj
5
+
6
+ try:
7
+ from ariautils.midi import MidiDict
8
+ from ariautils.tokenizer import AbsTokenizer
9
+ from ariautils.tokenizer._base import Token
10
+ except ImportError:
11
+ raise ImportError(
12
+ "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class AriaTokenizer(PreTrainedTokenizer):
22
+ """
23
+ Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
24
+
25
+ For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
26
+ <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
27
+ This way, we expect a continuation that connects PROMPT and GUIDANCE.
28
+ """
29
+
30
+ vocab_files_names = {}
31
+ model_input_names = ["input_ids", "attention_mask"]
32
+
33
+ def __init__(
34
+ self,
35
+ add_bos_token=True,
36
+ add_eos_token=True,
37
+ add_dim_token=True,
38
+ clean_up_tokenization_spaces=False,
39
+ use_default_system_prompt=False,
40
+ **kwargs,
41
+ ):
42
+ self._tokenizer = AbsTokenizer()
43
+
44
+ self.add_bos_token = add_bos_token
45
+ self.add_eos_token = add_eos_token
46
+ self.add_dim_token = add_dim_token
47
+ self.use_default_system_prompt = use_default_system_prompt
48
+
49
+ bos_token = self._tokenizer.bos_tok
50
+ eos_token = self._tokenizer.eos_tok
51
+ pad_token = self._tokenizer.pad_tok
52
+ unk_token = self._tokenizer.unk_tok
53
+
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ unk_token=unk_token,
58
+ pad_token=pad_token,
59
+ use_default_system_prompt=use_default_system_prompt,
60
+ **kwargs,
61
+ )
62
+
63
+ def __getstate__(self):
64
+ return {}
65
+
66
+ def __setstate__(self, d):
67
+ raise NotImplementedError()
68
+
69
+ @property
70
+ def vocab_size(self):
71
+ """Returns vocab size"""
72
+ return self._tokenizer.vocab_size
73
+
74
+ def get_vocab(self):
75
+ return self._tokenizer.tok_to_id
76
+
77
+ def tokenize(
78
+ self,
79
+ midi_dict: MidiDict,
80
+ add_dim_tok: Optional[bool] = None,
81
+ add_eos_tok: Optional[bool] = None,
82
+ **kwargs,
83
+ ) -> List[Token]:
84
+ return self._tokenizer.tokenize(
85
+ midi_dict=midi_dict,
86
+ add_dim_tok=(
87
+ add_dim_tok if add_dim_tok is not None else self.add_dim_token
88
+ ),
89
+ add_eos_tok=(
90
+ add_eos_tok if add_eos_tok is not None else self.add_eos_token
91
+ ),
92
+ )
93
+
94
+ def _tokenize(
95
+ self,
96
+ midi_dict: MidiDict,
97
+ add_dim_tok: Optional[bool] = None,
98
+ add_eos_tok: Optional[bool] = None,
99
+ **kwargs,
100
+ ) -> List[Token]:
101
+ return self._tokenizer.tokenize(
102
+ midi_dict=midi_dict,
103
+ add_dim_tok=(
104
+ add_dim_tok if add_dim_tok is not None else self.add_dim_token
105
+ ),
106
+ add_eos_tok=(
107
+ add_eos_tok if add_eos_tok is not None else self.add_eos_token
108
+ ),
109
+ )
110
+
111
+ def __call__(
112
+ self,
113
+ midi_dicts: MidiDict | list[MidiDict],
114
+ padding: bool = False,
115
+ max_length: int | None = None,
116
+ pad_to_multiple_of: int | None = None,
117
+ return_tensors: str | TensorType | None = None,
118
+ return_attention_mask: bool | None = None,
119
+ **kwargs,
120
+ ) -> BatchEncoding:
121
+ """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
122
+ if isinstance(midi_dicts, MidiDict):
123
+ midi_dicts = [midi_dicts]
124
+
125
+ all_tokens: list[list[int]] = []
126
+ all_attn_masks: list[list[int]] = []
127
+ max_len_encoded = 0
128
+ for md in midi_dicts:
129
+ tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
130
+ if max_length is not None:
131
+ tokens = tokens[:max_length]
132
+ max_len_encoded = max(max_len_encoded, len(tokens))
133
+ all_tokens.append(tokens)
134
+ all_attn_masks.append([True] * len(tokens))
135
+
136
+ if pad_to_multiple_of is not None:
137
+ max_len_encoded = (
138
+ (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
139
+ ) * pad_to_multiple_of
140
+ if padding:
141
+ for tokens, attn_mask in zip(all_tokens, all_attn_masks):
142
+ tokens.extend(
143
+ [self._tokenizer.pad_id] * (max_len_encoded - len(tokens))
144
+ )
145
+ attn_mask.extend([False] * (max_len_encoded - len(tokens)))
146
+
147
+ return BatchEncoding(
148
+ {
149
+ "input_ids": all_tokens,
150
+ "attention_masks": all_attn_masks,
151
+ },
152
+ tensor_type=return_tensors,
153
+ )
154
+
155
+ def decode(self, token_ids: List[int], **kwargs) -> MidiDict:
156
+ token_ids = to_py_obj(token_ids)
157
+
158
+ return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
159
+
160
+ def batch_decode(
161
+ self, token_ids_list: List[List[Token]], **kwargs
162
+ ) -> List[MidiDict]:
163
+ results = []
164
+ for token_ids in token_ids_list:
165
+ results.append(self.decode(token_ids))
166
+ return results
167
+
168
+ def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
169
+ midi_dict = MidiDict.from_midi(filename)
170
+ return self(midi_dict, **kwargs)
171
+
172
+ def encode_from_files(
173
+ self, filenames: list[str], **kwargs
174
+ ) -> BatchEncoding:
175
+ midi_dicts = [MidiDict.from_midi(file) for file in filenames]
176
+ return self(midi_dicts, **kwargs)
177
+
178
+ def _convert_token_to_id(self, token: Token):
179
+ """Converts a token (tuple or str) into an id."""
180
+ return self._tokenizer.tok_to_id.get(
181
+ token, self._tokenizer.tok_to_id[self.unk_token]
182
+ )
183
+
184
+ def _convert_id_to_token(self, index: int):
185
+ """Converts an index (integer) in a token (tuple or str)."""
186
+ return self._tokenizer.id_to_tok.get(index, self.unk_token)
187
+
188
+ def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
189
+ """Converts a sequence of tokens into a single MidiDict."""
190
+ return self._tokenizer.detokenize(tokens)
191
+
192
+ def save_vocabulary(
193
+ self, save_directory, filename_prefix: Optional[str] = None
194
+ ) -> Tuple[str]:
195
+ raise NotImplementedError()
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_aria.AriaTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "tokenizer_class": "AriaTokenizer"
11
+ }