zaydzuhri commited on
Commit
adbece6
·
verified ·
1 Parent(s): 8d737a7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/models/abc/configuration_abc.py +91 -0
  3. fla/models/forgetting_transformer/__init__.py +16 -0
  4. fla/models/gated_deltanet/__init__.py +12 -0
  5. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  6. fla/models/gated_deltaproduct/__init__.py +14 -0
  7. fla/models/gla/configuration_gla.py +95 -0
  8. fla/models/linear_attn/modeling_linear_attn.py +406 -0
  9. fla/models/mamba2/modeling_mamba2.py +1093 -0
  10. fla/models/rwkv6/__init__.py +13 -0
  11. fla/models/rwkv7/modeling_rwkv7.py +505 -0
  12. fla/models/samba/configuration_samba.py +92 -0
  13. fla/ops/common/__pycache__/__init__.cpython-312.pyc +0 -0
  14. fla/ops/common/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  15. fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  16. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  17. fla/ops/forgetting_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc +0 -0
  19. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  20. fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  21. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc +0 -0
  22. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  23. fla/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  24. fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc +0 -0
  25. fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
  26. fla/ops/generalized_delta_rule/iplr/chunk.py +528 -0
  27. fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc +0 -0
  28. fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  29. fla/ops/retention/__pycache__/chunk.cpython-312.pyc +0 -0
  30. fla/ops/retention/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  31. fla/ops/utils/__pycache__/logcumsumexp.cpython-312.pyc +0 -0
  32. fla/ops/utils/__pycache__/logsumexp.cpython-312.pyc +0 -0
  33. fla/ops/utils/__pycache__/pooling.cpython-312.pyc +0 -0
  34. profile_trace/iteration_111616/rank0_trace.json +0 -0
  35. profile_trace/iteration_111616/rank1_trace.json +0 -0
  36. profile_trace/iteration_111616/rank2_trace.json +0 -0
  37. profile_trace/iteration_111616/rank3_trace.json +0 -0
  38. profile_trace/iteration_111616/rank4_trace.json +0 -0
  39. profile_trace/iteration_111616/rank6_trace.json +0 -0
  40. profile_trace/iteration_111616/rank7_trace.json +0 -0
  41. profile_trace/iteration_112128/rank0_trace.json +0 -0
  42. profile_trace/iteration_112128/rank2_trace.json +0 -0
  43. profile_trace/iteration_112128/rank3_trace.json +0 -0
  44. profile_trace/iteration_112128/rank4_trace.json +0 -0
  45. profile_trace/iteration_112128/rank5_trace.json +0 -0
  46. profile_trace/iteration_112128/rank6_trace.json +0 -0
  47. profile_trace/iteration_112128/rank7_trace.json +0 -0
  48. profile_trace/iteration_121344/rank0_trace.json +0 -0
  49. profile_trace/iteration_121344/rank1_trace.json +0 -0
  50. profile_trace/iteration_121344/rank2_trace.json +0 -0
fla/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (3.07 kB). View file
 
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_rope: bool = True,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.hidden_size = hidden_size
47
+ self.gate_low_rank_dim = gate_low_rank_dim
48
+ self.clamp_min = clamp_min
49
+ self.clamp_max = clamp_max
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_slots = num_slots
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.expand_k = exapnd_k
58
+ self.expand_v = exapnd_v
59
+ self.hidden_act = hidden_act
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.elementwise_affine = elementwise_affine
62
+ self.norm_eps = norm_eps
63
+ self.use_rope = use_rope
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/forgetting_transformer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
6
+ from fla.models.forgetting_transformer.modeling_forgetting_transformer import (
7
+ ForgettingTransformerForCausalLM,
8
+ ForgettingTransformerModel
9
+ )
10
+
11
+ AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig)
12
+ AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel)
13
+ AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM)
14
+
15
+
16
+ __all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel']
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
+
3
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel
5
+
6
+ AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig)
7
+ AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel)
8
+ AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM)
9
+
10
+ __all__ = [
11
+ "GatedDeltaProductConfig",
12
+ "GatedDeltaProductForCausalLM",
13
+ "GatedDeltaProductModel",
14
+ ]
fla/models/gla/configuration_gla.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GLAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gla'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ expand_k: int = 0.5,
17
+ expand_v: int = 1,
18
+ hidden_ratio: Optional[int] = 4,
19
+ intermediate_size: Optional[int] = None,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 4,
22
+ num_kv_heads: Optional[int] = None,
23
+ feature_map: Optional[str] = None,
24
+ attn_mode: str = "chunk",
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ use_output_gate: bool = True,
28
+ clamp_min: Optional[float] = None,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ use_gk: bool = True,
34
+ use_gv: bool = False,
35
+ attn: Optional[Dict] = None,
36
+ use_cache: bool = True,
37
+ pad_token_id: int = None,
38
+ bos_token_id: int = 1,
39
+ eos_token_id: int = 2,
40
+ tie_word_embeddings: bool = False,
41
+ initializer_range: float = 0.006,
42
+ fuse_norm: bool = True,
43
+ fuse_swiglu: bool = True,
44
+ fuse_cross_entropy: bool = True,
45
+ vocab_size: int = 32000,
46
+ **kwargs
47
+ ):
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.attn_mode = attn_mode
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.clamp_min = clamp_min
62
+ self.hidden_act = hidden_act
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.elementwise_affine = elementwise_affine
65
+ self.norm_eps = norm_eps
66
+ self.use_gk = use_gk
67
+ self.use_gv = use_gv
68
+ self.attn = attn
69
+ self.use_cache = use_cache
70
+ self.initializer_range = initializer_range
71
+
72
+ self.fuse_norm = fuse_norm
73
+ self.fuse_swiglu = fuse_swiglu
74
+ self.fuse_cross_entropy = fuse_cross_entropy
75
+ self.vocab_size = vocab_size
76
+
77
+ if attn is not None:
78
+ if not isinstance(attn, Dict):
79
+ raise ValueError("attn must be a dictionary")
80
+ if 'layers' not in attn:
81
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
82
+ if 'num_heads' not in attn:
83
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
84
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
85
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
86
+ attn['window_size'] = attn.get('window_size', None)
87
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
88
+
89
+ super().__init__(
90
+ pad_token_id=pad_token_id,
91
+ bos_token_id=bos_token_id,
92
+ eos_token_id=eos_token_id,
93
+ tie_word_embeddings=tie_word_embeddings,
94
+ **kwargs,
95
+ )
fla/models/linear_attn/modeling_linear_attn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.linear_attn import LinearAttention
20
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LinearAttentionMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class LinearAttentionBlock(nn.Module):
30
+ def __init__(self, config: LinearAttentionConfig, layer_idx: int):
31
+ super().__init__()
32
+
33
+ self.config = config
34
+ self.layer_idx = layer_idx
35
+
36
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
37
+ if config.attn is not None and layer_idx in config.attn['layers']:
38
+ self.attn = Attention(
39
+ hidden_size=config.hidden_size,
40
+ num_heads=config.attn['num_heads'],
41
+ num_kv_heads=config.attn['num_kv_heads'],
42
+ qkv_bias=config.attn['qkv_bias'],
43
+ window_size=config.attn['window_size'],
44
+ rope_theta=config.attn['rope_theta'],
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ layer_idx=layer_idx
47
+ )
48
+ else:
49
+ self.attn = LinearAttention(
50
+ mode=config.attn_mode,
51
+ hidden_size=config.hidden_size,
52
+ expand_k=config.expand_k,
53
+ expand_v=config.expand_v,
54
+ num_heads=config.num_heads,
55
+ num_kv_heads=config.num_kv_heads,
56
+ feature_map=config.feature_map,
57
+ tie_feature_map_qk=config.tie_feature_map_qk,
58
+ norm_q=config.norm_q,
59
+ norm_k=config.norm_k,
60
+ do_feature_map_norm=config.norm_feature_map,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = LinearAttentionMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs,
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ # currently not supported
85
+ attentions, past_key_values = None, None
86
+ hidden_states = self.attn_norm(hidden_states)
87
+ hidden_states = self.attn(hidden_states=hidden_states, **kwargs)
88
+ if self.config.fuse_norm:
89
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
90
+ else:
91
+ hidden_states = residual + hidden_states
92
+ residual = hidden_states
93
+ hidden_states = self.mlp_norm(hidden_states)
94
+ hidden_states = self.mlp(hidden_states, **kwargs)
95
+ hidden_states = residual + hidden_states
96
+
97
+ outputs = (hidden_states, attentions, past_key_values)
98
+
99
+ return outputs
100
+
101
+
102
+ class LinearAttentionPreTrainedModel(PreTrainedModel):
103
+
104
+ config_class = LinearAttentionConfig
105
+ base_model_prefix = 'model'
106
+ supports_gradient_checkpointing = True
107
+ _no_split_modules = ['LinearAttentionBlock']
108
+ _supports_cache_class = True
109
+
110
+ def __init__(self, *inputs, **kwargs):
111
+ super().__init__(*inputs, **kwargs)
112
+
113
+ def _init_weights(
114
+ self,
115
+ module: nn.Module,
116
+ prenorm_residual_strategy: Optional[str] = 'rescale',
117
+ num_residuals_per_layer: int = 2,
118
+ ):
119
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
120
+ # Slightly different from the TF version which uses truncated_normal for initialization
121
+ # cf https://github.com/pytorch/pytorch/pull/5617
122
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ elif isinstance(module, nn.Embedding):
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ elif hasattr(module, 'reset_parameters'):
128
+ module.reset_parameters()
129
+
130
+ if prenorm_residual_strategy is not None:
131
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
132
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
133
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
134
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
135
+ #
136
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
137
+ p = None
138
+ if hasattr(module, 'o_proj'):
139
+ p = module.o_proj.weight
140
+ elif hasattr(module, 'down_proj'):
141
+ p = module.down_proj.weight
142
+ if p is not None:
143
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
144
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
145
+ # We need to reinit p since this code could be called multiple times
146
+ # Having just p *= scale would repeatedly scale it down
147
+ if prenorm_residual_strategy == 'rescale':
148
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
149
+ with torch.no_grad():
150
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
151
+ elif prenorm_residual_strategy == 'zero':
152
+ nn.init.zeros_(p)
153
+ else:
154
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
155
+
156
+
157
+ class LinearAttentionModel(LinearAttentionPreTrainedModel):
158
+
159
+ def __init__(self, config: LinearAttentionConfig):
160
+ super().__init__(config)
161
+ self.padding_idx = config.pad_token_id
162
+ self.vocab_size = config.vocab_size
163
+
164
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
165
+ self.layers = nn.ModuleList([LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None, # noqa
182
+ inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None
188
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
189
+ if output_attentions:
190
+ warnings.warn(
191
+ "`LinearAttentionModel` does not support output attention weights now, "
192
+ "so `output_attentions` is set to `False`."
193
+ )
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ )
233
+ else:
234
+ hidden_states, attentions, past_key_values = layer(
235
+ hidden_states,
236
+ attention_mask=attention_mask,
237
+ past_key_values=past_key_values,
238
+ use_cache=use_cache,
239
+ output_attentions=output_attentions
240
+ )
241
+
242
+ if output_attentions:
243
+ all_attns += (attentions,)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ if not return_dict:
252
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
253
+ return BaseModelOutputWithPast(
254
+ last_hidden_state=hidden_states,
255
+ past_key_values=past_key_values,
256
+ hidden_states=all_hidden_states,
257
+ attentions=all_attns
258
+ )
259
+
260
+
261
+ class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel, GenerationMixin):
262
+
263
+ _tied_weights_keys = ["lm_head.weight"]
264
+
265
+ def __init__(self, config):
266
+ super().__init__(config)
267
+ self.model = LinearAttentionModel(config)
268
+ self.vocab_size = config.vocab_size
269
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
270
+ self.criterion = None
271
+
272
+ # Initialize weights and apply final processing
273
+ self.post_init()
274
+
275
+ def get_input_embeddings(self):
276
+ return self.model.embeddings
277
+
278
+ def set_input_embeddings(self, value):
279
+ self.model.embeddings = value
280
+
281
+ def get_output_embeddings(self):
282
+ return self.lm_head
283
+
284
+ def set_output_embeddings(self, new_embeddings):
285
+ self.lm_head = new_embeddings
286
+
287
+ def set_decoder(self, decoder):
288
+ self.model = decoder
289
+
290
+ def get_decoder(self):
291
+ return self.model
292
+
293
+ def generate(self, *args, **kwargs):
294
+ try:
295
+ return super().generate(*args, **kwargs)
296
+ except AttributeError as exception:
297
+ if 'past_key_values' in str(exception):
298
+ raise AttributeError(
299
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
300
+ f"which is not supported for {self.__class__.__name__}. "
301
+ f"Try another generation strategy instead. "
302
+ f"For the available generation strategies, check this doc: "
303
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
304
+ )
305
+ else:
306
+ raise exception
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ inputs_embeds: Optional[torch.Tensor] = None,
348
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0
355
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
361
+
362
+ outputs = self.model(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ inputs_embeds=inputs_embeds,
366
+ past_key_values=past_key_values,
367
+ use_cache=use_cache,
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict
371
+ )
372
+
373
+ hidden_states = outputs[0]
374
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
375
+
376
+ loss, logits = None, None
377
+ if not fuse_linear_and_cross_entropy or labels is None:
378
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
379
+ if labels is not None:
380
+ if getattr(self, 'criterion', None) is None:
381
+ if fuse_linear_and_cross_entropy:
382
+ criterion = FusedLinearCrossEntropyLoss()
383
+ elif self.config.fuse_cross_entropy:
384
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
385
+ else:
386
+ criterion = nn.CrossEntropyLoss()
387
+ else:
388
+ criterion = self.criterion
389
+ labels = labels.to(hidden_states.device)
390
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
391
+ if fuse_linear_and_cross_entropy:
392
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
393
+ else:
394
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
395
+
396
+ if not return_dict:
397
+ output = (logits,) + outputs[1:]
398
+ return (loss,) + output if loss is not None else output
399
+
400
+ return CausalLMOutputWithPast(
401
+ loss=loss,
402
+ logits=logits,
403
+ past_key_values=outputs.past_key_values,
404
+ hidden_states=outputs.hidden_states,
405
+ attentions=outputs.attentions,
406
+ )
fla/models/mamba2/modeling_mamba2.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch MAMBA2 model."""
15
+
16
+ import math
17
+ import warnings
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from transformers.activations import ACT2FN
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import ModelOutput, logging
28
+ from transformers.utils.deprecation import deprecate_kwarg
29
+
30
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
31
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
32
+ from fla.modules.layernorm_gated import RMSNormGated
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ with warnings.catch_warnings():
37
+ warnings.simplefilter('ignore')
38
+ try:
39
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
40
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
41
+ except ImportError:
42
+ (
43
+ selective_state_update,
44
+ mamba_chunk_scan_combined,
45
+ mamba_split_conv1d_scan_combined,
46
+ ) = (None, None, None)
47
+ try:
48
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
49
+ except ImportError:
50
+ causal_conv1d_update, causal_conv1d_fn = None, None
51
+ is_fast_path_available = all((
52
+ selective_state_update,
53
+ causal_conv1d_fn,
54
+ causal_conv1d_update
55
+ ))
56
+
57
+
58
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
59
+ """
60
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
61
+
62
+ Assumes that we only have tensors of either size 4 or 3
63
+ """
64
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
65
+
66
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
67
+
68
+
69
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
70
+ """
71
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
72
+ simultaneously splitting it into chunk sequences.
73
+
74
+ Assumes that we only have tensors of either size 4 or 3
75
+ """
76
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
77
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
78
+
79
+ if len(input_tensor.shape) == 3:
80
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
81
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
82
+ else:
83
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] ->
84
+ # [bsz, -1, chunk_size, num_heads, head_dim or state_size]
85
+ return input_tensor.reshape(
86
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
87
+ )
88
+
89
+
90
+ def segment_sum(input_tensor):
91
+ """
92
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
93
+ """
94
+ chunk_size = input_tensor.size(-1)
95
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
96
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
97
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
98
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
99
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
100
+ input_tensor = input_tensor.masked_fill(~mask, 0)
101
+ # 3. compute actual cumsum
102
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
103
+
104
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
105
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
106
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
107
+ return tensor_segsum
108
+
109
+
110
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
111
+ """
112
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
113
+ """
114
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
115
+ dtype = hidden_states.dtype
116
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
117
+
118
+ return hidden_states
119
+
120
+
121
+ class Mamba2Cache:
122
+ """
123
+ Arguments:
124
+ config: Mamba2Config
125
+ batch_size: int
126
+ dtype: torch.dtype
127
+ device: torch.device
128
+
129
+ Attributes:
130
+ dtype: (`torch.dtype`):
131
+ The default `dtype` used to initializing the cache.
132
+ conv_kernel_size: (`int`):
133
+ Model's convolution kernel size taken from config.
134
+ n_groups: (`int`):
135
+ Model's number of groups taken from the config - similar to tensor parallel in Transformer.
136
+ state_size: (`int`):
137
+ Model's SSM state size taken from config.
138
+ num_heads: (`int`):
139
+ The number of heads used in the linear attention / SSM.
140
+ head_dim: (`int`):
141
+ The respective dimension of the heads used in the linear attention / SSM.
142
+ intermediate_size: (`int`):
143
+ Model's intermediate_size based on (expand * hidden_dim) from config.
144
+ conv_states: (`torch.Tensor`):
145
+ A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]`
146
+ that holds convolutional states.
147
+ ssm_states: (`torch.Tensor`):
148
+ A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ config: Mamba2Config,
154
+ batch_size: int,
155
+ dtype: torch.dtype = torch.float16,
156
+ device: Optional[str] = None,
157
+ ):
158
+ self.dtype = dtype
159
+ self.conv_kernel_size = config.conv_kernel
160
+ self.n_groups = config.n_groups
161
+ self.state_size = config.state_size
162
+ self.num_heads = config.num_heads
163
+ self.head_dim = config.head_dim
164
+ self.intermediate_size = int(config.expand * config.hidden_size)
165
+
166
+ self.conv_states = torch.zeros(
167
+ config.num_hidden_layers,
168
+ batch_size,
169
+ self.intermediate_size + 2 * self.n_groups * self.state_size,
170
+ self.conv_kernel_size,
171
+ device=device,
172
+ dtype=dtype,
173
+ )
174
+ self.ssm_states = torch.zeros(
175
+ config.num_hidden_layers,
176
+ batch_size,
177
+ self.num_heads,
178
+ self.head_dim,
179
+ self.state_size,
180
+ device=device,
181
+ dtype=dtype,
182
+ )
183
+
184
+ def update_conv_state(
185
+ self,
186
+ layer_idx: int,
187
+ new_conv_state: torch.Tensor,
188
+ cache_init: bool = False
189
+ ) -> torch.Tensor:
190
+ if cache_init:
191
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
192
+ else:
193
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
194
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
195
+ return self.conv_states[layer_idx]
196
+
197
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
198
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
199
+ return self.ssm_states[layer_idx]
200
+
201
+ def reset(self):
202
+ self.conv_states.zero_()
203
+ self.ssm_states.zero_()
204
+
205
+
206
+ class Mamba2Mixer(nn.Module):
207
+ """
208
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
209
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
210
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
211
+ and is why Mamba is called **selective** state spaces)
212
+ """
213
+
214
+ def __init__(self, config: Mamba2Config, layer_idx: int):
215
+ super().__init__()
216
+ self.num_heads = config.num_heads
217
+ self.hidden_size = config.hidden_size
218
+ self.ssm_state_size = config.state_size
219
+ self.conv_kernel_size = config.conv_kernel
220
+ self.intermediate_size = int(config.expand * self.hidden_size)
221
+ self.time_step_rank = int(config.time_step_rank)
222
+ self.layer_idx = layer_idx
223
+ self.use_conv_bias = config.use_conv_bias
224
+ self.activation = config.hidden_act
225
+ self.act = ACT2FN[config.hidden_act]
226
+
227
+ self.layer_norm_epsilon = config.layer_norm_epsilon
228
+ self.rms_norm = config.rms_norm
229
+
230
+ self.n_groups = config.n_groups
231
+ self.head_dim = config.head_dim
232
+ self.chunk_size = config.chunk_size
233
+
234
+ self.time_step_limit = config.time_step_limit
235
+ self.time_step_min = config.time_step_min
236
+ self.time_step_max = config.time_step_max
237
+
238
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
239
+ self.conv1d = nn.Conv1d(
240
+ in_channels=self.conv_dim,
241
+ out_channels=self.conv_dim,
242
+ bias=config.use_conv_bias,
243
+ kernel_size=config.conv_kernel,
244
+ groups=self.conv_dim,
245
+ padding=config.conv_kernel - 1,
246
+ )
247
+
248
+ # projection of the input hidden states
249
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
250
+ self.in_proj = nn.Linear(
251
+ self.hidden_size,
252
+ projection_size,
253
+ bias=config.use_bias,
254
+ )
255
+ # selective projection used to make dt, B and C input dependant
256
+
257
+ # time step projection (discretization)
258
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
259
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
260
+
261
+ # S4D real initialization. These are not discretized!
262
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
263
+ A = torch.arange(1, self.num_heads + 1)
264
+ self.A_log = nn.Parameter(torch.log(A))
265
+ self.A_log._no_weight_decay = True
266
+ self.norm = RMSNormGated(
267
+ self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=False
268
+ )
269
+ self.D = nn.Parameter(torch.ones(self.num_heads))
270
+ self.D._no_weight_decay = True
271
+
272
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
273
+ self.use_bias = config.use_bias
274
+
275
+ if not is_fast_path_available:
276
+ logger.warning_once(
277
+ "The fast path is not available because one of "
278
+ "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. "
279
+ "Falling back to the naive implementation. "
280
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
281
+ "https://github.com/Dao-AILab/causal-conv1d"
282
+ )
283
+
284
+ def cuda_kernels_forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ cache_params: Optional[Mamba2Cache] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ attention_mask: Optional[torch.Tensor] = None,
290
+ ):
291
+ # 1. Gated MLP's linear projection
292
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
293
+ projected_states = self.in_proj(hidden_states)
294
+
295
+ # Set up dimensions for reshapes later
296
+ batch_size, seq_len, _ = hidden_states.shape
297
+ groups_time_state_size = self.n_groups * self.ssm_state_size
298
+ d_mlp = (
299
+ projected_states.shape[-1]
300
+ - 2 * self.intermediate_size
301
+ - 2 * self.n_groups * self.ssm_state_size
302
+ - self.num_heads
303
+ ) // 2
304
+
305
+ # Single step calculations via cache
306
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
307
+ _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
308
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
309
+ )
310
+
311
+ # 2. Convolution sequence transformation
312
+ hidden_states_B_C = causal_conv1d_update(
313
+ hidden_states_B_C,
314
+ cache_params.conv_states[self.layer_idx],
315
+ self.conv1d.weight.squeeze(1),
316
+ self.conv1d.bias,
317
+ self.activation,
318
+ )
319
+
320
+ hidden_states, B, C = torch.split(
321
+ hidden_states_B_C,
322
+ [
323
+ self.intermediate_size,
324
+ groups_time_state_size,
325
+ groups_time_state_size,
326
+ ],
327
+ dim=-1,
328
+ )
329
+
330
+ # 3. SSM transformation
331
+ A = -torch.exp(self.A_log.float()) # (nheads,)
332
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
333
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
334
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
335
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
336
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
337
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
338
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
339
+
340
+ hidden_states = selective_state_update(
341
+ cache_params.ssm_states[self.layer_idx],
342
+ hidden_states_reshaped,
343
+ dt,
344
+ A,
345
+ B,
346
+ C,
347
+ D,
348
+ z=None,
349
+ dt_bias=dt_bias,
350
+ dt_softplus=True,
351
+ )
352
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
353
+ hidden_states = self.norm(hidden_states, gate)
354
+
355
+ # 4. Final linear projection
356
+ out = self.out_proj(hidden_states)[:, None, ...]
357
+
358
+ # Fused calculations or step by step if no initialized cache is found
359
+ else:
360
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
361
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
362
+
363
+ # 2-4. Fused kernel for conv1d, SSM, and the final projection
364
+ if self.training and cache_params is None:
365
+ out = mamba_split_conv1d_scan_combined(
366
+ projected_states,
367
+ self.conv1d.weight.squeeze(1),
368
+ self.conv1d.bias,
369
+ self.dt_bias,
370
+ A,
371
+ D=self.D,
372
+ chunk_size=self.chunk_size,
373
+ seq_idx=None, # was seq_idx
374
+ activation=self.activation,
375
+ rmsnorm_weight=self.norm.weight,
376
+ rmsnorm_eps=self.norm.eps,
377
+ outproj_weight=self.out_proj.weight,
378
+ outproj_bias=self.out_proj.bias,
379
+ headdim=self.head_dim,
380
+ ngroups=self.n_groups,
381
+ norm_before_gate=False,
382
+ return_final_states=False,
383
+ **dt_limit_kwargs,
384
+ )
385
+
386
+ else:
387
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
388
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
389
+ )
390
+
391
+ # 2. Convolution sequence transformation
392
+ # Init cache
393
+ if cache_params is not None:
394
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
395
+ conv_states = nn.functional.pad(
396
+ hidden_states_B_C_transposed,
397
+ (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
398
+ )
399
+ cache_params.update_conv_state(
400
+ layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
401
+ )
402
+
403
+ if self.activation not in ["silu", "swish"]:
404
+ hidden_states_B_C = self.act(
405
+ self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
406
+ )
407
+ else:
408
+ hidden_states_B_C = causal_conv1d_fn(
409
+ x=hidden_states_B_C.transpose(1, 2),
410
+ weight=self.conv1d.weight.squeeze(1),
411
+ bias=self.conv1d.bias,
412
+ activation=self.activation,
413
+ ).transpose(1, 2)
414
+
415
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
416
+ hidden_states, B, C = torch.split(
417
+ hidden_states_B_C,
418
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
419
+ dim=-1,
420
+ )
421
+
422
+ # 3. SSM transformation
423
+ scan_output, ssm_state = mamba_chunk_scan_combined(
424
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
425
+ dt,
426
+ A,
427
+ B.view(batch_size, seq_len, self.n_groups, -1),
428
+ C.view(batch_size, seq_len, self.n_groups, -1),
429
+ chunk_size=self.chunk_size,
430
+ D=self.D,
431
+ z=None,
432
+ seq_idx=None,
433
+ return_final_states=True,
434
+ dt_bias=self.dt_bias,
435
+ dt_softplus=True,
436
+ **dt_limit_kwargs,
437
+ )
438
+
439
+ # Init cache
440
+ if ssm_state is not None and cache_params is not None:
441
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
442
+
443
+ scan_output = scan_output.view(batch_size, seq_len, -1)
444
+ # Multiply "gate" branch and apply extra normalization layer
445
+ scan_output = self.norm(scan_output, gate)
446
+
447
+ # 4. Final linear projection
448
+ out = self.out_proj(scan_output)
449
+ return out
450
+
451
+ # fmt: off
452
+ def torch_forward(
453
+ self,
454
+ input_states,
455
+ cache_params: Optional[Mamba2Cache] = None,
456
+ cache_position: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None
458
+ ):
459
+ batch_size, seq_len, _ = input_states.shape
460
+ dtype = input_states.dtype
461
+
462
+ # 1. Gated MLP's linear projection
463
+ input_states = apply_mask_to_padding_states(input_states, attention_mask)
464
+ projected_states = self.in_proj(input_states)
465
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size -
466
+ 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2
467
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
468
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
469
+ )
470
+
471
+ # 2. Convolution sequence transformation
472
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
473
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
474
+
475
+ # We need to guarantee that anything regarding the cache is on the same device
476
+ conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
477
+
478
+ hidden_states_B_C = torch.sum(
479
+ conv_states * self.conv1d.weight.squeeze(1), dim=-1
480
+ )
481
+ if self.use_conv_bias:
482
+ hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
483
+ hidden_states_B_C = self.act(hidden_states_B_C)
484
+ else:
485
+ # Init cache
486
+ if cache_params is not None:
487
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
488
+ conv_states = nn.functional.pad(
489
+ hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
490
+ )
491
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
492
+
493
+ hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
494
+
495
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
496
+ hidden_states, B, C = torch.split(
497
+ hidden_states_B_C,
498
+ [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
499
+ dim=-1
500
+ )
501
+
502
+ # 3. SSM transformation
503
+ A = -torch.exp(self.A_log.float()) # [num_heads]
504
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
505
+ # We need to guarantee that anything regarding the cache is on the same device
506
+ cache_device = cache_params.ssm_states.device
507
+
508
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
509
+ # for batched generation
510
+ dt = dt[:, 0, :][:, None, ...]
511
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
512
+ # [num_heads] -> [num_heads, head_dim]
513
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
514
+
515
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
516
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
517
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
518
+ # [bsz, num_heads, head_dim, state_size]
519
+ dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
520
+
521
+ # Discretize B
522
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
523
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
524
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
525
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
526
+ B = B.reshape(batch_size, -1, B.shape[-1])
527
+ # [bsz, num_heads, head_dim, state_size]
528
+ dB = dt[..., None] * B[..., None, :]
529
+
530
+ # Discretize x into dB
531
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
532
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
533
+ dBx = (dB * hidden_states[..., None]).to(device=cache_device)
534
+
535
+ # State calculation
536
+ cache_params.update_ssm_state(
537
+ layer_idx=self.layer_idx,
538
+ new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
539
+ )
540
+
541
+ # Subsequent output
542
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
543
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
544
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
545
+ C = C.reshape(batch_size, -1, C.shape[-1])
546
+ # [bsz, num_heads, head_dim]
547
+
548
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
549
+ # Reshape ssm_states to merge the first two dimensions
550
+ # Shape: [b*h, d, n]
551
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size)
552
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
553
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
554
+ y = y.view(batch_size, self.num_heads, self.head_dim)
555
+
556
+ # D skip connection
557
+ # [num_heads] -> [num_heads, head_dim]
558
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
559
+ y = (y + hidden_states * D).to(y.dtype)
560
+
561
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
562
+ y = y.reshape(batch_size, -1)[:, None, ...]
563
+ else:
564
+ # begin ssd naive implementation without einsums
565
+ dt = nn.functional.softplus(dt + self.dt_bias)
566
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
567
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
568
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
569
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
570
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
571
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
572
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
573
+
574
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
575
+
576
+ # Discretize x and A
577
+ hidden_states = hidden_states * dt[..., None]
578
+ A = A.to(hidden_states.dtype) * dt
579
+
580
+ # Rearrange into blocks/chunks
581
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
582
+
583
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
584
+ A = A.permute(0, 3, 1, 2)
585
+ A_cumsum = torch.cumsum(A, dim=-1)
586
+
587
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
588
+ # This is the analog of a causal mask
589
+ L = torch.exp(segment_sum(A))
590
+
591
+ # Contraction of C and B to get G (attention-weights like)
592
+ # shape: (b, c, l, s, h, n)
593
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
594
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
595
+
596
+ # Compute M, equivalent to applying attention mask to weights
597
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
598
+ M = M_intermediate.sum(dim=-1)
599
+
600
+ # Compute Y_diag (apply to values)
601
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
602
+
603
+ # 2. Compute the state for each intra-chunk
604
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
605
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
606
+ B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
607
+ states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
608
+
609
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
610
+ # (middle term of factorization of off-diag blocks; A terms)
611
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
612
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
613
+ else:
614
+ previous_states = torch.zeros_like(states[:, :1])
615
+ states = torch.cat([previous_states, states], dim=1)
616
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
617
+ decay_chunk = decay_chunk.transpose(1, 3)
618
+ new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
619
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
620
+
621
+ # 4. Compute state -> output conversion per chunk
622
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
623
+ state_decay_out = torch.exp(A_cumsum)
624
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
625
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
626
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
627
+
628
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
629
+ y = Y_diag + Y_off
630
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
631
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
632
+
633
+ y = y + D_residual
634
+ # Cutting off padded chunks
635
+ if pad_size > 0:
636
+ y = y[:, :seq_len, :, :]
637
+ y = y.reshape(batch_size, seq_len, -1)
638
+
639
+ # Init cache
640
+ if ssm_state is not None and cache_params is not None:
641
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
642
+
643
+ scan_output = self.norm(y, gate)
644
+
645
+ # end ssd naive
646
+
647
+ # 4. Final linear projection
648
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
649
+ return contextualized_states
650
+ # fmt: on
651
+
652
+ def forward(
653
+ self,
654
+ hidden_states,
655
+ cache_params: Optional[Mamba2Cache] = None,
656
+ cache_position: Optional[torch.LongTensor] = None,
657
+ attention_mask: Optional[torch.Tensor] = None,
658
+ ):
659
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
660
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
661
+ dtype = hidden_states.dtype
662
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
663
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
664
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
665
+
666
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
667
+
668
+
669
+ class Mamba2Block(nn.Module):
670
+ def __init__(self, config, layer_idx):
671
+ super().__init__()
672
+ self.config = config
673
+ self.layer_idx = layer_idx
674
+ self.residual_in_fp32 = config.residual_in_fp32
675
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
676
+ self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states,
681
+ cache_params: Optional[Mamba2Cache] = None,
682
+ cache_position: Optional[torch.LongTensor] = None,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ ):
685
+ residual = hidden_states
686
+ hidden_states = self.norm(hidden_states)
687
+ if self.residual_in_fp32:
688
+ residual = residual.to(torch.float32)
689
+
690
+ hidden_states = self.mixer(
691
+ hidden_states,
692
+ cache_params=cache_params,
693
+ cache_position=cache_position,
694
+ attention_mask=attention_mask,
695
+ )
696
+ hidden_states = residual + hidden_states
697
+ if self.residual_in_fp32:
698
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
699
+ return hidden_states
700
+
701
+
702
+ class Mamba2PreTrainedModel(PreTrainedModel, GenerationMixin):
703
+ """
704
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
705
+ models.
706
+ """
707
+
708
+ config_class = Mamba2Config
709
+ base_model_prefix = "backbone"
710
+ _no_split_modules = ["Mamba2Block"]
711
+ supports_gradient_checkpointing = True
712
+ _is_stateful = True
713
+
714
+ def _init_weights(
715
+ self,
716
+ module: nn.Module,
717
+ num_residuals_per_layer: int = 1,
718
+ ):
719
+ """Initialize the weights."""
720
+ if isinstance(module, Mamba2Mixer):
721
+
722
+ # --- A_log ---
723
+ A = torch.arange(1, module.num_heads + 1)
724
+ with torch.no_grad():
725
+ if not isinstance(module.A_log, torch.distributed.tensor.DTensor):
726
+ module.A_log.copy_(torch.log(A))
727
+ else:
728
+ logger.warning_once("`A_log` is a DTensor, skipping initialization")
729
+ module.A_log._no_weight_decay = True
730
+
731
+ # --- D ---
732
+ nn.init.ones_(module.D)
733
+ module.D._no_weight_decay = True
734
+
735
+ # --- dt_bias ---
736
+ dt = torch.exp(
737
+ torch.rand(self.config.num_heads)
738
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
739
+ + math.log(self.config.time_step_min)
740
+ ).clamp(min=self.config.time_step_floor)
741
+
742
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
743
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
744
+ with torch.no_grad():
745
+ if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor):
746
+ module.dt_bias.copy_(inv_dt)
747
+ else:
748
+ logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
749
+ module.dt_bias._no_reinit = True
750
+
751
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
752
+ # Slightly different from the TF version which uses truncated_normal for initialization
753
+ # cf https://github.com/pytorch/pytorch/pull/5617
754
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
755
+ if module.bias is not None:
756
+ nn.init.zeros_(module.bias)
757
+ # guard against deprecated behavior
758
+ if hasattr(module.bias, "_no_reinit"):
759
+ raise ValueError("This is not supposed to happen")
760
+ elif isinstance(module, nn.Embedding):
761
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
762
+ elif hasattr(module, 'reset_parameters'):
763
+ module.reset_parameters()
764
+
765
+ if self.config.rescale_prenorm_residual:
766
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
767
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
768
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
769
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
770
+ #
771
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
772
+ p = None
773
+ if hasattr(module, 'o_proj'):
774
+ # p = module.o_proj.weight
775
+ # guard against deprecated behavior
776
+ raise ValueError("This is not supposed to happen")
777
+ elif hasattr(module, 'out_proj'):
778
+ p = module.out_proj.weight
779
+ elif hasattr(module, 'down_proj'):
780
+ p = module.down_proj.weight
781
+ if p is not None:
782
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
783
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
784
+ # We need to reinit p since this code could be called multiple times
785
+ # Having just p *= scale would repeatedly scale it down
786
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
787
+ with torch.no_grad():
788
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
789
+
790
+
791
+ @dataclass
792
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
793
+ class Mamba2Output(ModelOutput):
794
+ """
795
+ Class for the MAMBA2 model outputs.
796
+
797
+ Args:
798
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
799
+ Sequence of hidden-states at the output of the last layer of the model.
800
+ cache_params (`Mamba2Cache`):
801
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
802
+ avoid providing the old `input_ids`.
803
+
804
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
805
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
806
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
807
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
808
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
809
+
810
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
811
+ """
812
+
813
+ last_hidden_state: Optional[torch.FloatTensor] = None
814
+ cache_params: Optional[Mamba2Cache] = None
815
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+
818
+ @dataclass
819
+ # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
820
+ class Mamba2CausalLMOutput(ModelOutput):
821
+ """
822
+ Base class for causal language model (or autoregressive) outputs.
823
+
824
+ Args:
825
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
826
+ Language modeling loss (for next-token prediction).
827
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
828
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
829
+ cache_params (`Mamba2Cache`):
830
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
831
+ avoid providing the old `input_ids`.
832
+
833
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
834
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
835
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
836
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
837
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
838
+
839
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
840
+ """
841
+
842
+ loss: Optional[torch.FloatTensor] = None
843
+ logits: Optional[torch.FloatTensor] = None
844
+ cache_params: Optional[Mamba2Cache] = None
845
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
846
+
847
+
848
+ class Mamba2Model(Mamba2PreTrainedModel):
849
+ def __init__(self, config):
850
+ super().__init__(config)
851
+
852
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
853
+ self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
854
+
855
+ self.gradient_checkpointing = False
856
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
857
+ # Initialize weights and apply final processing
858
+ self._register_load_state_dict_pre_hook(self.load_hook)
859
+ self.post_init()
860
+
861
+ def load_hook(self, state_dict, prefix, *args):
862
+ for k in state_dict:
863
+ if "embedding." in k:
864
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
865
+ break
866
+
867
+ def get_input_embeddings(self):
868
+ return self.embeddings
869
+
870
+ def set_input_embeddings(self, new_embeddings):
871
+ self.embeddings = new_embeddings
872
+
873
+ def forward(
874
+ self,
875
+ input_ids: Optional[torch.LongTensor] = None,
876
+ inputs_embeds: Optional[torch.LongTensor] = None,
877
+ cache_params: Optional[Mamba2Cache] = None,
878
+ use_cache: Optional[bool] = None,
879
+ output_hidden_states: Optional[bool] = None,
880
+ return_dict: Optional[bool] = None,
881
+ cache_position: Optional[torch.LongTensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ **kwargs,
884
+ ) -> Union[Tuple, Mamba2Output]:
885
+ output_hidden_states = (
886
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
887
+ )
888
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
889
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
890
+
891
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
892
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
893
+
894
+ if inputs_embeds is None:
895
+ inputs_embeds = self.embeddings(input_ids)
896
+
897
+ if self.gradient_checkpointing and self.training and use_cache:
898
+ use_cache = False
899
+
900
+ if use_cache:
901
+ if cache_params is None:
902
+ cache_params = Mamba2Cache(
903
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
904
+ )
905
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
906
+ elif cache_position is None:
907
+ # cases when we do manual forward instead of using `model.generate` which will initiate
908
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
909
+ # hack to conjecture the current cache position
910
+ raise ValueError(
911
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
912
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
913
+ "be initialized for you automatically"
914
+ )
915
+ else:
916
+ cache_params = None
917
+
918
+ hidden_states = inputs_embeds
919
+ all_hidden_states = () if output_hidden_states else None
920
+ for mixer_block in self.layers:
921
+ if self.gradient_checkpointing and self.training:
922
+ hidden_states = self._gradient_checkpointing_func(
923
+ mixer_block.__call__,
924
+ hidden_states,
925
+ cache_params,
926
+ cache_position,
927
+ attention_mask,
928
+ )
929
+ else:
930
+ hidden_states = mixer_block(
931
+ hidden_states,
932
+ cache_params=cache_params,
933
+ cache_position=cache_position,
934
+ attention_mask=attention_mask,
935
+ )
936
+
937
+ if output_hidden_states:
938
+ all_hidden_states = all_hidden_states + (hidden_states,)
939
+
940
+ hidden_states = self.norm_f(hidden_states)
941
+
942
+ if output_hidden_states:
943
+ all_hidden_states = all_hidden_states + (hidden_states,)
944
+
945
+ if not return_dict:
946
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
947
+
948
+ return Mamba2Output(
949
+ last_hidden_state=hidden_states,
950
+ cache_params=cache_params if use_cache else None,
951
+ hidden_states=all_hidden_states,
952
+ )
953
+
954
+
955
+ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
956
+ _tied_weights_keys = []
957
+
958
+ def __init__(self, config):
959
+ super().__init__(config)
960
+ self.backbone = Mamba2Model(config)
961
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
962
+ self.criterion = None
963
+
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ def get_output_embeddings(self):
968
+ return self.lm_head
969
+
970
+ def set_output_embeddings(self, new_embeddings):
971
+ self.lm_head = new_embeddings
972
+
973
+ def get_input_embeddings(self):
974
+ return self.backbone.get_input_embeddings()
975
+
976
+ def set_input_embeddings(self, new_embeddings):
977
+ return self.backbone.set_input_embeddings(new_embeddings)
978
+
979
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
980
+ def prepare_inputs_for_generation(
981
+ self,
982
+ input_ids,
983
+ inputs_embeds=None,
984
+ use_cache=None,
985
+ cache_params: Optional[Mamba2Cache] = None,
986
+ cache_position: Optional[torch.LongTensor] = None,
987
+ attention_mask: Optional[torch.Tensor] = None,
988
+ logits_to_keep: Optional[int] = None,
989
+ **kwargs,
990
+ ):
991
+ if use_cache:
992
+ # `cache_position` should have been initialized in `generate`
993
+ if cache_position is None:
994
+ raise ValueError(
995
+ "`cache_position` should not be None as it should have been initialized in "
996
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
997
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
998
+ )
999
+ if cache_position[0] > 0:
1000
+ input_ids = input_ids[:, -1][..., None]
1001
+
1002
+ if attention_mask is not None:
1003
+ attention_mask = None
1004
+ else:
1005
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
1006
+ # considering padding will be applied when input length is shorter, and truncation
1007
+ # will be applied when it is longer, so it will be equivalent to always have it match
1008
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
1009
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
1010
+
1011
+ if inputs_embeds is not None and cache_params is None:
1012
+ model_inputs = {"inputs_embeds": inputs_embeds}
1013
+ else:
1014
+ model_inputs = {"input_ids": input_ids}
1015
+
1016
+ if logits_to_keep is not None:
1017
+ model_inputs['logits_to_keep'] = logits_to_keep
1018
+
1019
+ model_inputs.update({
1020
+ 'attention_mask': attention_mask,
1021
+ 'cache_params': cache_params,
1022
+ 'use_cache': use_cache,
1023
+ 'cache_position': cache_position,
1024
+ 'logits_to_keep': logits_to_keep
1025
+ })
1026
+ return model_inputs
1027
+
1028
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1029
+ def forward(
1030
+ self,
1031
+ input_ids: Optional[torch.LongTensor] = None,
1032
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1033
+ cache_params: Optional[Mamba2Cache] = None,
1034
+ labels: Optional[torch.LongTensor] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ return_dict: Optional[bool] = None,
1037
+ use_cache: Optional[bool] = None,
1038
+ cache_position: Optional[torch.Tensor] = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ logits_to_keep: Optional[int] = 0,
1041
+ **kwargs, # for now we need this for generation
1042
+ ) -> Union[Tuple, Mamba2CausalLMOutput]:
1043
+ r"""
1044
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1045
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1046
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1047
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1048
+ """
1049
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1050
+
1051
+ outputs = self.backbone(
1052
+ input_ids,
1053
+ cache_params=cache_params,
1054
+ inputs_embeds=inputs_embeds,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ use_cache=use_cache,
1058
+ cache_position=cache_position,
1059
+ attention_mask=attention_mask,
1060
+ )
1061
+ hidden_states = outputs[0]
1062
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
1063
+
1064
+ loss, logits = None, None
1065
+ if not fuse_linear_and_cross_entropy or labels is None:
1066
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
1067
+ if labels is not None:
1068
+ if getattr(self, 'criterion', None) is None:
1069
+ if fuse_linear_and_cross_entropy:
1070
+ criterion = FusedLinearCrossEntropyLoss()
1071
+ elif self.config.fuse_cross_entropy:
1072
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
1073
+ else:
1074
+ criterion = nn.CrossEntropyLoss()
1075
+ else:
1076
+ criterion = self.criterion
1077
+ labels = labels.to(hidden_states.device)
1078
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
1079
+ if fuse_linear_and_cross_entropy:
1080
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
1081
+ else:
1082
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
1083
+
1084
+ if not return_dict:
1085
+ output = (logits,) + outputs[1:]
1086
+ return (loss,) + output if loss is not None else output
1087
+
1088
+ return Mamba2CausalLMOutput(
1089
+ loss=loss,
1090
+ logits=logits,
1091
+ cache_params=outputs.cache_params,
1092
+ hidden_states=outputs.hidden_states,
1093
+ )
fla/models/rwkv6/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
6
+ from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model
7
+
8
+ AutoConfig.register(RWKV6Config.model_type, RWKV6Config, True)
9
+ AutoModel.register(RWKV6Config, RWKV6Model, True)
10
+ AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model']
fla/models/rwkv7/modeling_rwkv7.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV7FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV7FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
55
+
56
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
58
+ self.act_fn = ACT2FN[hidden_act]
59
+
60
+ self.layer_idx = layer_idx
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ state: Optional[Cache] = None
67
+ ) -> torch.Tensor:
68
+ if attention_mask is not None:
69
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
70
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
71
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
72
+ else:
73
+ shifted = self.time_shift(x)
74
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
75
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
76
+ if state is not None:
77
+ # no need to update the offset twice
78
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
79
+ return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state
80
+
81
+
82
+ class RWKV7Block(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: RWKV7Config,
87
+ layer_idx: int
88
+ ) -> RWKV7Block:
89
+ super().__init__()
90
+
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+
94
+ if config.norm_first and layer_idx == 0:
95
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
96
+ config.hidden_size,
97
+ bias=config.norm_bias,
98
+ eps=config.norm_eps
99
+ )
100
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
101
+ config.hidden_size,
102
+ bias=config.norm_bias,
103
+ eps=config.norm_eps
104
+ )
105
+ if config.attn is not None and layer_idx in config.attn['layers']:
106
+ self.attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ num_heads=config.attn['num_heads'],
109
+ num_kv_heads=config.attn['num_kv_heads'],
110
+ qkv_bias=config.attn['qkv_bias'],
111
+ window_size=config.attn['window_size'],
112
+ rope_theta=config.attn['rope_theta'],
113
+ max_position_embeddings=config.max_position_embeddings,
114
+ layer_idx=layer_idx
115
+ )
116
+ else:
117
+ self.attn = RWKV7Attention(
118
+ mode=config.attn_mode,
119
+ hidden_size=config.hidden_size,
120
+ head_dim=config.head_dim,
121
+ num_heads=config.num_heads,
122
+ decay_low_rank_dim=config.decay_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ a_low_rank_dim=config.a_low_rank_dim,
125
+ v_low_rank_dim=config.v_low_rank_dim,
126
+ norm_eps=config.norm_eps,
127
+ fuse_norm=config.fuse_norm,
128
+ layer_idx=layer_idx,
129
+ value_dim=config.value_dim[layer_idx]
130
+ )
131
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
132
+ config.hidden_size,
133
+ bias=config.norm_bias,
134
+ eps=config.norm_eps
135
+ )
136
+ self.ffn = RWKV7FeedForward(
137
+ hidden_size=config.hidden_size,
138
+ hidden_ratio=config.hidden_ratio,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ layer_idx=layer_idx
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ past_key_values: Optional[Cache] = None,
149
+ use_cache: Optional[bool] = False,
150
+ output_attentions: Optional[bool] = False,
151
+ v_first: torch.Tensor = None,
152
+ **kwargs,
153
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
154
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
155
+ hidden_states = self.attn_norm(residual)
156
+ hidden_states, attentions, past_key_values, v_first = self.attn(
157
+ hidden_states=hidden_states,
158
+ attention_mask=attention_mask,
159
+ past_key_values=past_key_values,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ v_first=v_first,
163
+ **kwargs
164
+ )
165
+ if self.config.fuse_norm:
166
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
167
+ else:
168
+ hidden_states = residual + hidden_states
169
+ residual = hidden_states
170
+ hidden_states = self.ffn_norm(hidden_states)
171
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
172
+ hidden_states = residual + hidden_states
173
+
174
+ outputs = (hidden_states, attentions, past_key_values, v_first)
175
+
176
+ return outputs
177
+
178
+
179
+ class RWKV7PreTrainedModel(PreTrainedModel):
180
+
181
+ config_class = RWKV7Config
182
+ base_model_prefix = 'model'
183
+ supports_gradient_checkpointing = True
184
+ _no_split_modules = ['RWKV7Block']
185
+ _supports_cache_class = True
186
+ _skip_keys_device_placement = ["past_key_values"]
187
+
188
+ def __init__(self, *inputs, **kwargs):
189
+ super().__init__(*inputs, **kwargs)
190
+
191
+ def _init_weights(
192
+ self,
193
+ module: nn.Module,
194
+ rescale_prenorm_residual: bool = True,
195
+ num_residuals_per_layer: int = 2,
196
+ ):
197
+ warnings.warn(
198
+ "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. "
199
+ "The detailed initialization scheme is currently not implemented here but can be found in the "
200
+ "official code repository. We emphasize that using the recommended initialization is essential "
201
+ "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization "
202
+ "may lead to performance degradation.\n"
203
+ "Alternatively, please generate initial weights from the official RWKV code repository, and "
204
+ "convert the PyTorch checkpoint into FLA supported format."
205
+ )
206
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
207
+ # Slightly different from the TF version which uses truncated_normal for initialization
208
+ # cf https://github.com/pytorch/pytorch/pull/5617
209
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
210
+ if module.bias is not None:
211
+ nn.init.zeros_(module.bias)
212
+ elif isinstance(module, nn.Parameter):
213
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
216
+ elif hasattr(module, 'reset_parameters'):
217
+ module.reset_parameters()
218
+
219
+ if rescale_prenorm_residual:
220
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
221
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
222
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
223
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
224
+ #
225
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
226
+ p = None
227
+ if hasattr(module, 'o_proj'):
228
+ p = module.o_proj.weight
229
+ elif hasattr(module, 'down_proj'):
230
+ p = module.down_proj.weight
231
+ if p is not None:
232
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
233
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
234
+ # We need to reinit p since this code could be called multiple times
235
+ # Having just p *= scale would repeatedly scale it down
236
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
237
+ with torch.no_grad():
238
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
239
+
240
+
241
+ class RWKV7Model(RWKV7PreTrainedModel):
242
+
243
+ def __init__(self, config: RWKV7Config):
244
+ super().__init__(config)
245
+ self.padding_idx = config.pad_token_id
246
+ self.vocab_size = config.vocab_size
247
+
248
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
249
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
250
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
251
+ config.hidden_size,
252
+ bias=config.norm_bias,
253
+ eps=config.norm_eps
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ self.post_init()
259
+
260
+ def get_input_embeddings(self):
261
+ return self.embeddings
262
+
263
+ def set_input_embeddings(self, value):
264
+ self.embeddings = value
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None, # noqa
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ past_key_values: Optional[Cache] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ **kwargs: Unpack[Dict]
277
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
278
+ if output_attentions:
279
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
280
+ output_attentions = False
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
283
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
289
+ if input_ids is None and inputs_embeds is None:
290
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.embeddings(input_ids)
294
+ hidden_states = inputs_embeds
295
+
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ past_key_values = Cache.from_legacy_cache(past_key_values)
298
+
299
+ if self.gradient_checkpointing and self.training and use_cache:
300
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
301
+ use_cache = False
302
+
303
+ all_hidden_states = () if output_hidden_states else None
304
+ all_attns = () if output_attentions else None
305
+
306
+ v_first = torch.zeros_like(hidden_states)
307
+ for layer in self.layers:
308
+ if output_hidden_states:
309
+ all_hidden_states += (hidden_states,)
310
+
311
+ if self.gradient_checkpointing and self.training:
312
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
313
+ layer.__call__,
314
+ hidden_states,
315
+ attention_mask,
316
+ past_key_values,
317
+ use_cache,
318
+ output_attentions,
319
+ v_first,
320
+ **kwargs
321
+ )
322
+ else:
323
+ hidden_states, attentions, past_key_values, v_first = layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ past_key_values=past_key_values,
327
+ use_cache=use_cache,
328
+ output_attentions=output_attentions,
329
+ v_first=v_first,
330
+ **kwargs
331
+ )
332
+
333
+ if output_attentions:
334
+ all_attns += (attentions,)
335
+
336
+ hidden_states = self.norm(hidden_states)
337
+
338
+ # add hidden states from the last decoder layer
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+
342
+ if not return_dict:
343
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
344
+ return BaseModelOutputWithPast(
345
+ last_hidden_state=hidden_states,
346
+ past_key_values=past_key_values,
347
+ hidden_states=all_hidden_states,
348
+ attentions=all_attns
349
+ )
350
+
351
+
352
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
353
+
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = RWKV7Model(config)
359
+ self.vocab_size = config.vocab_size
360
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
+ self.criterion = None
362
+
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def get_input_embeddings(self):
367
+ return self.model.embeddings
368
+
369
+ def set_input_embeddings(self, value):
370
+ self.model.embeddings = value
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head = new_embeddings
377
+
378
+ def set_decoder(self, decoder):
379
+ self.model = decoder
380
+
381
+ def get_decoder(self):
382
+ return self.model
383
+
384
+ def generate(self, *args, **kwargs):
385
+ try:
386
+ return super().generate(*args, **kwargs)
387
+ except AttributeError as exception:
388
+ if 'past_key_values' in str(exception):
389
+ raise AttributeError(
390
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
391
+ f"which is not supported for {self.__class__.__name__}. "
392
+ f"Try another generation strategy instead. "
393
+ f"For the available generation strategies, check this doc: "
394
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
395
+ )
396
+ else:
397
+ raise exception
398
+
399
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
400
+ def prepare_inputs_for_generation(
401
+ self,
402
+ input_ids: torch.LongTensor = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ use_cache: bool = True,
407
+ logits_to_keep: Optional[int] = None,
408
+ **kwargs
409
+ ):
410
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
411
+ if past_key_values is not None and len(past_key_values) > 0:
412
+ input_ids = input_ids[:, -1:]
413
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
414
+ if inputs_embeds is not None and len(past_key_values) == 0:
415
+ model_inputs = {'inputs_embeds': inputs_embeds}
416
+ else:
417
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
418
+ # recompiles graphs as the stride of the inputs is a guard.
419
+ # Ref: https://github.com/huggingface/transformers/pull/29114
420
+ # TODO: use `next_tokens` directly instead.
421
+ model_inputs = {'input_ids': input_ids.contiguous()}
422
+
423
+ if logits_to_keep is not None:
424
+ model_inputs['logits_to_keep'] = logits_to_keep
425
+
426
+ model_inputs.update({
427
+ 'past_key_values': past_key_values,
428
+ 'use_cache': use_cache,
429
+ 'attention_mask': attention_mask,
430
+ })
431
+ return model_inputs
432
+
433
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ inputs_embeds: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[Cache] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ shift_labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ logits_to_keep: Optional[int] = 0,
447
+ **kwargs: Unpack[Dict]
448
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
449
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
+ output_hidden_states = (
451
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
+ )
453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ inputs_embeds=inputs_embeds,
459
+ past_key_values=past_key_values,
460
+ use_cache=use_cache,
461
+ output_attentions=output_attentions,
462
+ output_hidden_states=output_hidden_states,
463
+ return_dict=return_dict,
464
+ **kwargs
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
469
+
470
+ loss, logits = None, None
471
+ has_labels = (labels is not None) or (shift_labels is not None)
472
+ if not (fuse_linear_and_cross_entropy and has_labels):
473
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
474
+ if has_labels:
475
+ if getattr(self, 'criterion', None) is None:
476
+ if fuse_linear_and_cross_entropy:
477
+ criterion = FusedLinearCrossEntropyLoss()
478
+ elif self.config.fuse_cross_entropy:
479
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
480
+ else:
481
+ criterion = nn.CrossEntropyLoss()
482
+ else:
483
+ criterion = self.criterion
484
+
485
+ # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files.
486
+ if shift_labels is None:
487
+ shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
488
+ shift_labels = shift_labels.to(hidden_states.device)
489
+
490
+ if fuse_linear_and_cross_entropy:
491
+ loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias)
492
+ else:
493
+ loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,) + outputs[1:]
497
+ return (loss,) + output if loss is not None else output
498
+
499
+ return CausalLMOutputWithPast(
500
+ loss=loss,
501
+ logits=logits,
502
+ past_key_values=outputs.past_key_values,
503
+ hidden_states=outputs.hidden_states,
504
+ attentions=outputs.attentions,
505
+ )
fla/models/samba/configuration_samba.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ from typing import Dict, Optional
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class SambaConfig(PretrainedConfig):
10
+
11
+ model_type = "samba"
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2304,
16
+ state_size: int = 16,
17
+ num_hidden_layers: int = 18,
18
+ norm_eps=1e-5,
19
+ pad_token_id: int = 0,
20
+ bos_token_id: int = 1,
21
+ eos_token_id: int = 2,
22
+ expand: int = 2,
23
+ conv_kernel: int = 4,
24
+ use_bias: bool = False,
25
+ use_conv_bias: bool = True,
26
+ hidden_act: str = "swish",
27
+ initializer_range: str = 0.02,
28
+ residual_in_fp32: bool = False,
29
+ time_step_rank: str = "auto",
30
+ time_step_scale: float = 1.0,
31
+ time_step_min: float = 0.001,
32
+ time_step_max: float = 0.1,
33
+ time_step_init_scheme: str = "random",
34
+ time_step_floor: float = 1e-4,
35
+ max_position_embeddings: int = 2048,
36
+ attn: Optional[Dict] = {
37
+ 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17),
38
+ 'num_heads': 18,
39
+ 'num_kv_heads': 18,
40
+ 'qkv_bias': False,
41
+ 'window_size': 2048,
42
+ 'rope_theta': 10000.
43
+ },
44
+ hidden_ratio: Optional[int] = 4,
45
+ rescale_prenorm_residual: bool = False,
46
+ use_cache: bool = True,
47
+ fuse_norm: bool = True,
48
+ fuse_swiglu: bool = True,
49
+ fuse_cross_entropy: bool = True,
50
+ vocab_size: int = 32000,
51
+ tie_word_embeddings: bool = False,
52
+ **kwargs,
53
+ ):
54
+ self.hidden_size = hidden_size
55
+ self.state_size = state_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_eps = norm_eps
58
+ self.conv_kernel = conv_kernel
59
+ self.expand = expand
60
+ self.intermediate_size = int(expand * self.hidden_size)
61
+ self.bos_token_id = bos_token_id
62
+ self.eos_token_id = eos_token_id
63
+ self.pad_token_id = pad_token_id
64
+ self.use_bias = use_bias
65
+ self.use_conv_bias = use_conv_bias
66
+ self.hidden_act = hidden_act
67
+ self.initializer_range = initializer_range
68
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
69
+ self.time_step_scale = time_step_scale
70
+ self.time_step_min = time_step_min
71
+ self.time_step_max = time_step_max
72
+ self.time_step_init_scheme = time_step_init_scheme
73
+ self.time_step_floor = time_step_floor
74
+ self.max_position_embeddings = max_position_embeddings
75
+ self.attn = attn
76
+ self.hidden_ratio = hidden_ratio
77
+ self.rescale_prenorm_residual = rescale_prenorm_residual
78
+ self.residual_in_fp32 = residual_in_fp32
79
+ self.use_cache = use_cache
80
+
81
+ self.fuse_norm = fuse_norm
82
+ self.fuse_swiglu = fuse_swiglu
83
+ self.fuse_cross_entropy = fuse_cross_entropy
84
+ self.vocab_size = vocab_size
85
+
86
+ super().__init__(
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ pad_token_id=pad_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs
92
+ )
fla/ops/common/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (139 Bytes). View file
 
fla/ops/common/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (32.4 kB). View file
 
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (392 Bytes). View file
 
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (34 kB). View file
 
fla/ops/forgetting_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (242 Bytes). View file
 
fla/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (39 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (21.3 kB). View file
 
fla/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (27 kB). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_delta_h import prepare_chunk_offsets
11
+ from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
13
+
14
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [2, 4, 8, 16]
26
+ ],
27
+ key=['BT', 'BK', 'BV'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
32
+ k,
33
+ v,
34
+ d,
35
+ b,
36
+ u,
37
+ v_new,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
82
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
83
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
86
+ p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ else:
92
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
93
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
95
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
98
+ # [BK, BC]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ b_d = tl.load(p_d, boundary_check=(0, 1))
102
+ b_b = tl.load(p_b, boundary_check=(0, 1))
103
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
104
+ b_hc += tl.dot(b_k, b_v)
105
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
106
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
107
+ b_h += b_hc
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
116
+ })
117
+ @triton.autotune(
118
+ configs=[
119
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
120
+ for BK in BKV_LIST
121
+ for BV in BKV_LIST
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3]
124
+ ],
125
+ key=['BT'],
126
+ use_cuda_graph=use_cuda_graph,
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
130
+ q,
131
+ k,
132
+ v,
133
+ u,
134
+ b,
135
+ h,
136
+ o,
137
+ offsets,
138
+ indices,
139
+ scale,
140
+ T,
141
+ H: tl.constexpr,
142
+ K: tl.constexpr,
143
+ V: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BK: tl.constexpr,
146
+ BV: tl.constexpr,
147
+ USE_OFFSETS: tl.constexpr,
148
+ HEAD_FIRST: tl.constexpr,
149
+ ):
150
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
151
+ i_b, i_h = i_bh // H, i_bh % H
152
+
153
+ if USE_OFFSETS:
154
+ i_tg = i_t
155
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
156
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
157
+ T = eos - bos
158
+ NT = tl.cdiv(T, BT)
159
+ else:
160
+ NT = tl.cdiv(T, BT)
161
+ i_tg = i_b * NT + i_t
162
+ bos, eos = i_b * T, i_b * T + T
163
+
164
+ # offset calculation
165
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
166
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
167
+ b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
168
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
169
+ u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
170
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
171
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
172
+ stride_qk = K if HEAD_FIRST else H*K
173
+ stride_vo = V if HEAD_FIRST else H*V
174
+
175
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
176
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
177
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
178
+
179
+ for i_k in range(tl.cdiv(K, BK)):
180
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
181
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
182
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
183
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
184
+ # [BT, BK]
185
+ b_q = tl.load(p_q, boundary_check=(0, 1))
186
+ # [BK, BT]
187
+ b_k = tl.load(p_k, boundary_check=(0, 1))
188
+ b_b = tl.load(p_b, boundary_check=(0, 1))
189
+ # [BK, BV]
190
+ b_h = tl.load(p_h, boundary_check=(0, 1))
191
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
192
+ b_o += tl.dot(b_q, b_h)
193
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
194
+ b_Aqk += tl.dot(b_q, b_k)
195
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
196
+ b_Aqb += tl.dot(b_q, b_b)
197
+
198
+ o_i = tl.arange(0, BT)
199
+ m_A = o_i[:, None] >= o_i[None, :]
200
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
201
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
202
+
203
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ b_v = tl.load(p_v, boundary_check=(0, 1))
207
+ b_u = tl.load(p_u, boundary_check=(0, 1))
208
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+
212
+ def chunk_generalized_iplr_delta_rule_fwd_o(
213
+ q: torch.Tensor,
214
+ k: torch.Tensor,
215
+ v: torch.Tensor,
216
+ v_new: torch.Tensor,
217
+ b: torch.Tensor,
218
+ h: torch.Tensor,
219
+ scale: Optional[float] = None,
220
+ offsets: Optional[torch.LongTensor] = None,
221
+ indices: Optional[torch.LongTensor] = None,
222
+ head_first: bool = True,
223
+ chunk_size: int = 64
224
+ ) -> torch.Tensor:
225
+ if head_first:
226
+ B, H, T, K, V = *q.shape, v.shape[-1]
227
+ else:
228
+ B, T, H, K, V = *q.shape, v.shape[-1]
229
+ if scale is None:
230
+ scale = k.shape[-1] ** -0.5
231
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
232
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
233
+
234
+ o = torch.empty_like(v)
235
+
236
+ def grid(meta): return (
237
+ triton.cdiv(V, meta['BV']),
238
+ NT,
239
+ B * H
240
+ )
241
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
242
+ q=q,
243
+ k=k,
244
+ v=v,
245
+ u=v_new,
246
+ b=b,
247
+ h=h,
248
+ o=o,
249
+ offsets=offsets,
250
+ indices=indices,
251
+ scale=scale,
252
+ T=T,
253
+ H=H,
254
+ K=K,
255
+ V=V,
256
+ BT=BT,
257
+ HEAD_FIRST=head_first
258
+ )
259
+ return o
260
+
261
+
262
+ def chunk_generalized_iplr_delta_rule_fwd_h(
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ w: torch.Tensor,
266
+ u: torch.Tensor,
267
+ b: torch.Tensor,
268
+ initial_state: Optional[torch.Tensor] = None,
269
+ output_final_state: bool = False,
270
+ offsets: Optional[torch.LongTensor] = None,
271
+ indices: Optional[torch.LongTensor] = None,
272
+ head_first: bool = True,
273
+ chunk_size: int = 64
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ if head_first:
276
+ B, H, T, K, V = *k.shape, u.shape[-1]
277
+ else:
278
+ B, T, H, K, V = *k.shape, u.shape[-1]
279
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
280
+ # N: the actual number of sequences in the batch with either equal or variable lengths
281
+ if offsets is None:
282
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
283
+ else:
284
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
285
+
286
+ BK = triton.next_power_of_2(K)
287
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
288
+ # H100 can have larger block size
289
+
290
+ if check_shared_mem('hopper', k.device.index):
291
+ BV = 64
292
+ BC = 64 if K <= 128 else 32
293
+ elif check_shared_mem('ampere', k.device.index): # A100
294
+ BV = 32
295
+ BC = 32
296
+ else:
297
+ BV = 16
298
+ BC = 16
299
+
300
+ BC = min(BT, BC)
301
+ NK = triton.cdiv(K, BK)
302
+ NV = triton.cdiv(V, BV)
303
+
304
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
305
+
306
+ if head_first:
307
+ h = k.new_empty(B, H, NT, K, V)
308
+ else:
309
+ h = k.new_empty(B, NT, H, K, V)
310
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
311
+
312
+ v_new = torch.empty_like(u)
313
+ grid = (NK, NV, N * H)
314
+
315
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
316
+ k=k,
317
+ v=v,
318
+ d=w,
319
+ b=b,
320
+ u=u,
321
+ v_new=v_new,
322
+ h=h,
323
+ h0=initial_state,
324
+ ht=final_state,
325
+ offsets=offsets,
326
+ chunk_offsets=chunk_offsets,
327
+ T=T,
328
+ H=H,
329
+ K=K,
330
+ V=V,
331
+ BT=BT,
332
+ BC=BC,
333
+ BK=BK,
334
+ BV=BV,
335
+ NT=NT,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return h, v_new, final_state
339
+
340
+
341
+ def chunk_generalized_iplr_delta_rule_fwd(
342
+ q: torch.Tensor,
343
+ k: torch.Tensor,
344
+ v: torch.Tensor,
345
+ a: torch.Tensor,
346
+ b: torch.Tensor,
347
+ scale: float,
348
+ initial_state: torch.Tensor,
349
+ output_final_state: bool,
350
+ offsets: Optional[torch.LongTensor] = None,
351
+ indices: Optional[torch.LongTensor] = None,
352
+ head_first: bool = True,
353
+ chunk_size: int = 64
354
+ ):
355
+ T = q.shape[2] if head_first else q.shape[1]
356
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
357
+ w, u, _ = fwd_prepare_wy_repr(
358
+ a=a,
359
+ b=b,
360
+ k=k,
361
+ v=v,
362
+ offsets=offsets,
363
+ indices=indices,
364
+ head_first=head_first,
365
+ chunk_size=BT
366
+ )
367
+
368
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
369
+ k=k,
370
+ v=v,
371
+ b=b,
372
+ w=w,
373
+ u=u,
374
+ initial_state=initial_state,
375
+ output_final_state=output_final_state,
376
+ offsets=offsets,
377
+ indices=indices,
378
+ head_first=head_first,
379
+ chunk_size=BT
380
+ )
381
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
382
+ q=q,
383
+ k=k,
384
+ v=v,
385
+ v_new=v_new,
386
+ b=b,
387
+ h=h,
388
+ scale=scale,
389
+ offsets=offsets,
390
+ indices=indices,
391
+ head_first=head_first,
392
+ chunk_size=BT
393
+ )
394
+ return o, final_state
395
+
396
+
397
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
398
+
399
+ @staticmethod
400
+ @input_guard
401
+ @autocast_custom_fwd
402
+ def forward(
403
+ ctx,
404
+ q: torch.Tensor,
405
+ k: torch.Tensor,
406
+ v: torch.Tensor,
407
+ a: torch.Tensor,
408
+ b: torch.Tensor,
409
+ scale: float,
410
+ initial_state: torch.Tensor,
411
+ output_final_state: bool,
412
+ offsets: Optional[torch.LongTensor] = None,
413
+ head_first: bool = True
414
+ ):
415
+ chunk_size = 64
416
+
417
+ # 2-d indices denoting the offsets of chunks in each sequence
418
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
419
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
420
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
421
+ indices = None
422
+ if offsets is not None:
423
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
424
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
425
+
426
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
427
+ q=q,
428
+ k=k,
429
+ v=v,
430
+ a=a,
431
+ b=b,
432
+ scale=scale,
433
+ initial_state=initial_state,
434
+ output_final_state=output_final_state,
435
+ offsets=offsets,
436
+ indices=indices,
437
+ head_first=head_first,
438
+ chunk_size=chunk_size
439
+ )
440
+ return o.to(q.dtype), final_state
441
+
442
+ @staticmethod
443
+ @input_guard
444
+ @autocast_custom_bwd
445
+ def backward(
446
+ ctx,
447
+ do: torch.Tensor,
448
+ dht: torch.Tensor
449
+ ):
450
+ raise NotImplementedError(
451
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
452
+ "Stay tuned!"
453
+ )
454
+
455
+
456
+ @torch.compiler.disable
457
+ def chunk_iplr_delta_rule(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ a: torch.Tensor,
462
+ b: torch.Tensor,
463
+ scale: float = None,
464
+ initial_state: torch.Tensor = None,
465
+ output_final_state: bool = False,
466
+ cu_seqlens: Optional[torch.LongTensor] = None,
467
+ head_first: bool = True
468
+ ):
469
+ r"""
470
+ Args:
471
+ q (torch.Tensor):
472
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
473
+ k (torch.Tensor):
474
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ v (torch.Tensor):
476
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
477
+ a (torch.Tensor):
478
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
479
+ b (torch.Tensor):
480
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
481
+ scale (Optional[int]):
482
+ Scale factor for the RetNet attention scores.
483
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
484
+ initial_state (Optional[torch.Tensor]):
485
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
486
+ For equal-length input sequences, `N` equals the batch size `B`.
487
+ Default: `None`.
488
+ output_final_state (Optional[bool]):
489
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
490
+ cu_seqlens (torch.LongTensor):
491
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
492
+ consistent with the FlashAttention API.
493
+ head_first (Optional[bool]):
494
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
495
+ Default: `True`.
496
+
497
+ Returns:
498
+ o (torch.Tensor):
499
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
500
+ final_state (torch.Tensor):
501
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
502
+ """
503
+ assert q.dtype == k.dtype == v.dtype
504
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
505
+
506
+ if cu_seqlens is not None:
507
+ if q.shape[0] != 1:
508
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
509
+ f"Please flatten variable-length inputs before processing.")
510
+ if head_first:
511
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
512
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
513
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
514
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
515
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
516
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
517
+ q,
518
+ k,
519
+ v,
520
+ a,
521
+ b,
522
+ scale,
523
+ initial_state,
524
+ output_final_state,
525
+ cu_seqlens,
526
+ head_first
527
+ )
528
+ return o, final_state
fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
fla/ops/retention/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (3.62 kB). View file
 
fla/ops/retention/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (2.12 kB). View file
 
fla/ops/utils/__pycache__/logcumsumexp.cpython-312.pyc ADDED
Binary file (2.91 kB). View file
 
fla/ops/utils/__pycache__/logsumexp.cpython-312.pyc ADDED
Binary file (3.62 kB). View file
 
fla/ops/utils/__pycache__/pooling.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
profile_trace/iteration_111616/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_111616/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_112128/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_121344/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_121344/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_121344/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff