zaydzuhri commited on
Commit
ecdc989
·
verified ·
1 Parent(s): 6f577e3

Add files using upload-large-folder tool

Browse files
Files changed (46) hide show
  1. .gitattributes +10 -0
  2. fla/models/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
  4. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  5. fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc +0 -0
  6. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  7. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
  8. fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  9. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  10. fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc +0 -0
  11. fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
  12. fla/models/gla/modeling_gla.py +417 -0
  13. fla/models/gsa/configuration_gsa.py +97 -0
  14. fla/models/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  15. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc +0 -0
  16. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  17. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  18. fla/models/lightnet/modeling_lightnet.py +410 -0
  19. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
  20. fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
  21. fla/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
  22. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
  23. fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  24. fla/models/nsa/modeling_nsa.py +398 -0
  25. fla/models/retnet/configuration_retnet.py +92 -0
  26. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc +0 -0
  29. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  30. fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  31. fla/models/transformer_mtp/modeling_transformer.py +608 -0
  32. fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  33. fla/modules/__pycache__/feature_map.cpython-312.pyc +0 -0
  34. logs/none_yagntt11/attempt_0/0/stderr.log +3 -0
  35. logs/none_yagntt11/attempt_0/1/stderr.log +3 -0
  36. logs/none_yagntt11/attempt_0/2/stderr.log +3 -0
  37. logs/none_yagntt11/attempt_0/3/stderr.log +3 -0
  38. logs/none_yagntt11/attempt_0/4/stderr.log +3 -0
  39. logs/none_yagntt11/attempt_0/5/stderr.log +3 -0
  40. logs/none_yagntt11/attempt_0/5/stdout.log +0 -0
  41. logs/none_yagntt11/attempt_0/6/stderr.log +3 -0
  42. logs/none_yagntt11/attempt_0/7/stderr.log +3 -0
  43. model-00001-of-00002.safetensors +3 -0
  44. model-00002-of-00002.safetensors +3 -0
  45. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log +3 -0
  46. tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb +3 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ logs/none_yagntt11/attempt_0/4/stderr.log filter=lfs diff=lfs merge=lfs -text
37
+ logs/none_yagntt11/attempt_0/1/stderr.log filter=lfs diff=lfs merge=lfs -text
38
+ logs/none_yagntt11/attempt_0/2/stderr.log filter=lfs diff=lfs merge=lfs -text
39
+ logs/none_yagntt11/attempt_0/5/stderr.log filter=lfs diff=lfs merge=lfs -text
40
+ logs/none_yagntt11/attempt_0/3/stderr.log filter=lfs diff=lfs merge=lfs -text
41
+ logs/none_yagntt11/attempt_0/7/stderr.log filter=lfs diff=lfs merge=lfs -text
42
+ logs/none_yagntt11/attempt_0/6/stderr.log filter=lfs diff=lfs merge=lfs -text
43
+ tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log filter=lfs diff=lfs merge=lfs -text
44
+ logs/none_yagntt11/attempt_0/0/stderr.log filter=lfs diff=lfs merge=lfs -text
45
+ tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb filter=lfs diff=lfs merge=lfs -text
fla/models/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (701 Bytes). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (746 Bytes). View file
 
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (20.7 kB). View file
 
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gated_deltaproduct import GatedDeltaProduct
21
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
24
+ from fla.modules.activations import swiglu_linear
25
+ from fla.modules.layernorm import rms_norm_linear
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers.processing_utils import Unpack
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetMLP(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ hidden_ratio: Optional[int] = None,
38
+ intermediate_size: Optional[int] = None,
39
+ hidden_act: str = "swish",
40
+ norm_first: bool = True,
41
+ norm_eps: float = 1e-5,
42
+ ) -> GatedDeltaNetMLP:
43
+ super().__init__()
44
+
45
+ self.hidden_size = hidden_size
46
+ # the final number of params is `hidden_ratio * hidden_size^2`
47
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
48
+ if hidden_ratio is None:
49
+ hidden_ratio = 4
50
+ if intermediate_size is None:
51
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
52
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.norm_first = norm_first
56
+
57
+ if norm_first:
58
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
59
+
60
+ self.gate_proj = nn.Linear(
61
+ self.hidden_size, self.intermediate_size * 2, bias=False
62
+ )
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[hidden_act]
65
+
66
+ def forward(
67
+ self,
68
+ x: torch.Tensor,
69
+ **kwargs: Unpack[Dict],
70
+ ) -> torch.Tensor:
71
+ if self.norm_first:
72
+ x = rms_norm_linear(
73
+ x,
74
+ self.norm.weight,
75
+ self.norm.bias,
76
+ self.gate_proj.weight,
77
+ self.gate_proj.bias,
78
+ )
79
+ else:
80
+ x = self.gate_proj(x)
81
+ gate, y = x.chunk(2, -1)
82
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
83
+
84
+
85
+ class GatedDeltaProductBlock(nn.Module):
86
+ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
87
+ super().__init__()
88
+ self.hidden_size = config.hidden_size
89
+
90
+ if not config.norm_first:
91
+ self.attn_norm = RMSNorm(
92
+ hidden_size=config.hidden_size, eps=config.norm_eps
93
+ )
94
+ if config.attn is not None and layer_idx in config.attn["layers"]:
95
+ self.attn = Attention(
96
+ hidden_size=config.hidden_size,
97
+ num_heads=config.attn["num_heads"],
98
+ num_kv_heads=config.attn["num_kv_heads"],
99
+ window_size=config.attn["window_size"],
100
+ max_position_embeddings=config.max_position_embeddings,
101
+ layer_idx=layer_idx,
102
+ )
103
+ else:
104
+ self.attn = GatedDeltaProduct(
105
+ mode=config.attn_mode,
106
+ hidden_size=config.hidden_size,
107
+ expand_v=config.expand_v,
108
+ head_dim=config.head_dim,
109
+ num_heads=config.num_heads,
110
+ use_gate=config.use_gate,
111
+ use_forget_gate=config.use_forget_gate,
112
+ use_short_conv=config.use_short_conv,
113
+ conv_size=config.conv_size,
114
+ norm_first=config.norm_first,
115
+ norm_eps=config.norm_eps,
116
+ allow_neg_eigval=config.allow_neg_eigval,
117
+ num_householder=config.num_householder,
118
+ layer_idx=layer_idx,
119
+ use_beta_conv=config.use_beta_conv
120
+ )
121
+ if not config.norm_first:
122
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
123
+ self.mlp = GatedDeltaNetMLP(
124
+ hidden_size=config.hidden_size,
125
+ hidden_ratio=config.hidden_ratio,
126
+ intermediate_size=config.intermediate_size,
127
+ hidden_act=config.hidden_act,
128
+ norm_first=config.norm_first,
129
+ norm_eps=config.norm_eps,
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
137
+ use_cache: Optional[bool] = False,
138
+ output_attentions: Optional[bool] = False,
139
+ **kwargs: Unpack[Dict],
140
+ ) -> Tuple[
141
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
142
+ ]:
143
+ residual = hidden_states
144
+ if hasattr(self, "attn_norm"):
145
+ hidden_states = self.attn_norm(hidden_states)
146
+ hidden_states, attentions, past_key_values = self.attn(
147
+ hidden_states=hidden_states,
148
+ attention_mask=attention_mask,
149
+ past_key_values=past_key_values,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ **kwargs,
153
+ )
154
+ if hasattr(self, "mlp_norm"):
155
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
156
+ else:
157
+ hidden_states = residual + hidden_states
158
+ residual = hidden_states
159
+ hidden_states = self.mlp(hidden_states, **kwargs)
160
+ hidden_states = residual + hidden_states
161
+
162
+ outputs = (hidden_states, attentions, past_key_values)
163
+
164
+ return outputs
165
+
166
+
167
+ class GatedDeltaProductPreTrainedModel(PreTrainedModel):
168
+ config_class = GatedDeltaProductConfig
169
+ supports_gradient_checkpointing = True
170
+ _no_split_modules = ["GatedDeltaNetBlock"]
171
+
172
+ def __init__(self, *inputs, **kwargs):
173
+ super().__init__(*inputs, **kwargs)
174
+
175
+ def _init_weights(
176
+ self,
177
+ module: nn.Module,
178
+ rescale_prenorm_residual: bool = True,
179
+ num_residuals_per_layer: int = 2,
180
+ ):
181
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
182
+ # Slightly different from the TF version which uses truncated_normal for initialization
183
+ # cf https://github.com/pytorch/pytorch/pull/5617
184
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
185
+ if module.bias is not None:
186
+ nn.init.zeros_(module.bias)
187
+ elif isinstance(module, nn.Embedding):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
189
+ if module.padding_idx is not None:
190
+ module.weight.data[module.padding_idx].zero_()
191
+
192
+ if rescale_prenorm_residual:
193
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
194
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
195
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
196
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
197
+ #
198
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
199
+ for name, p in module.named_parameters():
200
+ if name in ["o_proj.weight", "down_proj.weight"]:
201
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
202
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
203
+ # We need to reinit p since this code could be called multiple times
204
+ # Having just p *= scale would repeatedly scale it down
205
+ with torch.no_grad():
206
+ p /= math.sqrt(
207
+ num_residuals_per_layer * self.config.num_hidden_layers
208
+ )
209
+
210
+
211
+ class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
212
+ def __init__(self, config: GatedDeltaProductConfig):
213
+ super().__init__(config)
214
+ self.padding_idx = config.pad_token_id
215
+ self.vocab_size = config.vocab_size
216
+
217
+ self.embeddings = nn.Embedding(
218
+ config.vocab_size, config.hidden_size, self.padding_idx
219
+ )
220
+ self.layers = nn.ModuleList(
221
+ [
222
+ GatedDeltaProductBlock(config, layer_idx)
223
+ for layer_idx in range(config.num_hidden_layers)
224
+ ]
225
+ )
226
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ self.post_init()
231
+
232
+ def get_input_embeddings(self):
233
+ return self.embeddings
234
+
235
+ def set_input_embeddings(self, value):
236
+ self.embeddings = value
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: Optional[torch.LongTensor] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ inputs_embeds: Optional[torch.FloatTensor] = None,
243
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
244
+ use_cache: Optional[bool] = None,
245
+ output_attentions: Optional[bool] = None,
246
+ output_hidden_states: Optional[bool] = None,
247
+ return_dict: Optional[bool] = None,
248
+ **kwargs: Unpack[Dict],
249
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
250
+ if output_attentions:
251
+ warnings.warn(
252
+ "`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
253
+ stacklevel=2,
254
+ )
255
+ output_attentions = False
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = (
267
+ use_cache
268
+ if use_cache is not None
269
+ else (self.config.use_cache if not self.training else False)
270
+ )
271
+ return_dict = (
272
+ return_dict if return_dict is not None else self.config.use_return_dict
273
+ )
274
+
275
+ # retrieve input_ids and inputs_embeds
276
+ if input_ids is not None and inputs_embeds is not None:
277
+ raise ValueError(
278
+ "You cannot specify both input_ids and inputs_embeds at the same time"
279
+ )
280
+ if input_ids is None and inputs_embeds is None:
281
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.embeddings(input_ids)
285
+ hidden_states = inputs_embeds
286
+
287
+ if use_cache and not isinstance(past_key_values, Cache):
288
+ past_key_values = Cache.from_legacy_cache(past_key_values)
289
+
290
+ if self.gradient_checkpointing and self.training and use_cache:
291
+ logger.warning_once(
292
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
293
+ )
294
+ use_cache = False
295
+
296
+ all_hidden_states = () if output_hidden_states else None
297
+ all_attns = () if output_attentions else None
298
+ for layer in self.layers:
299
+ if output_hidden_states:
300
+ all_hidden_states += (hidden_states,)
301
+
302
+ if self.gradient_checkpointing and self.training:
303
+ hidden_states, attentions, past_key_values = (
304
+ self._gradient_checkpointing_func(
305
+ layer.__call__,
306
+ hidden_states,
307
+ attention_mask,
308
+ past_key_values,
309
+ use_cache,
310
+ output_attentions,
311
+ **kwargs,
312
+ )
313
+ )
314
+ else:
315
+ hidden_states, attentions, past_key_values = layer(
316
+ hidden_states,
317
+ attention_mask=attention_mask,
318
+ past_key_values=past_key_values,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ **kwargs,
322
+ )
323
+
324
+ if output_attentions:
325
+ all_attns += (attentions,)
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ # add hidden states from the last decoder layer
329
+ if output_hidden_states:
330
+ all_hidden_states += (hidden_states,)
331
+
332
+ if not return_dict:
333
+ return tuple(
334
+ i
335
+ for i in [
336
+ hidden_states,
337
+ past_key_values,
338
+ all_hidden_states,
339
+ all_attns,
340
+ ]
341
+ if i is not None
342
+ )
343
+ return BaseModelOutputWithPast(
344
+ last_hidden_state=hidden_states,
345
+ past_key_values=past_key_values,
346
+ hidden_states=all_hidden_states,
347
+ attentions=all_attns,
348
+ )
349
+
350
+
351
+ class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
352
+ _tied_weights_keys = ["lm_head.weight"]
353
+
354
+ def __init__(self, config):
355
+ super().__init__(config)
356
+ self.model = GatedDeltaProductModel(config)
357
+ self.vocab_size = config.vocab_size
358
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
359
+
360
+ # Initialize weights and apply final processing
361
+ self.post_init()
362
+
363
+ def get_input_embeddings(self):
364
+ return self.model.embeddings
365
+
366
+ def set_input_embeddings(self, value):
367
+ self.model.embeddings = value
368
+
369
+ def get_output_embeddings(self):
370
+ return self.lm_head
371
+
372
+ def set_output_embeddings(self, new_embeddings):
373
+ self.lm_head = new_embeddings
374
+
375
+ def set_decoder(self, decoder):
376
+ self.model = decoder
377
+
378
+ def get_decoder(self):
379
+ return self.model
380
+
381
+ def generate(self, *args, **kwargs):
382
+ try:
383
+ return super().generate(*args, **kwargs)
384
+ except AttributeError as exception:
385
+ if "past_key_values" in str(exception):
386
+ raise AttributeError(
387
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
388
+ f"which is not supported for {self.__class__.__name__}. "
389
+ f"Try another generation strategy instead. "
390
+ f"For the available generation strategies, check this doc: "
391
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
392
+ )
393
+ else:
394
+ raise exception
395
+
396
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
397
+ def prepare_inputs_for_generation(
398
+ self,
399
+ input_ids: torch.LongTensor = None,
400
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ use_cache: bool = True,
404
+ num_logits_to_keep: Optional[int] = None,
405
+ logits_to_keep: Optional[int] = None,
406
+ **kwargs,
407
+ ):
408
+ # only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
409
+ if past_key_values is not None and len(past_key_values) > 0:
410
+ input_ids = input_ids[:, -1:]
411
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
412
+ if inputs_embeds is not None and past_key_values is None:
413
+ model_inputs = {"inputs_embeds": inputs_embeds}
414
+ else:
415
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
416
+ # recompiles graphs as the stride of the inputs is a guard.
417
+ # Ref: https://github.com/huggingface/transformers/pull/29114
418
+ # TODO: use `next_tokens` directly instead.
419
+ model_inputs = {"input_ids": input_ids.contiguous()}
420
+
421
+ if logits_to_keep is not None:
422
+ model_inputs['logits_to_keep'] = logits_to_keep
423
+
424
+ model_inputs.update(
425
+ {
426
+ "past_key_values": past_key_values,
427
+ "use_cache": use_cache,
428
+ "attention_mask": attention_mask,
429
+ "num_logits_to_keep": num_logits_to_keep,
430
+ }
431
+ )
432
+ return model_inputs
433
+
434
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
435
+ def forward(
436
+ self,
437
+ input_ids: torch.LongTensor = None,
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ inputs_embeds: Optional[torch.Tensor] = None,
440
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ num_logits_to_keep: Optional[int] = 0,
447
+ logits_to_keep: Optional[int] = 0,
448
+ **kwargs: Unpack[Dict],
449
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
450
+ num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
451
+ output_attentions = (
452
+ output_attentions
453
+ if output_attentions is not None
454
+ else self.config.output_attentions
455
+ )
456
+ output_hidden_states = (
457
+ output_hidden_states
458
+ if output_hidden_states is not None
459
+ else self.config.output_hidden_states
460
+ )
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ kwargs.pop("num_items_in_batch", None)
465
+ outputs = self.model(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ inputs_embeds=inputs_embeds,
469
+ past_key_values=past_key_values,
470
+ use_cache=use_cache,
471
+ output_attentions=output_attentions,
472
+ output_hidden_states=output_hidden_states,
473
+ return_dict=return_dict,
474
+ **kwargs,
475
+ )
476
+ hidden_states = outputs[0]
477
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
478
+
479
+ loss, logits = None, None
480
+ if not fuse_linear_and_cross_entropy or labels is None:
481
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
482
+ if labels is not None:
483
+ if self.config.fuse_cross_entropy:
484
+ if fuse_linear_and_cross_entropy:
485
+ loss_fct = FusedLinearCrossEntropyLoss()
486
+ else:
487
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
488
+ else:
489
+ loss_fct = nn.CrossEntropyLoss()
490
+ # Enable model parallelism
491
+ labels = labels.to(hidden_states.device)
492
+ labels = torch.cat(
493
+ (
494
+ labels[..., 1:],
495
+ torch.full_like(labels[:, :1], loss_fct.ignore_index),
496
+ ),
497
+ 1,
498
+ )
499
+ if fuse_linear_and_cross_entropy:
500
+ loss = loss_fct(
501
+ hidden_states.view(-1, self.config.hidden_size),
502
+ labels.view(-1),
503
+ self.lm_head.weight,
504
+ self.lm_head.bias,
505
+ )
506
+ else:
507
+ loss = loss_fct(
508
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
509
+ )
510
+
511
+ if not return_dict:
512
+ output = (logits,) + outputs[1:]
513
+ return (loss, *output) if loss is not None else output
514
+ return CausalLMOutputWithPast(
515
+ loss=loss,
516
+ logits=logits,
517
+ past_key_values=outputs.past_key_values,
518
+ hidden_states=outputs.hidden_states,
519
+ attentions=outputs.attentions,
520
+ )
fla/models/gla/modeling_gla.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gla import GatedLinearAttention
20
+ from fla.models.gla.configuration_gla import GLAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GLAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GLABlock(nn.Module):
33
+ def __init__(self, config: GLAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedLinearAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ use_output_gate=config.use_output_gate,
63
+ gate_fn=config.hidden_act,
64
+ elementwise_affine=config.elementwise_affine,
65
+ norm_eps=config.norm_eps,
66
+ clamp_min=config.clamp_min,
67
+ fuse_norm=config.fuse_norm,
68
+ layer_idx=layer_idx
69
+ )
70
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
71
+ self.mlp = GLAMLP(
72
+ hidden_size=config.hidden_size,
73
+ hidden_ratio=config.hidden_ratio,
74
+ intermediate_size=config.intermediate_size,
75
+ hidden_act=config.hidden_act,
76
+ fuse_swiglu=config.fuse_swiglu
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
84
+ use_cache: Optional[bool] = False,
85
+ output_attentions: Optional[bool] = False,
86
+ **kwargs: Unpack[Dict]
87
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
88
+ residual = hidden_states
89
+ hidden_states = self.attn_norm(hidden_states)
90
+ hidden_states, attentions, past_key_values = self.attn(
91
+ hidden_states=hidden_states,
92
+ attention_mask=attention_mask,
93
+ past_key_values=past_key_values,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ **kwargs
97
+ )
98
+ if self.config.fuse_norm:
99
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
100
+ else:
101
+ hidden_states = residual + hidden_states
102
+ residual = hidden_states
103
+ hidden_states = self.mlp_norm(hidden_states)
104
+ hidden_states = self.mlp(hidden_states, **kwargs)
105
+ hidden_states = residual + hidden_states
106
+
107
+ outputs = (hidden_states, attentions, past_key_values)
108
+
109
+ return outputs
110
+
111
+
112
+ class GLAPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = GLAConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['GLABlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ prenorm_residual_strategy: Optional[str] = 'rescale',
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if prenorm_residual_strategy is not None:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ if prenorm_residual_strategy == 'rescale':
158
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
159
+ with torch.no_grad():
160
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
161
+ elif prenorm_residual_strategy == 'zero':
162
+ nn.init.zeros_(p)
163
+ else:
164
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
165
+
166
+
167
+ class GLAModel(GLAPreTrainedModel):
168
+
169
+ def __init__(self, config: GLAConfig):
170
+ super().__init__(config)
171
+ self.padding_idx = config.pad_token_id
172
+ self.vocab_size = config.vocab_size
173
+
174
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
175
+ self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
176
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
177
+
178
+ self.gradient_checkpointing = False
179
+
180
+ self.post_init()
181
+
182
+ def get_input_embeddings(self):
183
+ return self.embeddings
184
+
185
+ def set_input_embeddings(self, value):
186
+ self.embeddings = value
187
+
188
+ def forward(
189
+ self,
190
+ input_ids: Optional[torch.LongTensor] = None,
191
+ attention_mask: Optional[torch.Tensor] = None, # noqa
192
+ inputs_embeds: Optional[torch.FloatTensor] = None,
193
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
194
+ use_cache: Optional[bool] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ return_dict: Optional[bool] = None,
198
+ **kwargs: Unpack[Dict]
199
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
200
+ if output_attentions:
201
+ warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ if input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embeddings(input_ids)
216
+ hidden_states = inputs_embeds
217
+
218
+ if use_cache and not isinstance(past_key_values, Cache):
219
+ past_key_values = Cache.from_legacy_cache(past_key_values)
220
+
221
+ if self.gradient_checkpointing and self.training and use_cache:
222
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
223
+ use_cache = False
224
+
225
+ all_hidden_states = () if output_hidden_states else None
226
+ all_attns = () if output_attentions else None
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ use_cache,
238
+ output_attentions,
239
+ **kwargs
240
+ )
241
+ else:
242
+ hidden_states, attentions, past_key_values = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ use_cache=use_cache,
247
+ output_attentions=output_attentions,
248
+ **kwargs
249
+ )
250
+
251
+ if output_attentions:
252
+ all_attns += (attentions,)
253
+
254
+ hidden_states = self.norm(hidden_states)
255
+
256
+ # add hidden states from the last decoder layer
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ if not return_dict:
261
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
262
+ return BaseModelOutputWithPast(
263
+ last_hidden_state=hidden_states,
264
+ past_key_values=past_key_values,
265
+ hidden_states=all_hidden_states,
266
+ attentions=all_attns
267
+ )
268
+
269
+
270
+ class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
271
+
272
+ _tied_weights_keys = ["lm_head.weight"]
273
+
274
+ def __init__(self, config):
275
+ super().__init__(config)
276
+ self.model = GLAModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+ self.criterion = None
280
+
281
+ # Initialize weights and apply final processing
282
+ self.post_init()
283
+
284
+ def get_input_embeddings(self):
285
+ return self.model.embeddings
286
+
287
+ def set_input_embeddings(self, value):
288
+ self.model.embeddings = value
289
+
290
+ def get_output_embeddings(self):
291
+ return self.lm_head
292
+
293
+ def set_output_embeddings(self, new_embeddings):
294
+ self.lm_head = new_embeddings
295
+
296
+ def set_decoder(self, decoder):
297
+ self.model = decoder
298
+
299
+ def get_decoder(self):
300
+ return self.model
301
+
302
+ def generate(self, *args, **kwargs):
303
+ try:
304
+ return super().generate(*args, **kwargs)
305
+ except AttributeError as exception:
306
+ if 'past_key_values' in str(exception):
307
+ raise AttributeError(
308
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
309
+ f"which is not supported for {self.__class__.__name__}. "
310
+ f"Try another generation strategy instead. "
311
+ f"For the available generation strategies, check this doc: "
312
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
313
+ )
314
+ else:
315
+ raise exception
316
+
317
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
318
+ def prepare_inputs_for_generation(
319
+ self,
320
+ input_ids: torch.LongTensor = None,
321
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ inputs_embeds: Optional[torch.Tensor] = None,
324
+ use_cache: bool = True,
325
+ logits_to_keep: Optional[int] = None,
326
+ **kwargs
327
+ ):
328
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
329
+ if past_key_values is not None and len(past_key_values) > 0:
330
+ input_ids = input_ids[:, -1:]
331
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
332
+ if inputs_embeds is not None and len(past_key_values) == 0:
333
+ model_inputs = {'inputs_embeds': inputs_embeds}
334
+ else:
335
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
336
+ # recompiles graphs as the stride of the inputs is a guard.
337
+ # Ref: https://github.com/huggingface/transformers/pull/29114
338
+ # TODO: use `next_tokens` directly instead.
339
+ model_inputs = {'input_ids': input_ids.contiguous()}
340
+
341
+ if logits_to_keep is not None:
342
+ model_inputs['logits_to_keep'] = logits_to_keep
343
+
344
+ model_inputs.update({
345
+ 'past_key_values': past_key_values,
346
+ 'use_cache': use_cache,
347
+ 'attention_mask': attention_mask,
348
+ })
349
+ return model_inputs
350
+
351
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
352
+ def forward(
353
+ self,
354
+ input_ids: torch.LongTensor = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
358
+ labels: Optional[torch.LongTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ logits_to_keep: Optional[int] = 0,
364
+ **kwargs: Unpack[Dict]
365
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
366
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
367
+ output_hidden_states = (
368
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
369
+ )
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.model(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ inputs_embeds=inputs_embeds,
376
+ past_key_values=past_key_values,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ **kwargs
382
+ )
383
+
384
+ hidden_states = outputs[0]
385
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
386
+
387
+ loss, logits = None, None
388
+ if not fuse_linear_and_cross_entropy or labels is None:
389
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
390
+ if labels is not None:
391
+ if getattr(self, 'criterion', None) is None:
392
+ if fuse_linear_and_cross_entropy:
393
+ criterion = FusedLinearCrossEntropyLoss()
394
+ elif self.config.fuse_cross_entropy:
395
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
396
+ else:
397
+ criterion = nn.CrossEntropyLoss()
398
+ else:
399
+ criterion = self.criterion
400
+ labels = labels.to(hidden_states.device)
401
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
402
+ if fuse_linear_and_cross_entropy:
403
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
404
+ else:
405
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
406
+
407
+ if not return_dict:
408
+ output = (logits,) + outputs[1:]
409
+ return (loss,) + output if loss is not None else output
410
+
411
+ return CausalLMOutputWithPast(
412
+ loss=loss,
413
+ logits=logits,
414
+ past_key_values=outputs.past_key_values,
415
+ hidden_states=outputs.hidden_states,
416
+ attentions=outputs.attentions,
417
+ )
fla/models/gsa/configuration_gsa.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_logit_normalizer: Optional[int] = 8,
17
+ clamp_min: Optional[float] = None,
18
+ clamp_max: Optional[float] = None,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ num_slots: Optional[int] = 64,
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ exapnd_k: float = 1,
28
+ exapnd_v: float = 1,
29
+ feature_map: str = 'swish',
30
+ use_output_gate: bool = False,
31
+ use_norm: bool = True,
32
+ max_position_embeddings: int = 2048,
33
+ hidden_act: str = "swish",
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-6,
36
+ attn: Optional[Dict] = None,
37
+ use_cache: bool = True,
38
+ pad_token_id: int = None,
39
+ bos_token_id: int = 1,
40
+ eos_token_id: int = 2,
41
+ initializer_range: float = 0.006,
42
+ tie_word_embeddings: bool = False,
43
+ fuse_norm: bool = True,
44
+ fuse_swiglu: bool = True,
45
+ fuse_cross_entropy: bool = True,
46
+ vocab_size: int = 32000,
47
+ **kwargs
48
+ ):
49
+ self.hidden_size = hidden_size
50
+ self.gate_logit_normalizer = gate_logit_normalizer
51
+ self.clamp_min = clamp_min
52
+ self.clamp_max = clamp_max
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_slots = num_slots
59
+ self.use_short_conv = use_short_conv
60
+ self.conv_size = conv_size
61
+ self.expand_k = exapnd_k
62
+ self.expand_v = exapnd_v
63
+ self.feature_map = feature_map
64
+ self.use_output_gate = use_output_gate
65
+ self.use_norm = use_norm
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.elementwise_affine = elementwise_affine
69
+ self.norm_eps = norm_eps
70
+ self.attn = attn
71
+ self.use_cache = use_cache
72
+ self.initializer_range = initializer_range
73
+
74
+ self.fuse_norm = fuse_norm
75
+ self.fuse_swiglu = fuse_swiglu
76
+ self.fuse_cross_entropy = fuse_cross_entropy
77
+ self.vocab_size = vocab_size
78
+
79
+ if attn is not None:
80
+ if not isinstance(attn, Dict):
81
+ raise ValueError("attn must be a dictionary")
82
+ if 'layers' not in attn:
83
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
84
+ if 'num_heads' not in attn:
85
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
86
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
87
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
88
+ attn['window_size'] = attn.get('window_size', None)
89
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
90
+
91
+ super().__init__(
92
+ pad_token_id=pad_token_id,
93
+ bos_token_id=bos_token_id,
94
+ eos_token_id=eos_token_id,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs,
97
+ )
fla/models/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (665 Bytes). View file
 
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc ADDED
Binary file (3.28 kB). View file
 
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/mamba2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (695 Bytes). View file
 
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
fla/models/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/nsa/modeling_nsa.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.nsa import NativeSparseAttention
19
+ from fla.models.nsa.configuration_nsa import NSAConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as NSAMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class NSABlock(nn.Module):
32
+ def __init__(self, config: NSAConfig, layer_idx: int):
33
+ super().__init__()
34
+
35
+ self.config = config
36
+ self.layer_idx = layer_idx
37
+
38
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
39
+ self.attn = NativeSparseAttention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.num_heads,
42
+ num_kv_heads=config.num_kv_heads,
43
+ qkv_bias=config.qkv_bias,
44
+ block_size=config.block_size,
45
+ block_counts=config.block_counts,
46
+ window_size=config.window_size,
47
+ rope_theta=config.rope_theta,
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = NSAMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ intermediate_size=config.intermediate_size,
56
+ hidden_act=config.hidden_act,
57
+ fuse_swiglu=config.fuse_swiglu
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
65
+ use_cache: Optional[bool] = False,
66
+ output_attentions: Optional[bool] = False,
67
+ **kwargs: Unpack[Dict]
68
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
69
+ residual = hidden_states
70
+ hidden_states = self.attn_norm(hidden_states)
71
+ hidden_states, attentions, past_key_values = self.attn(
72
+ hidden_states=hidden_states,
73
+ attention_mask=attention_mask,
74
+ past_key_values=past_key_values,
75
+ use_cache=use_cache,
76
+ output_attentions=output_attentions,
77
+ **kwargs
78
+ )
79
+ if self.config.fuse_norm:
80
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
81
+ else:
82
+ hidden_states = residual + hidden_states
83
+ residual = hidden_states
84
+ hidden_states = self.mlp_norm(hidden_states)
85
+ hidden_states = self.mlp(hidden_states, **kwargs)
86
+ hidden_states = residual + hidden_states
87
+
88
+ outputs = (hidden_states, attentions, past_key_values)
89
+
90
+ return outputs
91
+
92
+
93
+ class NSAPreTrainedModel(PreTrainedModel):
94
+
95
+ config_class = NSAConfig
96
+ base_model_prefix = 'model'
97
+ supports_gradient_checkpointing = True
98
+ _no_split_modules = ['NSABlock']
99
+ _supports_cache_class = True
100
+
101
+ def __init__(self, *inputs, **kwargs):
102
+ super().__init__(*inputs, **kwargs)
103
+
104
+ def _init_weights(
105
+ self,
106
+ module: nn.Module,
107
+ prenorm_residual_strategy: Optional[str] = 'rescale',
108
+ num_residuals_per_layer: int = 2,
109
+ ):
110
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
111
+ # Slightly different from the TF version which uses truncated_normal for initialization
112
+ # cf https://github.com/pytorch/pytorch/pull/5617
113
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
114
+ if module.bias is not None:
115
+ nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
+ elif hasattr(module, 'reset_parameters'):
119
+ module.reset_parameters()
120
+
121
+ if prenorm_residual_strategy is not None:
122
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
123
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
124
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
125
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
126
+ #
127
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
128
+ p = None
129
+ if hasattr(module, 'o_proj'):
130
+ p = module.o_proj.weight
131
+ elif hasattr(module, 'down_proj'):
132
+ p = module.down_proj.weight
133
+ if p is not None:
134
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
135
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
136
+ # We need to reinit p since this code could be called multiple times
137
+ # Having just p *= scale would repeatedly scale it down
138
+ if prenorm_residual_strategy == 'rescale':
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
142
+ elif prenorm_residual_strategy == 'zero':
143
+ nn.init.zeros_(p)
144
+ else:
145
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
146
+
147
+
148
+ class NSAModel(NSAPreTrainedModel):
149
+
150
+ def __init__(self, config: NSAConfig):
151
+ super().__init__(config)
152
+ self.padding_idx = config.pad_token_id
153
+ self.vocab_size = config.vocab_size
154
+
155
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
156
+ self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
157
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
158
+
159
+ self.gradient_checkpointing = False
160
+
161
+ self.post_init()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.embeddings
165
+
166
+ def set_input_embeddings(self, value):
167
+ self.embeddings = value
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: Optional[torch.LongTensor] = None,
172
+ attention_mask: Optional[torch.Tensor] = None, # noqa
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ **kwargs: Unpack[Dict]
180
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
181
+ if output_attentions:
182
+ warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
183
+ output_attentions = False
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
185
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ # retrieve input_ids and inputs_embeds
190
+ if input_ids is not None and inputs_embeds is not None:
191
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
192
+ if input_ids is None and inputs_embeds is None:
193
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
194
+
195
+ if inputs_embeds is None:
196
+ inputs_embeds = self.embeddings(input_ids)
197
+ hidden_states = inputs_embeds
198
+
199
+ if use_cache and not isinstance(past_key_values, Cache):
200
+ past_key_values = Cache.from_legacy_cache(past_key_values)
201
+
202
+ if self.gradient_checkpointing and self.training and use_cache:
203
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
204
+ use_cache = False
205
+
206
+ all_hidden_states = () if output_hidden_states else None
207
+ all_attns = () if output_attentions else None
208
+ for layer in self.layers:
209
+ if output_hidden_states:
210
+ all_hidden_states += (hidden_states,)
211
+
212
+ if self.gradient_checkpointing and self.training:
213
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
214
+ layer.__call__,
215
+ hidden_states,
216
+ attention_mask,
217
+ past_key_values,
218
+ use_cache,
219
+ output_attentions,
220
+ **kwargs
221
+ )
222
+ else:
223
+ hidden_states, attentions, past_key_values = layer(
224
+ hidden_states,
225
+ attention_mask=attention_mask,
226
+ past_key_values=past_key_values,
227
+ use_cache=use_cache,
228
+ output_attentions=output_attentions,
229
+ **kwargs
230
+ )
231
+
232
+ if output_attentions:
233
+ all_attns += (attentions,)
234
+
235
+ hidden_states = self.norm(hidden_states)
236
+
237
+ # add hidden states from the last decoder layer
238
+ if output_hidden_states:
239
+ all_hidden_states += (hidden_states,)
240
+
241
+ if not return_dict:
242
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=past_key_values,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_attns
248
+ )
249
+
250
+
251
+ class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
252
+
253
+ _tied_weights_keys = ["lm_head.weight"]
254
+
255
+ def __init__(self, config):
256
+ super().__init__(config)
257
+ self.model = NSAModel(config)
258
+ self.vocab_size = config.vocab_size
259
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
260
+ self.criterion = None
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_input_embeddings(self):
266
+ return self.model.embeddings
267
+
268
+ def set_input_embeddings(self, value):
269
+ self.model.embeddings = value
270
+
271
+ def get_output_embeddings(self):
272
+ return self.lm_head
273
+
274
+ def set_output_embeddings(self, new_embeddings):
275
+ self.lm_head = new_embeddings
276
+
277
+ def set_decoder(self, decoder):
278
+ self.model = decoder
279
+
280
+ def get_decoder(self):
281
+ return self.model
282
+
283
+ def generate(self, *args, **kwargs):
284
+ try:
285
+ return super().generate(*args, **kwargs)
286
+ except AttributeError as exception:
287
+ if 'past_key_values' in str(exception):
288
+ raise AttributeError(
289
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
290
+ f"which is not supported for {self.__class__.__name__}. "
291
+ f"Try another generation strategy instead. "
292
+ f"For the available generation strategies, check this doc: "
293
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
294
+ )
295
+ else:
296
+ raise exception
297
+
298
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
299
+ def prepare_inputs_for_generation(
300
+ self,
301
+ input_ids: torch.LongTensor = None,
302
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ inputs_embeds: Optional[torch.Tensor] = None,
305
+ use_cache: bool = True,
306
+ logits_to_keep: Optional[int] = None,
307
+ **kwargs
308
+ ):
309
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
310
+ if past_key_values is not None and len(past_key_values) > 0:
311
+ input_ids = input_ids[:, -1:]
312
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
313
+ if inputs_embeds is not None and len(past_key_values) == 0:
314
+ model_inputs = {'inputs_embeds': inputs_embeds}
315
+ else:
316
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
317
+ # recompiles graphs as the stride of the inputs is a guard.
318
+ # Ref: https://github.com/huggingface/transformers/pull/29114
319
+ # TODO: use `next_tokens` directly instead.
320
+ model_inputs = {'input_ids': input_ids.contiguous()}
321
+
322
+ if logits_to_keep is not None:
323
+ model_inputs['logits_to_keep'] = logits_to_keep
324
+
325
+ model_inputs.update({
326
+ 'past_key_values': past_key_values,
327
+ 'use_cache': use_cache,
328
+ 'attention_mask': attention_mask,
329
+ })
330
+ return model_inputs
331
+
332
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ inputs_embeds: Optional[torch.Tensor] = None,
338
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ logits_to_keep: Optional[int] = 0,
345
+ **kwargs: Unpack[Dict]
346
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
347
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
348
+ output_hidden_states = (
349
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
350
+ )
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+
353
+ outputs = self.model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ inputs_embeds=inputs_embeds,
357
+ past_key_values=past_key_values,
358
+ use_cache=use_cache,
359
+ output_attentions=output_attentions,
360
+ output_hidden_states=output_hidden_states,
361
+ return_dict=return_dict,
362
+ **kwargs
363
+ )
364
+
365
+ hidden_states = outputs[0]
366
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
367
+
368
+ loss, logits = None, None
369
+ if not fuse_linear_and_cross_entropy or labels is None:
370
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
371
+ if labels is not None:
372
+ if getattr(self, 'criterion', None) is None:
373
+ if fuse_linear_and_cross_entropy:
374
+ criterion = FusedLinearCrossEntropyLoss()
375
+ elif self.config.fuse_cross_entropy:
376
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
377
+ else:
378
+ criterion = nn.CrossEntropyLoss()
379
+ else:
380
+ criterion = self.criterion
381
+ labels = labels.to(hidden_states.device)
382
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
383
+ if fuse_linear_and_cross_entropy:
384
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
385
+ else:
386
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
387
+
388
+ if not return_dict:
389
+ output = (logits,) + outputs[1:]
390
+ return (loss,) + output if loss is not None else output
391
+
392
+ return CausalLMOutputWithPast(
393
+ loss=loss,
394
+ logits=logits,
395
+ past_key_values=outputs.past_key_values,
396
+ hidden_states=outputs.hidden_states,
397
+ attentions=outputs.attentions,
398
+ )
fla/models/retnet/configuration_retnet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Optional
6
+
7
+ from transformers.configuration_utils import PretrainedConfig
8
+
9
+
10
+ class RetNetConfig(PretrainedConfig):
11
+
12
+ model_type = 'retnet'
13
+ keys_to_ignore_at_inference = ['past_key_values']
14
+
15
+ def __init__(
16
+ self,
17
+ attn_mode: str = "chunk",
18
+ hidden_size: int = 2048,
19
+ expand_k: int = 1,
20
+ expand_v: int = 2,
21
+ hidden_ratio: Optional[int] = 2,
22
+ intermediate_size: Optional[int] = None,
23
+ num_hidden_layers: int = 24,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: Optional[str] = None,
27
+ hidden_act: str = "swish",
28
+ use_short_conv: bool = False,
29
+ conv_size: int = 4,
30
+ use_output_gate: bool = True,
31
+ max_position_embeddings: int = 2048,
32
+ elementwise_affine: Optional[bool] = True,
33
+ norm_eps: float = 1e-6,
34
+ attn: Optional[Dict] = None,
35
+ use_cache: bool = True,
36
+ pad_token_id: int = None,
37
+ bos_token_id: int = 1,
38
+ eos_token_id: int = 2,
39
+ tie_word_embeddings: bool = False,
40
+ initializer_range: float = 0.006,
41
+ fuse_norm: bool = True,
42
+ fuse_swiglu: bool = True,
43
+ fuse_cross_entropy: bool = True,
44
+ vocab_size: int = 32000,
45
+ **kwargs
46
+ ) -> RetNetConfig:
47
+ self.attn_mode = attn_mode
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.hidden_act = hidden_act
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.hidden_act = hidden_act
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.elementwise_affine = elementwise_affine
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+
69
+ self.fuse_norm = fuse_norm
70
+ self.fuse_swiglu = fuse_swiglu
71
+ self.fuse_cross_entropy = fuse_cross_entropy
72
+ self.vocab_size = vocab_size
73
+
74
+ if attn is not None:
75
+ if not isinstance(attn, Dict):
76
+ raise ValueError("attn must be a dictionary")
77
+ if 'layers' not in attn:
78
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
79
+ if 'num_heads' not in attn:
80
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
81
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
82
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
83
+ attn['window_size'] = attn.get('window_size', None)
84
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs,
92
+ )
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc ADDED
Binary file (4.24 kB). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (24.7 kB). View file
 
fla/models/transformer_mtp/modeling_transformer.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from dataclasses import dataclass
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.utils.deprecation import deprecate_kwarg
19
+
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from fla.layers.attn import Attention
24
+ from fla.models.transformer_mtp.configuration_transformer import MTPTransformerConfig
25
+ from fla.models.utils import Cache
26
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
27
+ from fla.modules import GatedMLP as TransformerMLP
28
+ from fla.modules import RMSNorm
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers.processing_utils import Unpack
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ class SequentialHeadsCustomBackward(torch.autograd.Function):
37
+ @staticmethod
38
+ def forward(ctx, trunk_output, lm_head, norm_layer, logits_to_keep, *prediction_heads):
39
+ # We now need the norm layer in the forward pass calculation
40
+ ctx.prediction_heads = prediction_heads
41
+ ctx.lm_head = lm_head
42
+ ctx.norm_layer = norm_layer
43
+ ctx.logits_to_keep = logits_to_keep
44
+ ctx.save_for_backward(trunk_output)
45
+
46
+ latents = []
47
+ for head in prediction_heads:
48
+ # Assuming head forward signature is `head(hidden_states)`
49
+ latent = head(trunk_output)[0]
50
+ latents.append(latent)
51
+
52
+ latents_stacked = torch.stack(latents, dim=-2)
53
+ # Apply the final norm before the lm_head
54
+ normalized_latents = norm_layer(latents_stacked)
55
+ all_logits = lm_head(normalized_latents[:, -logits_to_keep:])
56
+ return all_logits
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ trunk_output, = ctx.saved_tensors
61
+ prediction_heads = ctx.prediction_heads
62
+ lm_head = ctx.lm_head
63
+ norm_layer = ctx.norm_layer
64
+ logits_to_keep = ctx.logits_to_keep
65
+
66
+ d = trunk_output.detach().requires_grad_(True)
67
+ grad_output_per_head = grad_output.unbind(dim=2)
68
+
69
+ # We need to manually handle the backward pass for the final norm layer once
70
+ # before the loop, as its gradient depends on all heads.
71
+ # To do this, we reconstruct the input to the lm_head and do a backward pass.
72
+ with torch.enable_grad():
73
+ # Re-run the head computations to get the input to the norm layer
74
+ latents = []
75
+ for head in prediction_heads:
76
+ latents.append(head(d)[0])
77
+ latents_stacked = torch.stack(latents, dim=-2)
78
+ latents_stacked.requires_grad_(True)
79
+ # The part of the graph we need to backprop through first
80
+ normalized_latents = norm_layer(latents_stacked)
81
+
82
+ # Backpropagate through the lm_head and norm_layer
83
+ normalized_latents.backward(lm_head.weight.grad @ grad_output)
84
+
85
+ # Now, `latents_stacked.grad` contains the sum of gradients from all heads
86
+ # just before the final normalization. We can now unbind it.
87
+ grad_per_head_latent = latents_stacked.grad.unbind(dim=-2)
88
+
89
+ # Now, backpropagate through each head individually.
90
+ for i, head in enumerate(prediction_heads):
91
+ with torch.enable_grad():
92
+ head_latent = head(d)[0]
93
+ # Backpropagate using the gradient for this specific head's output
94
+ head_latent.backward(gradient=grad_per_head_latent[i])
95
+
96
+ num_nones = 2 + len(prediction_heads) # for lm_head, norm_layer, and *prediction_heads
97
+ return (d.grad,) + (None,) * num_nones
98
+
99
+ def seq_to_mtp(
100
+ long_input_ids: torch.Tensor,
101
+ model_seq_len: int,
102
+ n_future_tokens: int
103
+ ) -> torch.Tensor:
104
+ """
105
+ Generates a tensor of future targets on the fly from a long input sequence.
106
+
107
+ This version assumes `long_input_ids` contains both the tokens for the model's
108
+ input AND the future tokens needed for the labels.
109
+ It extracts the correct targets without adding artificial padding.
110
+
111
+ Args:
112
+ long_input_ids (torch.Tensor): The input sequences from the dataloader,
113
+ shape (B, T + n_future_tokens).
114
+ model_seq_len (int): The sequence length `T` that the model processes.
115
+ n_future_tokens (int): The number of future tokens to predict for each time step.
116
+
117
+ Returns:
118
+ torch.Tensor: The target tensor of shape (B, T, n_future_tokens).
119
+ y[b, t, k] corresponds to the (k+1)-th token after input_ids[b, t].
120
+ """
121
+ B, total_len = long_input_ids.shape
122
+ assert total_len >= model_seq_len + n_future_tokens, \
123
+ "long_input_ids must be at least model_seq_len + n_future_tokens long."
124
+
125
+ # 1. Create sliding windows (views) over the long tensor.
126
+ # .unfold() is a highly efficient way to create sliding windows.
127
+ # We create windows of size `n_future_tokens + 1`. For each time step `t`,
128
+ # the window will contain the input token and its `n_future_tokens` targets.
129
+ # Example (n=3, window_size=4):
130
+ # For t=0, window is [t0, t1, t2, t3]
131
+ # For t=1, window is [t1, t2, t3, t4]
132
+ # Shape of windows: (B, total_len - n_future_tokens, n_future_tokens + 1)
133
+ windows = long_input_ids.unfold(dimension=1, size=n_future_tokens + 1, step=1)
134
+
135
+ # 2. Slice the windows to get only the targets.
136
+ # We slice off the first element of each window (the input token itself)
137
+ # to keep only the future tokens.
138
+ # Example window [t0, t1, t2, t3] -> becomes targets [t1, t2, t3]
139
+ all_targets = windows[:, :, 1:]
140
+
141
+ # 3. Trim the result to match the model's output sequence length.
142
+ # We only need the targets for the first `model_seq_len` positions.
143
+ output_targets = all_targets[:, :model_seq_len, :]
144
+
145
+ return output_targets.transpose(1, 2)
146
+
147
+
148
+ @dataclass
149
+ class MTPLMOutputWithPast(CausalLMOutputWithPast):
150
+ ntp_loss: Optional[torch.FloatTensor] = None
151
+ mtp_loss: Optional[torch.FloatTensor] = None
152
+
153
+ class MTPTransformerBlock(nn.Module):
154
+
155
+ def __init__(self, config: MTPTransformerConfig, layer_idx: int):
156
+ super().__init__()
157
+
158
+ self.config = config
159
+ self.layer_idx = layer_idx
160
+
161
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
162
+ self.attn = Attention(
163
+ hidden_size=config.hidden_size,
164
+ num_heads=config.num_heads,
165
+ num_kv_heads=config.num_kv_heads,
166
+ qkv_bias=config.qkv_bias,
167
+ qk_norm=config.qk_norm,
168
+ window_size=config.window_size,
169
+ rope_theta=config.rope_theta,
170
+ max_position_embeddings=config.max_position_embeddings,
171
+ layer_idx=layer_idx
172
+ )
173
+
174
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+ self.mlp = TransformerMLP(
176
+ hidden_size=config.hidden_size,
177
+ hidden_ratio=config.hidden_ratio,
178
+ intermediate_size=config.intermediate_size,
179
+ hidden_act=config.hidden_act,
180
+ fuse_swiglu=config.fuse_swiglu
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ hidden_states: torch.Tensor,
186
+ attention_mask: Optional[torch.Tensor] = None,
187
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
188
+ output_attentions: Optional[bool] = False,
189
+ use_cache: Optional[bool] = False,
190
+ **kwargs: Unpack[Any]
191
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
192
+
193
+ residual = hidden_states
194
+ hidden_states = self.attn_norm(hidden_states)
195
+ hidden_states, attentions, past_key_values = self.attn(
196
+ hidden_states=hidden_states,
197
+ attention_mask=attention_mask,
198
+ past_key_values=past_key_values,
199
+ use_cache=use_cache,
200
+ output_attentions=output_attentions,
201
+ **kwargs
202
+ )
203
+ if self.config.fuse_norm:
204
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
205
+ else:
206
+ hidden_states = residual + hidden_states
207
+ residual = hidden_states
208
+ hidden_states = self.mlp_norm(hidden_states)
209
+ hidden_states = self.mlp(hidden_states, **kwargs)
210
+ hidden_states = residual + hidden_states
211
+
212
+ outputs = (hidden_states,)
213
+
214
+ if output_attentions:
215
+ outputs += (attentions,)
216
+
217
+ if use_cache:
218
+ outputs += (past_key_values,)
219
+
220
+ return outputs
221
+
222
+
223
+ class MTPTransformerPreTrainedModel(PreTrainedModel):
224
+
225
+ config_class = MTPTransformerConfig
226
+ base_model_prefix = 'model'
227
+ supports_gradient_checkpointing = True
228
+ _no_split_modules = ['MTPTransformerBlock']
229
+ _supports_cache_class = True
230
+
231
+ def __init__(self, *inputs, **kwargs):
232
+ super().__init__(*inputs, **kwargs)
233
+
234
+ def _init_weights(
235
+ self,
236
+ module: nn.Module,
237
+ rescale_prenorm_residual: bool = False,
238
+ num_residuals_per_layer: int = 2,
239
+ ):
240
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
241
+ # Slightly different from the TF version which uses truncated_normal for initialization
242
+ # cf https://github.com/pytorch/pytorch/pull/5617
243
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
244
+ if module.bias is not None:
245
+ nn.init.zeros_(module.bias)
246
+ elif isinstance(module, nn.Embedding):
247
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
248
+ elif hasattr(module, 'reset_parameters'):
249
+ module.reset_parameters()
250
+
251
+ if rescale_prenorm_residual:
252
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
253
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
254
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
255
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
256
+ #
257
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
258
+ p = None
259
+ if hasattr(module, 'o_proj'):
260
+ p = module.o_proj.weight
261
+ elif hasattr(module, 'down_proj'):
262
+ p = module.down_proj.weight
263
+ if p is not None:
264
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
265
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
266
+ # We need to reinit p since this code could be called multiple times
267
+ # Having just p *= scale would repeatedly scale it down
268
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
269
+ with torch.no_grad():
270
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
271
+
272
+
273
+ class MTPTransformerModel(MTPTransformerPreTrainedModel):
274
+
275
+ def __init__(
276
+ self,
277
+ config: MTPTransformerConfig
278
+ ) -> MTPTransformerModel:
279
+ super().__init__(config)
280
+ self.padding_idx = config.pad_token_id
281
+ self.vocab_size = config.vocab_size
282
+
283
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
284
+ self.layers = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers - config.n_future_tokens)])
285
+ self.extra_heads = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.n_future_tokens)])
286
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
287
+
288
+ self.gradient_checkpointing = False
289
+
290
+ self.post_init()
291
+
292
+ def get_input_embeddings(self):
293
+ return self.embeddings
294
+
295
+ def set_input_embeddings(self, value):
296
+ self.embeddings = value
297
+
298
+ def forward(
299
+ self,
300
+ input_ids: Optional[torch.LongTensor] = None,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
303
+ inputs_embeds: Optional[torch.FloatTensor] = None,
304
+ use_cache: Optional[bool] = None,
305
+ output_attentions: Optional[bool] = None,
306
+ output_hidden_states: Optional[bool] = None,
307
+ return_dict: Optional[bool] = None,
308
+ return_all_heads: bool = False, # if Training, this is True
309
+ **kwargs: Unpack[Any]
310
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
311
+ if output_attentions:
312
+ warnings.warn(
313
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
314
+ )
315
+ output_attentions = False
316
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
317
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
318
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+ use_custom_backward = self.config.use_custom_backward and self.training
321
+ if self.training and return_all_heads is False:
322
+ logger.warning_once(
323
+ "`return_all_heads=False` is incompatible with training. Setting `return_all_heads=True`..."
324
+ )
325
+ return_all_heads = True
326
+
327
+ # retrieve input_ids and inputs_embeds
328
+ if input_ids is not None and inputs_embeds is not None:
329
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
330
+ elif input_ids is None and inputs_embeds is None:
331
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
332
+
333
+ if use_cache and not isinstance(past_key_values, Cache):
334
+ past_key_values = Cache.from_legacy_cache(past_key_values)
335
+
336
+ if inputs_embeds is None:
337
+ inputs_embeds = self.embeddings(input_ids)
338
+
339
+ # embed positions
340
+ hidden_states = inputs_embeds
341
+
342
+ if self.gradient_checkpointing and self.training:
343
+ if use_cache:
344
+ logger.warning_once(
345
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
346
+ )
347
+ use_cache = False
348
+
349
+ all_hidden_states = () if output_hidden_states else None
350
+ all_attns = () if output_attentions else None
351
+ next_cache = None
352
+
353
+ for layer in self.layers:
354
+ if output_hidden_states:
355
+ all_hidden_states += (hidden_states,)
356
+
357
+ if self.gradient_checkpointing and self.training:
358
+ layer_outputs = self._gradient_checkpointing_func(
359
+ layer.__call__,
360
+ hidden_states,
361
+ attention_mask,
362
+ past_key_values,
363
+ output_attentions,
364
+ use_cache,
365
+ **kwargs
366
+ )
367
+ else:
368
+ layer_outputs = layer(
369
+ hidden_states,
370
+ attention_mask=attention_mask,
371
+ past_key_values=past_key_values,
372
+ output_attentions=output_attentions,
373
+ use_cache=use_cache,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = layer_outputs[0]
378
+
379
+ if use_cache:
380
+ next_cache = layer_outputs[2 if output_attentions else 1]
381
+
382
+ if output_attentions:
383
+ all_attns += (layer_outputs[1],)
384
+
385
+ trunk = hidden_states
386
+
387
+ n_heads_to_use = self.config.n_future_tokens if return_all_heads else 1
388
+ prediction_heads_to_use = self.extra_heads[:n_heads_to_use]
389
+
390
+ if use_custom_backward and self.training:
391
+ # all_logits = SequentialHeadsCustomBackward.apply(trunk, self.lm_head, *prediction_heads)
392
+ hidden_states = trunk # return hidden states and apply custom backward on the MTPTransformersLM
393
+ else:
394
+ latents = []
395
+ for i, layer in enumerate(prediction_heads_to_use):
396
+ if output_hidden_states:
397
+ all_hidden_states += (hidden_states,)
398
+
399
+ if self.gradient_checkpointing and self.training:
400
+ layer_outputs = self._gradient_checkpointing_func(
401
+ layer.__call__,
402
+ trunk, # Use trunk instead of hidden states
403
+ attention_mask,
404
+ past_key_values,
405
+ output_attentions,
406
+ use_cache,
407
+ **kwargs
408
+ )
409
+ else:
410
+ layer_outputs = layer(
411
+ trunk, # Use trunk instead of hidden states
412
+ attention_mask=attention_mask,
413
+ past_key_values=past_key_values,
414
+ output_attentions=output_attentions,
415
+ use_cache=use_cache,
416
+ **kwargs
417
+ )
418
+ hidden_states = layer_outputs[0]
419
+ latents.append(hidden_states)
420
+
421
+ if use_cache:
422
+ next_cache = layer_outputs[2 if output_attentions else 1]
423
+
424
+ if output_attentions:
425
+ all_attns += (layer_outputs[1],)
426
+
427
+ hidden_states = torch.stack(latents, dim=-2) # (B, T, n_heads_to_use, D)
428
+ hidden_states = self.norm(hidden_states)
429
+
430
+ # add hidden states from the last decoder layer
431
+ if output_hidden_states and not self.custom_backward:
432
+ all_hidden_states += (hidden_states,)
433
+
434
+ if not return_dict:
435
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
436
+
437
+ return BaseModelOutputWithPast(
438
+ last_hidden_state=hidden_states,
439
+ past_key_values=next_cache,
440
+ hidden_states=all_hidden_states,
441
+ attentions=all_attns
442
+ )
443
+
444
+
445
+ class MTPTransformerForCausalLM(MTPTransformerPreTrainedModel, GenerationMixin):
446
+
447
+ _tied_weights_keys = ["lm_head.weight"]
448
+
449
+ def __init__(self, config):
450
+ super().__init__(config)
451
+ self.model = MTPTransformerModel(config)
452
+ self.vocab_size = config.vocab_size
453
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
454
+ self.criterion = None
455
+ self.pad_token_id = config.pad_token_id
456
+
457
+ # Initialize weights and apply final processing
458
+ self.post_init()
459
+
460
+ def get_input_embeddings(self):
461
+ return self.model.embeddings
462
+
463
+ def set_input_embeddings(self, value):
464
+ self.model.embeddings = value
465
+
466
+ def get_output_embeddings(self):
467
+ return self.lm_head
468
+
469
+ def set_output_embeddings(self, new_embeddings):
470
+ self.lm_head = new_embeddings
471
+
472
+ def set_decoder(self, decoder):
473
+ self.model = decoder
474
+
475
+ def get_decoder(self):
476
+ return self.model
477
+
478
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
479
+ def prepare_inputs_for_generation(
480
+ self,
481
+ input_ids: torch.LongTensor = None,
482
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
483
+ attention_mask: Optional[torch.Tensor] = None,
484
+ inputs_embeds: Optional[torch.Tensor] = None,
485
+ use_cache: bool = True,
486
+ logits_to_keep: Optional[int] = None,
487
+ **kwargs
488
+ ):
489
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
490
+ if past_key_values is not None and len(past_key_values) > 0:
491
+ input_ids = input_ids[:, -1:]
492
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
493
+ if inputs_embeds is not None and len(past_key_values) == 0:
494
+ model_inputs = {'inputs_embeds': inputs_embeds}
495
+ else:
496
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
497
+ # recompiles graphs as the stride of the inputs is a guard.
498
+ # Ref: https://github.com/huggingface/transformers/pull/29114
499
+ # TODO: use `next_tokens` directly instead.
500
+ model_inputs = {'input_ids': input_ids.contiguous()}
501
+
502
+ if logits_to_keep is not None:
503
+ model_inputs['logits_to_keep'] = logits_to_keep
504
+
505
+ model_inputs.update({
506
+ 'past_key_values': past_key_values,
507
+ 'use_cache': use_cache,
508
+ 'attention_mask': attention_mask,
509
+ })
510
+ return model_inputs
511
+
512
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
513
+ def forward(
514
+ self,
515
+ input_ids: torch.LongTensor = None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
518
+ inputs_embeds: Optional[torch.FloatTensor] = None,
519
+ labels: Optional[torch.LongTensor] = None,
520
+ use_cache: Optional[bool] = None,
521
+ output_attentions: Optional[bool] = None,
522
+ output_hidden_states: Optional[bool] = None,
523
+ return_dict: Optional[bool] = None,
524
+ logits_to_keep: Optional[int] = 0,
525
+ **kwargs: Unpack[Any]
526
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
527
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
528
+ output_hidden_states = (
529
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
530
+ )
531
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
532
+
533
+ outputs = self.model(
534
+ input_ids=input_ids,
535
+ attention_mask=attention_mask,
536
+ past_key_values=past_key_values,
537
+ inputs_embeds=inputs_embeds,
538
+ use_cache=use_cache,
539
+ output_attentions=output_attentions,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ return_all_heads=self.training,
543
+ **kwargs
544
+ )
545
+
546
+ hidden_states = outputs[0] # (B, T, n_heads_to_use, D)
547
+
548
+ all_logits = None
549
+ if not self.training:
550
+ if hidden_states.dim() == 4:
551
+ hidden_states = hidden_states.squeeze(-2) # Remove the n_heads_to_use dimension if not training
552
+ all_logits = self.lm_head(hidden_states)
553
+ else:
554
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
555
+ use_custom_backward = self.config.use_custom_backward and self.training
556
+ if use_custom_backward:
557
+ all_logits = SequentialHeadsCustomBackward.apply(
558
+ hidden_states, self.lm_head, self.model.norm, logits_to_keep, *self.model.extra_heads
559
+ )
560
+ else:
561
+ all_logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
562
+
563
+ loss = None
564
+ if labels is not None:
565
+ B, T, n_heads_prediction, D = hidden_states.shape
566
+ loss = torch.zeros(1, device=hidden_states.device)
567
+ ntp_loss = torch.zeros(1, device=hidden_states.device)
568
+ mtp_loss = torch.zeros(1, device=hidden_states.device)
569
+ if getattr(self, 'criterion', None) is None:
570
+ if fuse_linear_and_cross_entropy:
571
+ criterion = FusedLinearCrossEntropyLoss()
572
+ elif self.config.fuse_cross_entropy:
573
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
574
+ else:
575
+ criterion = nn.CrossEntropyLoss()
576
+ else:
577
+ criterion = self.criterion
578
+ # Enable model parallelism
579
+ labels = labels.to(hidden_states.device)
580
+ all_labels = seq_to_mtp(labels, n_future_tokens=n_heads_prediction, model_seq_len=T)
581
+ # Loop across prediction heads
582
+ for i in range(n_heads_prediction):
583
+ # labels in the shape of (B, n_heads_prediction, T)
584
+ labels = all_labels[:, i, :]
585
+ if fuse_linear_and_cross_entropy:
586
+ current_loss = criterion(hidden_states[:, :, i, :], labels.contiguous(), self.lm_head.weight, self.lm_head.bias)
587
+ else:
588
+ logits = all_logits[:, :, i, :]
589
+ current_loss = criterion(logits.view(labels.numel(), -1), labels.reshape(-1))
590
+ if i == 0: # NTP
591
+ ntp_loss = current_loss
592
+ else:
593
+ mtp_loss += current_loss
594
+ loss += current_loss
595
+
596
+ if not return_dict:
597
+ output = (all_logits,) + outputs[1:]
598
+ return (loss,) + output if loss is not None else output
599
+
600
+ return MTPLMOutputWithPast(
601
+ loss=loss,
602
+ ntp_loss=ntp_loss if loss is not None else None,
603
+ mtp_loss=mtp_loss if loss is not None else None,
604
+ logits=all_logits,
605
+ past_key_values=outputs.past_key_values,
606
+ hidden_states=outputs.hidden_states,
607
+ attentions=outputs.attentions,
608
+ )
fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/modules/__pycache__/feature_map.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
logs/none_yagntt11/attempt_0/0/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef32c39ad6f7ca02c833bf4d4f8196743faf7f289a49798fb5b451327ea3b019
3
+ size 25537738
logs/none_yagntt11/attempt_0/1/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2b1dba31ba3e451fdd9b07d71992f0a323cc218972d1006358219cbe1b65db2
3
+ size 15389397
logs/none_yagntt11/attempt_0/2/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b06065c7a987a777cf34d02283c391ce448589bf548d0054c2fb2360d9bd0f84
3
+ size 15389394
logs/none_yagntt11/attempt_0/3/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2ae4d559f5ace303133096b8d857c05c92e5ba11ef04042d664753267a5c871
3
+ size 15448342
logs/none_yagntt11/attempt_0/4/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:898546e9d6c18959efdf0095393dcb46d73c602f8ff6ceb3acd6230bf9dc8198
3
+ size 15389392
logs/none_yagntt11/attempt_0/5/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb65879d3b844e10b4a4a278450b49fb28cdddea127970882da65bf2d008f18d
3
+ size 15389393
logs/none_yagntt11/attempt_0/5/stdout.log ADDED
File without changes
logs/none_yagntt11/attempt_0/6/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa0202f405108b85b5c3d3c21eb0311b7cbdb7503729d63d74b8b7260a96093
3
+ size 15389394
logs/none_yagntt11/attempt_0/7/stderr.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d4af7c40646947d1d997442de38fc8e1b37755f4c4347ba8b758bad92f9df4d
3
+ size 15389389
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14bf00f42c711dafca4cd11373c3e7eee50c53323ce810d9b4b4893e77c76b68
3
+ size 4989532648
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fcd72cf1fcb3823fb818e67d13418d894973df4a14cf4a8bc0af1cc9466c20c
3
+ size 2111988680
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6eff1add295e31de6ccda5b78a6d7949ea83e385c5400a7138cc4ee5c6078f7a
3
+ size 15411430
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3c4838163e23cde9bacb4a0b52d47fb6cad45de8a3efa9704359a68e119d037
3
+ size 265364709