zaydzuhri commited on
Commit
3c70147
·
verified ·
1 Parent(s): adbece6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  2. fla/models/bitnet/modeling_bitnet.py +441 -0
  3. fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  4. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc +0 -0
  5. fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc +0 -0
  6. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc +0 -0
  7. fla/models/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  8. fla/models/gla/__pycache__/configuration_gla.cpython-312.pyc +0 -0
  9. fla/models/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  10. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  11. fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc +0 -0
  12. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc +0 -0
  13. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc +0 -0
  14. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  15. fla/models/hgrn2/configuration_hgrn2.py +91 -0
  16. fla/models/lightnet/__pycache__/__init__.cpython-312.pyc +0 -0
  17. fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/models/mamba/__pycache__/__init__.cpython-312.pyc +0 -0
  19. fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  20. fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc +0 -0
  21. fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc +0 -0
  22. fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc +0 -0
  23. fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
  24. fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc +0 -0
  25. fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc +0 -0
  26. fla/models/rwkv6/configuration_rwkv6.py +82 -0
  27. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  28. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  29. fla/models/transformer/__init__.py +13 -0
  30. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  31. fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  32. fla/models/transformer/configuration_transformer.py +71 -0
  33. fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc +0 -0
  34. fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  35. fla/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  36. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
  37. fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc +0 -0
  38. fla/modules/__pycache__/fused_cross_entropy.cpython-312.pyc +0 -0
  39. fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
  40. fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc +0 -0
  41. fla/modules/__pycache__/fused_norm_gate.cpython-312.pyc +0 -0
  42. fla/modules/__pycache__/l2norm.cpython-312.pyc +0 -0
  43. fla/modules/__pycache__/layernorm.cpython-312.pyc +0 -0
  44. fla/modules/__pycache__/layernorm_gated.cpython-312.pyc +0 -0
  45. fla/modules/__pycache__/parallel.cpython-312.pyc +0 -0
  46. fla/modules/__pycache__/rotary.cpython-312.pyc +0 -0
  47. fla/modules/__pycache__/seq_to_top.cpython-312.pyc +0 -0
  48. logs/none_yagntt11/attempt_0/0/stdout.log +0 -0
  49. logs/none_yagntt11/attempt_0/1/stdout.log +0 -0
  50. logs/none_yagntt11/attempt_0/2/stdout.log +0 -0
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (817 Bytes). View file
 
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (777 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (3.38 kB). View file
 
fla/models/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gla/__pycache__/configuration_gla.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
fla/models/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/hgrn2/configuration_hgrn2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRN2Config(PretrainedConfig):
9
+
10
+ model_type = 'hgrn2'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ use_lower_bound: bool = True,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ max_position_embeddings: int = 2048,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.use_lower_bound = use_lower_bound
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.hidden_ratio = hidden_ratio
60
+ self.intermediate_size = intermediate_size
61
+ self.hidden_act = hidden_act
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/lightnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (699 Bytes). View file
 
fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (737 Bytes). View file
 
fla/models/mamba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (717 Bytes). View file
 
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.06 kB). View file
 
fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc ADDED
Binary file (41.5 kB). View file
 
fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc ADDED
Binary file (2.64 kB). View file
 
fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc ADDED
Binary file (3.32 kB). View file
 
fla/models/rwkv6/configuration_rwkv6.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV6Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv6'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 0.5,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 3.5,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ proj_low_rank_dim: int = 32,
24
+ gate_low_rank_dim: int = 64,
25
+ hidden_act: str = "sqrelu",
26
+ max_position_embeddings: int = 2048,
27
+ norm_first: bool = True,
28
+ norm_bias: bool = True,
29
+ norm_eps: float = 1e-5,
30
+ attn: Optional[Dict] = None,
31
+ use_cache: bool = True,
32
+ pad_token_id: int = None,
33
+ bos_token_id: int = 1,
34
+ eos_token_id: int = 2,
35
+ tie_word_embeddings: bool = False,
36
+ initializer_range: float = 0.006,
37
+ fuse_norm: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.attn_mode = attn_mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.hidden_ratio = hidden_ratio
47
+ self.intermediate_size = intermediate_size
48
+ self.norm_first = norm_first
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_heads = num_heads
51
+ self.proj_low_rank_dim = proj_low_rank_dim
52
+ self.gate_low_rank_dim = gate_low_rank_dim
53
+ self.hidden_act = hidden_act
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.norm_bias = norm_bias
56
+ self.norm_eps = norm_eps
57
+ self.attn = attn
58
+ self.use_cache = use_cache
59
+ self.initializer_range = initializer_range
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_cross_entropy = fuse_cross_entropy
62
+ self.vocab_size = vocab_size
63
+
64
+ if attn is not None:
65
+ if not isinstance(attn, Dict):
66
+ raise ValueError("attn must be a dictionary")
67
+ if 'layers' not in attn:
68
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
69
+ if 'num_heads' not in attn:
70
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
71
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
72
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
73
+ attn['window_size'] = attn.get('window_size', None)
74
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.39 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer.configuration_transformer import TransformerConfig
6
+ from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel
7
+
8
+ AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
9
+ AutoModel.register(TransformerConfig, TransformerModel)
10
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
fla/models/transformer/configuration_transformer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class TransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.num_heads = num_heads
44
+ self.num_kv_heads = num_kv_heads
45
+ self.qkv_bias = qkv_bias
46
+ self.qk_norm = qk_norm
47
+ self.window_size = window_size
48
+ self.rope_theta = rope_theta
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+
55
+ self.initializer_range = initializer_range
56
+ self.elementwise_affine = elementwise_affine
57
+ self.norm_eps = norm_eps
58
+ self.use_cache = use_cache
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ super().__init__(
66
+ pad_token_id=pad_token_id,
67
+ bos_token_id=bos_token_id,
68
+ eos_token_id=eos_token_id,
69
+ tie_word_embeddings=tie_word_embeddings,
70
+ **kwargs,
71
+ )
fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (795 Bytes). View file
 
fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.69 kB). View file
 
fla/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.35 kB). View file
 
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21 kB). View file
 
fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc ADDED
Binary file (23.6 kB). View file
 
fla/modules/__pycache__/fused_cross_entropy.cpython-312.pyc ADDED
Binary file (16 kB). View file
 
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
fla/modules/__pycache__/fused_norm_gate.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
fla/modules/__pycache__/l2norm.cpython-312.pyc ADDED
Binary file (6.96 kB). View file
 
fla/modules/__pycache__/layernorm.cpython-312.pyc ADDED
Binary file (43.4 kB). View file
 
fla/modules/__pycache__/layernorm_gated.cpython-312.pyc ADDED
Binary file (23.5 kB). View file
 
fla/modules/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (2.04 kB). View file
 
fla/modules/__pycache__/rotary.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
fla/modules/__pycache__/seq_to_top.cpython-312.pyc ADDED
Binary file (4.08 kB). View file
 
logs/none_yagntt11/attempt_0/0/stdout.log ADDED
File without changes
logs/none_yagntt11/attempt_0/1/stdout.log ADDED
File without changes
logs/none_yagntt11/attempt_0/2/stdout.log ADDED
File without changes