zaydzuhri commited on
Commit
f91217a
·
verified ·
1 Parent(s): 400ed77

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  3. fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
  4. fla/models/abc/configuration_abc.py +91 -0
  5. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  6. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc +0 -0
  7. fla/models/delta_net/modeling_delta_net.py +415 -0
  8. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  9. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
  10. fla/models/forgetting_transformer/configuration_forgetting_transformer.py +68 -0
  11. fla/models/gated_deltanet/__init__.py +12 -0
  12. fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  13. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc +0 -0
  14. fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc +0 -0
  15. fla/models/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  16. fla/models/gsa/__init__.py +13 -0
  17. fla/models/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  19. fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc +0 -0
  20. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc +0 -0
  21. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc +0 -0
  22. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc +0 -0
  23. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  24. fla/models/hgrn2/configuration_hgrn2.py +91 -0
  25. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  26. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  27. fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  28. fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc +0 -0
  29. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
  30. fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc +0 -0
  31. fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  32. fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc +0 -0
  33. fla/models/retnet/__pycache__/__init__.cpython-312.pyc +0 -0
  34. fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc +0 -0
  35. fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc +0 -0
  36. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  37. fla/models/samba/__pycache__/__init__.cpython-312.pyc +0 -0
  38. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  39. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  40. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  41. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  42. fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  43. fla/models/transformer/modeling_transformer.py +406 -0
  44. fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  45. fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  46. fla/models/transformer_mtp/configuration_transformer.py +76 -0
  47. fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc +0 -0
  48. fla/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  49. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  50. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tb/20250722-0737/wandb/run-20250722_073713-mtp_transformer-mtp.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine-202507220732/run-mtp_transformer-mtp.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine-202507220732.wandb filter=lfs diff=lfs merge=lfs -text
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_rope: bool = True,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.hidden_size = hidden_size
47
+ self.gate_low_rank_dim = gate_low_rank_dim
48
+ self.clamp_min = clamp_min
49
+ self.clamp_max = clamp_max
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_slots = num_slots
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.expand_k = exapnd_k
58
+ self.expand_v = exapnd_v
59
+ self.hidden_act = hidden_act
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.elementwise_affine = elementwise_affine
62
+ self.norm_eps = norm_eps
63
+ self.use_rope = use_rope
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/delta_net/modeling_delta_net.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.delta_net import DeltaNet
20
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as DeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class DeltaNetBlock(nn.Module):
33
+ def __init__(self, config: DeltaNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = DeltaNet(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ use_gate=config.use_gate,
59
+ use_beta=config.use_beta,
60
+ use_short_conv=config.use_short_conv,
61
+ use_output_norm=config.use_output_norm,
62
+ conv_size=config.conv_size,
63
+ qk_norm=config.qk_norm,
64
+ qk_activation=config.qk_activation,
65
+ norm_eps=config.norm_eps,
66
+ layer_idx=layer_idx
67
+ )
68
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
69
+ self.mlp = DeltaNetMLP(
70
+ hidden_size=config.hidden_size,
71
+ hidden_ratio=config.hidden_ratio,
72
+ intermediate_size=config.intermediate_size,
73
+ hidden_act=config.hidden_act,
74
+ fuse_swiglu=config.fuse_swiglu
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
82
+ use_cache: Optional[bool] = False,
83
+ output_attentions: Optional[bool] = False,
84
+ **kwargs: Unpack[Dict]
85
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
86
+ residual = hidden_states
87
+ hidden_states = self.attn_norm(hidden_states)
88
+ hidden_states, attentions, past_key_values = self.attn(
89
+ hidden_states=hidden_states,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ **kwargs
95
+ )
96
+ if self.config.fuse_norm:
97
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
98
+ else:
99
+ hidden_states = residual + hidden_states
100
+ residual = hidden_states
101
+ hidden_states = self.mlp_norm(hidden_states)
102
+ hidden_states = self.mlp(hidden_states, **kwargs)
103
+ hidden_states = residual + hidden_states
104
+
105
+ outputs = (hidden_states, attentions, past_key_values)
106
+
107
+ return outputs
108
+
109
+
110
+ class DeltaNetPreTrainedModel(PreTrainedModel):
111
+
112
+ config_class = DeltaNetConfig
113
+ base_model_prefix = 'model'
114
+ supports_gradient_checkpointing = True
115
+ _no_split_modules = ['DeltaNetBlock']
116
+ _supports_cache_class = True
117
+
118
+ def __init__(self, *inputs, **kwargs):
119
+ super().__init__(*inputs, **kwargs)
120
+
121
+ def _init_weights(
122
+ self,
123
+ module: nn.Module,
124
+ prenorm_residual_strategy: Optional[str] = 'rescale',
125
+ num_residuals_per_layer: int = 2,
126
+ ):
127
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
128
+ # Slightly different from the TF version which uses truncated_normal for initialization
129
+ # cf https://github.com/pytorch/pytorch/pull/5617
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ if module.bias is not None:
132
+ nn.init.zeros_(module.bias)
133
+ elif isinstance(module, nn.Embedding):
134
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
135
+ elif hasattr(module, 'reset_parameters'):
136
+ module.reset_parameters()
137
+
138
+ if prenorm_residual_strategy is not None:
139
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
140
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
141
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
142
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
143
+ #
144
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
145
+ p = None
146
+ if hasattr(module, 'o_proj'):
147
+ p = module.o_proj.weight
148
+ elif hasattr(module, 'down_proj'):
149
+ p = module.down_proj.weight
150
+ if p is not None:
151
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
152
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
153
+ # We need to reinit p since this code could be called multiple times
154
+ # Having just p *= scale would repeatedly scale it down
155
+ if prenorm_residual_strategy == 'rescale':
156
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
157
+ with torch.no_grad():
158
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
159
+ elif prenorm_residual_strategy == 'zero':
160
+ nn.init.zeros_(p)
161
+ else:
162
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
163
+
164
+
165
+ class DeltaNetModel(DeltaNetPreTrainedModel):
166
+
167
+ def __init__(self, config: DeltaNetConfig):
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None, # noqa
190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
191
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Dict]
197
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
200
+ output_attentions = False
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
202
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ # retrieve input_ids and inputs_embeds
207
+ if input_ids is not None and inputs_embeds is not None:
208
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
209
+ if input_ids is None and inputs_embeds is None:
210
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
211
+
212
+ if inputs_embeds is None:
213
+ inputs_embeds = self.embeddings(input_ids)
214
+ hidden_states = inputs_embeds
215
+
216
+ if use_cache and not isinstance(past_key_values, Cache):
217
+ past_key_values = Cache.from_legacy_cache(past_key_values)
218
+
219
+ if self.gradient_checkpointing and self.training and use_cache:
220
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ for layer in self.layers:
226
+ if output_hidden_states:
227
+ all_hidden_states += (hidden_states,)
228
+
229
+ if self.gradient_checkpointing and self.training:
230
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
231
+ layer.__call__,
232
+ hidden_states,
233
+ attention_mask,
234
+ past_key_values,
235
+ use_cache,
236
+ output_attentions,
237
+ **kwargs
238
+ )
239
+ else:
240
+ hidden_states, attentions, past_key_values = layer(
241
+ hidden_states,
242
+ attention_mask=attention_mask,
243
+ past_key_values=past_key_values,
244
+ use_cache=use_cache,
245
+ output_attentions=output_attentions,
246
+ **kwargs
247
+ )
248
+
249
+ if output_attentions:
250
+ all_attns += (attentions,)
251
+
252
+ hidden_states = self.norm(hidden_states)
253
+
254
+ # add hidden states from the last decoder layer
255
+ if output_hidden_states:
256
+ all_hidden_states += (hidden_states,)
257
+
258
+ if not return_dict:
259
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
260
+ return BaseModelOutputWithPast(
261
+ last_hidden_state=hidden_states,
262
+ past_key_values=past_key_values,
263
+ hidden_states=all_hidden_states,
264
+ attentions=all_attns
265
+ )
266
+
267
+
268
+ class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
269
+
270
+ _tied_weights_keys = ["lm_head.weight"]
271
+
272
+ def __init__(self, config):
273
+ super().__init__(config)
274
+ self.model = DeltaNetModel(config)
275
+ self.vocab_size = config.vocab_size
276
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
277
+ self.criterion = None
278
+
279
+ # Initialize weights and apply final processing
280
+ self.post_init()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.model.embeddings
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.model.embeddings = value
287
+
288
+ def get_output_embeddings(self):
289
+ return self.lm_head
290
+
291
+ def set_output_embeddings(self, new_embeddings):
292
+ self.lm_head = new_embeddings
293
+
294
+ def set_decoder(self, decoder):
295
+ self.model = decoder
296
+
297
+ def get_decoder(self):
298
+ return self.model
299
+
300
+ def generate(self, *args, **kwargs):
301
+ try:
302
+ return super().generate(*args, **kwargs)
303
+ except AttributeError as exception:
304
+ if 'past_key_values' in str(exception):
305
+ raise AttributeError(
306
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
307
+ f"which is not supported for {self.__class__.__name__}. "
308
+ f"Try another generation strategy instead. "
309
+ f"For the available generation strategies, check this doc: "
310
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
311
+ )
312
+ else:
313
+ raise exception
314
+
315
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
316
+ def prepare_inputs_for_generation(
317
+ self,
318
+ input_ids: torch.LongTensor = None,
319
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ inputs_embeds: Optional[torch.Tensor] = None,
322
+ use_cache: bool = True,
323
+ logits_to_keep: Optional[int] = None,
324
+ **kwargs
325
+ ):
326
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
327
+ if past_key_values is not None and len(past_key_values) > 0:
328
+ input_ids = input_ids[:, -1:]
329
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
330
+ if inputs_embeds is not None and len(past_key_values) == 0:
331
+ model_inputs = {'inputs_embeds': inputs_embeds}
332
+ else:
333
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
334
+ # recompiles graphs as the stride of the inputs is a guard.
335
+ # Ref: https://github.com/huggingface/transformers/pull/29114
336
+ # TODO: use `next_tokens` directly instead.
337
+ model_inputs = {'input_ids': input_ids.contiguous()}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'past_key_values': past_key_values,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ })
347
+ return model_inputs
348
+
349
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ inputs_embeds: Optional[torch.Tensor] = None,
355
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
356
+ labels: Optional[torch.LongTensor] = None,
357
+ use_cache: Optional[bool] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
364
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
+ output_hidden_states = (
366
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
+ )
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+
370
+ outputs = self.model(
371
+ input_ids=input_ids,
372
+ attention_mask=attention_mask,
373
+ inputs_embeds=inputs_embeds,
374
+ past_key_values=past_key_values,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ **kwargs
380
+ )
381
+
382
+ hidden_states = outputs[0]
383
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
384
+
385
+ loss, logits = None, None
386
+ if not fuse_linear_and_cross_entropy or labels is None:
387
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
388
+ if labels is not None:
389
+ if getattr(self, 'criterion', None) is None:
390
+ if fuse_linear_and_cross_entropy:
391
+ criterion = FusedLinearCrossEntropyLoss()
392
+ elif self.config.fuse_cross_entropy:
393
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
394
+ else:
395
+ criterion = nn.CrossEntropyLoss()
396
+ else:
397
+ criterion = self.criterion
398
+ labels = labels.to(hidden_states.device)
399
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
400
+ if fuse_linear_and_cross_entropy:
401
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
402
+ else:
403
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
404
+
405
+ if not return_dict:
406
+ output = (logits,) + outputs[1:]
407
+ return (loss,) + output if loss is not None else output
408
+
409
+ return CausalLMOutputWithPast(
410
+ loss=loss,
411
+ logits=logits,
412
+ past_key_values=outputs.past_key_values,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
+ )
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla/models/forgetting_transformer/configuration_forgetting_transformer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ForgettingTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'forgetting_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: Optional[int] = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ use_output_gate: bool = False,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ initializer_range: float = 0.006,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ use_cache: bool = True,
30
+ pad_token_id: Optional[int] = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ fuse_norm: bool = True,
35
+ fuse_swiglu: bool = True,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ **kwargs,
39
+ ):
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_heads = num_heads
43
+ self.num_kv_heads = num_kv_heads
44
+ self.qkv_bias = qkv_bias
45
+ self.qk_norm = qk_norm
46
+ self.window_size = window_size
47
+ self.use_output_gate = use_output_gate
48
+ self.hidden_ratio = hidden_ratio
49
+ self.intermediate_size = intermediate_size
50
+ self.hidden_act = hidden_act
51
+
52
+ self.initializer_range = initializer_range
53
+ self.elementwise_affine = elementwise_affine
54
+ self.norm_eps = norm_eps
55
+ self.use_cache = use_cache
56
+
57
+ self.fuse_norm = fuse_norm
58
+ self.fuse_swiglu = fuse_swiglu
59
+ self.fuse_cross_entropy = fuse_cross_entropy
60
+ self.vocab_size = vocab_size
61
+
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ bos_token_id=bos_token_id,
65
+ eos_token_id=eos_token_id,
66
+ tie_word_embeddings=tie_word_embeddings,
67
+ **kwargs,
68
+ )
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (746 Bytes). View file
 
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (20.7 kB). View file
 
fla/models/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc ADDED
Binary file (3.28 kB). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/hgrn2/configuration_hgrn2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRN2Config(PretrainedConfig):
9
+
10
+ model_type = 'hgrn2'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ use_lower_bound: bool = True,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ max_position_embeddings: int = 2048,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.use_lower_bound = use_lower_bound
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.hidden_ratio = hidden_ratio
60
+ self.intermediate_size = intermediate_size
61
+ self.hidden_act = hidden_act
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.06 kB). View file
 
fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc ADDED
Binary file (41.5 kB). View file
 
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc ADDED
Binary file (52.4 kB). View file
 
fla/models/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
fla/models/retnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc ADDED
Binary file (3.73 kB). View file
 
fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/samba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (717 Bytes). View file
 
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.39 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
fla/models/transformer/modeling_transformer.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.models.transformer.configuration_transformer import TransformerConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as TransformerMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class TransformerBlock(nn.Module):
33
+
34
+ def __init__(self, config: TransformerConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.num_heads,
44
+ num_kv_heads=config.num_kv_heads,
45
+ qkv_bias=config.qkv_bias,
46
+ qk_norm=config.qk_norm,
47
+ window_size=config.window_size,
48
+ rope_theta=config.rope_theta,
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+
53
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
54
+ self.mlp = TransformerMLP(
55
+ hidden_size=config.hidden_size,
56
+ hidden_ratio=config.hidden_ratio,
57
+ intermediate_size=config.intermediate_size,
58
+ hidden_act=config.hidden_act,
59
+ fuse_swiglu=config.fuse_swiglu
60
+ )
61
+
62
+ def forward(
63
+ self,
64
+ hidden_states: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
67
+ output_attentions: Optional[bool] = False,
68
+ use_cache: Optional[bool] = False,
69
+ **kwargs: Unpack[Any]
70
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
71
+
72
+ residual = hidden_states
73
+ hidden_states = self.attn_norm(hidden_states)
74
+ hidden_states, attentions, past_key_values = self.attn(
75
+ hidden_states=hidden_states,
76
+ attention_mask=attention_mask,
77
+ past_key_values=past_key_values,
78
+ use_cache=use_cache,
79
+ output_attentions=output_attentions,
80
+ **kwargs
81
+ )
82
+ if self.config.fuse_norm:
83
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
84
+ else:
85
+ hidden_states = residual + hidden_states
86
+ residual = hidden_states
87
+ hidden_states = self.mlp_norm(hidden_states)
88
+ hidden_states = self.mlp(hidden_states, **kwargs)
89
+ hidden_states = residual + hidden_states
90
+
91
+ outputs = (hidden_states,)
92
+
93
+ if output_attentions:
94
+ outputs += (attentions,)
95
+
96
+ if use_cache:
97
+ outputs += (past_key_values,)
98
+
99
+ return outputs
100
+
101
+
102
+ class TransformerPreTrainedModel(PreTrainedModel):
103
+
104
+ config_class = TransformerConfig
105
+ base_model_prefix = 'model'
106
+ supports_gradient_checkpointing = True
107
+ _no_split_modules = ['TransformerBlock']
108
+ _supports_cache_class = True
109
+
110
+ def __init__(self, *inputs, **kwargs):
111
+ super().__init__(*inputs, **kwargs)
112
+
113
+ def _init_weights(
114
+ self,
115
+ module: nn.Module,
116
+ rescale_prenorm_residual: bool = False,
117
+ num_residuals_per_layer: int = 2,
118
+ ):
119
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
120
+ # Slightly different from the TF version which uses truncated_normal for initialization
121
+ # cf https://github.com/pytorch/pytorch/pull/5617
122
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ elif isinstance(module, nn.Embedding):
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ elif hasattr(module, 'reset_parameters'):
128
+ module.reset_parameters()
129
+
130
+ if rescale_prenorm_residual:
131
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
132
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
133
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
134
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
135
+ #
136
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
137
+ p = None
138
+ if hasattr(module, 'o_proj'):
139
+ p = module.o_proj.weight
140
+ elif hasattr(module, 'down_proj'):
141
+ p = module.down_proj.weight
142
+ if p is not None:
143
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
144
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
145
+ # We need to reinit p since this code could be called multiple times
146
+ # Having just p *= scale would repeatedly scale it down
147
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
148
+ with torch.no_grad():
149
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
150
+
151
+
152
+ class TransformerModel(TransformerPreTrainedModel):
153
+
154
+ def __init__(
155
+ self,
156
+ config: TransformerConfig
157
+ ) -> TransformerModel:
158
+ super().__init__(config)
159
+ self.padding_idx = config.pad_token_id
160
+ self.vocab_size = config.vocab_size
161
+
162
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
163
+ self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
164
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
165
+
166
+ self.gradient_checkpointing = False
167
+
168
+ self.post_init()
169
+
170
+ def get_input_embeddings(self):
171
+ return self.embeddings
172
+
173
+ def set_input_embeddings(self, value):
174
+ self.embeddings = value
175
+
176
+ def forward(
177
+ self,
178
+ input_ids: Optional[torch.LongTensor] = None,
179
+ attention_mask: Optional[torch.Tensor] = None,
180
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
181
+ inputs_embeds: Optional[torch.FloatTensor] = None,
182
+ use_cache: Optional[bool] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ return_dict: Optional[bool] = None,
186
+ **kwargs: Unpack[Any]
187
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
188
+ if output_attentions:
189
+ warnings.warn(
190
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
191
+ )
192
+ output_attentions = False
193
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
194
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
195
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
+
198
+ # retrieve input_ids and inputs_embeds
199
+ if input_ids is not None and inputs_embeds is not None:
200
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
201
+ elif input_ids is None and inputs_embeds is None:
202
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
203
+
204
+ if use_cache and not isinstance(past_key_values, Cache):
205
+ past_key_values = Cache.from_legacy_cache(past_key_values)
206
+
207
+ if inputs_embeds is None:
208
+ inputs_embeds = self.embeddings(input_ids)
209
+
210
+ # embed positions
211
+ hidden_states = inputs_embeds
212
+
213
+ if self.gradient_checkpointing and self.training:
214
+ if use_cache:
215
+ logger.warning_once(
216
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
217
+ )
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ next_cache = None
223
+
224
+ for layer in self.layers:
225
+ if output_hidden_states:
226
+ all_hidden_states += (hidden_states,)
227
+
228
+ if self.gradient_checkpointing and self.training:
229
+ layer_outputs = self._gradient_checkpointing_func(
230
+ layer.__call__,
231
+ hidden_states,
232
+ attention_mask,
233
+ past_key_values,
234
+ output_attentions,
235
+ use_cache,
236
+ **kwargs
237
+ )
238
+ else:
239
+ layer_outputs = layer(
240
+ hidden_states,
241
+ attention_mask=attention_mask,
242
+ past_key_values=past_key_values,
243
+ output_attentions=output_attentions,
244
+ use_cache=use_cache,
245
+ **kwargs
246
+ )
247
+
248
+ hidden_states = layer_outputs[0]
249
+
250
+ if use_cache:
251
+ next_cache = layer_outputs[2 if output_attentions else 1]
252
+
253
+ if output_attentions:
254
+ all_attns += (layer_outputs[1],)
255
+
256
+ hidden_states = self.norm(hidden_states)
257
+
258
+ # add hidden states from the last decoder layer
259
+ if output_hidden_states:
260
+ all_hidden_states += (hidden_states,)
261
+
262
+ if not return_dict:
263
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
264
+
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=next_cache,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_attns
270
+ )
271
+
272
+
273
+ class TransformerForCausalLM(TransformerPreTrainedModel, GenerationMixin):
274
+
275
+ _tied_weights_keys = ["lm_head.weight"]
276
+
277
+ def __init__(self, config):
278
+ super().__init__(config)
279
+ self.model = TransformerModel(config)
280
+ self.vocab_size = config.vocab_size
281
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
282
+ self.criterion = None
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
306
+ def prepare_inputs_for_generation(
307
+ self,
308
+ input_ids: torch.LongTensor = None,
309
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ inputs_embeds: Optional[torch.Tensor] = None,
312
+ use_cache: bool = True,
313
+ logits_to_keep: Optional[int] = None,
314
+ **kwargs
315
+ ):
316
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
317
+ if past_key_values is not None and len(past_key_values) > 0:
318
+ input_ids = input_ids[:, -1:]
319
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
320
+ if inputs_embeds is not None and len(past_key_values) == 0:
321
+ model_inputs = {'inputs_embeds': inputs_embeds}
322
+ else:
323
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
324
+ # recompiles graphs as the stride of the inputs is a guard.
325
+ # Ref: https://github.com/huggingface/transformers/pull/29114
326
+ # TODO: use `next_tokens` directly instead.
327
+ model_inputs = {'input_ids': input_ids.contiguous()}
328
+
329
+ if logits_to_keep is not None:
330
+ model_inputs['logits_to_keep'] = logits_to_keep
331
+
332
+ model_inputs.update({
333
+ 'past_key_values': past_key_values,
334
+ 'use_cache': use_cache,
335
+ 'attention_mask': attention_mask,
336
+ })
337
+ return model_inputs
338
+
339
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
340
+ def forward(
341
+ self,
342
+ input_ids: torch.LongTensor = None,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ labels: Optional[torch.LongTensor] = None,
347
+ use_cache: Optional[bool] = None,
348
+ output_attentions: Optional[bool] = None,
349
+ output_hidden_states: Optional[bool] = None,
350
+ return_dict: Optional[bool] = None,
351
+ logits_to_keep: Optional[int] = 0,
352
+ **kwargs: Unpack[Any]
353
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
354
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
355
+ output_hidden_states = (
356
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
357
+ )
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ outputs = self.model(
361
+ input_ids=input_ids,
362
+ attention_mask=attention_mask,
363
+ past_key_values=past_key_values,
364
+ inputs_embeds=inputs_embeds,
365
+ use_cache=use_cache,
366
+ output_attentions=output_attentions,
367
+ output_hidden_states=output_hidden_states,
368
+ return_dict=return_dict,
369
+ **kwargs
370
+ )
371
+
372
+ hidden_states = outputs[0]
373
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
374
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
375
+
376
+ loss = None
377
+ if labels is not None:
378
+ if getattr(self, 'criterion', None) is None:
379
+ if fuse_linear_and_cross_entropy:
380
+ criterion = FusedLinearCrossEntropyLoss()
381
+ elif self.config.fuse_cross_entropy:
382
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
383
+ else:
384
+ criterion = nn.CrossEntropyLoss()
385
+ else:
386
+ criterion = self.criterion
387
+ # Enable model parallelism
388
+ labels = labels.to(hidden_states.device)
389
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
390
+ labels = labels[..., :hidden_states.shape[1]].contiguous()
391
+ if fuse_linear_and_cross_entropy:
392
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
393
+ else:
394
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
395
+
396
+ if not return_dict:
397
+ output = (logits,) + outputs[1:]
398
+ return (loss,) + output if loss is not None else output
399
+
400
+ return CausalLMOutputWithPast(
401
+ loss=loss,
402
+ logits=logits,
403
+ past_key_values=outputs.past_key_values,
404
+ hidden_states=outputs.hidden_states,
405
+ attentions=outputs.attentions,
406
+ )
fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.69 kB). View file
 
fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (24.7 kB). View file
 
fla/models/transformer_mtp/configuration_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class MTPTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'mtp_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ n_future_tokens: int = 1,
40
+ use_custom_backward: Optional[bool] = False,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.qkv_bias = qkv_bias
48
+ self.qk_norm = qk_norm
49
+ self.window_size = window_size
50
+ self.rope_theta = rope_theta
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+
57
+ self.initializer_range = initializer_range
58
+ self.elementwise_affine = elementwise_affine
59
+ self.norm_eps = norm_eps
60
+ self.use_cache = use_cache
61
+
62
+ self.fuse_norm = fuse_norm
63
+ self.fuse_swiglu = fuse_swiglu
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+ self.vocab_size = vocab_size
66
+
67
+ self.n_future_tokens = n_future_tokens
68
+ self.use_custom_backward = use_custom_backward
69
+
70
+ super().__init__(
71
+ pad_token_id=pad_token_id,
72
+ bos_token_id=bos_token_id,
73
+ eos_token_id=eos_token_id,
74
+ tie_word_embeddings=tie_word_embeddings,
75
+ **kwargs,
76
+ )
fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (749 Bytes). View file
 
fla/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.35 kB). View file
 
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21 kB). View file