zaydzuhri commited on
Commit
48c5004
·
verified ·
1 Parent(s): dcf0454

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/layers/__pycache__/delta_net.cpython-312.pyc +0 -0
  3. fla/layers/__pycache__/forgetting_attn.cpython-312.pyc +0 -0
  4. fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc +0 -0
  5. fla/layers/__pycache__/hgrn.cpython-312.pyc +0 -0
  6. fla/layers/__pycache__/hgrn2.cpython-312.pyc +0 -0
  7. fla/layers/__pycache__/lightnet.cpython-312.pyc +0 -0
  8. fla/models/bitnet/configuration_bitnet.py +67 -0
  9. fla/models/bitnet/modeling_bitnet.py +441 -0
  10. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  11. fla/models/gated_deltanet/__init__.py +12 -0
  12. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  13. fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc +0 -0
  14. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc +0 -0
  15. fla/models/gla/__init__.py +13 -0
  16. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  17. fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  19. fla/models/linear_attn/__init__.py +12 -0
  20. fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  21. fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc +0 -0
  22. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  23. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  24. fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  25. fla/models/transformer_top/__init__.py +13 -0
  26. fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla/models/transformer_top/configuration_transformer.py +76 -0
  28. fla/models/transformer_top/modeling_transformer.py +438 -0
  29. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
  30. fla/modules/__pycache__/feature_map.cpython-312.pyc +0 -0
  31. fla/modules/__pycache__/fused_norm_gate.cpython-312.pyc +0 -0
  32. fla/modules/feature_map.py +300 -0
  33. fla/modules/fused_cross_entropy.py +419 -0
  34. fla/modules/fused_kl_div.py +323 -0
  35. fla/modules/parallel.py +37 -0
  36. fla/ops/abc/__init__.py +7 -0
  37. fla/ops/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  38. fla/ops/abc/__pycache__/chunk.cpython-312.pyc +0 -0
  39. fla/ops/abc/chunk.py +1116 -0
  40. fla/ops/attn/__pycache__/__init__.cpython-312.pyc +0 -0
  41. fla/ops/attn/__pycache__/parallel.cpython-312.pyc +0 -0
  42. fla/ops/attn/parallel.py +629 -0
  43. fla/ops/based/__init__.py +9 -0
  44. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  45. fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  46. fla/ops/based/__pycache__/parallel.cpython-312.pyc +0 -0
  47. fla/ops/based/fused_chunk.py +374 -0
  48. fla/ops/based/parallel.py +410 -0
  49. fla/ops/common/__init__.py +1 -0
  50. fla/ops/common/__pycache__/__init__.cpython-312.pyc +0 -0
fla/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.2 kB). View file
 
fla/layers/__pycache__/delta_net.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/layers/__pycache__/hgrn.cpython-312.pyc ADDED
Binary file (6.7 kB). View file
 
fla/layers/__pycache__/hgrn2.cpython-312.pyc ADDED
Binary file (8.6 kB). View file
 
fla/layers/__pycache__/lightnet.cpython-312.pyc ADDED
Binary file (8.85 kB). View file
 
fla/models/bitnet/configuration_bitnet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class BitNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'bitnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ window_size: Optional[int] = None,
20
+ rope_theta: Optional[float] = 10000.,
21
+ max_position_embeddings: int = 2048,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ initializer_range: float = 0.006,
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ use_cache: bool = True,
29
+ pad_token_id: int = None,
30
+ bos_token_id: int = 1,
31
+ eos_token_id: int = 2,
32
+ tie_word_embeddings: bool = False,
33
+ fuse_norm: bool = True,
34
+ fuse_swiglu: bool = True,
35
+ fuse_cross_entropy: bool = True,
36
+ vocab_size: int = 32000,
37
+ **kwargs,
38
+ ):
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_heads = num_heads
42
+ self.num_kv_heads = num_kv_heads
43
+ self.window_size = window_size
44
+ self.rope_theta = rope_theta
45
+ self.max_position_embeddings = max_position_embeddings
46
+
47
+ self.hidden_ratio = hidden_ratio
48
+ self.intermediate_size = intermediate_size
49
+ self.hidden_act = hidden_act
50
+
51
+ self.initializer_range = initializer_range
52
+ self.elementwise_affine = elementwise_affine
53
+ self.norm_eps = norm_eps
54
+ self.use_cache = use_cache
55
+
56
+ self.fuse_norm = fuse_norm
57
+ self.fuse_swiglu = fuse_swiglu
58
+ self.fuse_cross_entropy = fuse_cross_entropy
59
+ self.vocab_size = vocab_size
60
+
61
+ super().__init__(
62
+ pad_token_id=pad_token_id,
63
+ bos_token_id=bos_token_id,
64
+ eos_token_id=eos_token_id,
65
+ tie_word_embeddings=tie_word_embeddings,
66
+ **kwargs,
67
+ )
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (701 Bytes). View file
 
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (777 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (3.38 kB). View file
 
fla/models/gla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gla.configuration_gla import GLAConfig
6
+ from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
7
+
8
+ AutoConfig.register(GLAConfig.model_type, GLAConfig)
9
+ AutoModel.register(GLAConfig, GLAModel)
10
+ AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11
+
12
+
13
+ __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (674 Bytes). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/linear_attn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
6
+ from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel
7
+
8
+ AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
9
+ AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
10
+ AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
11
+
12
+ __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
fla/models/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc ADDED
Binary file (3.32 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (24.5 kB). View file
 
fla/models/transformer_top/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer_top.configuration_transformer import TOPTransformerConfig
6
+ from fla.models.transformer_top.modeling_transformer import TOPTransformerForCausalLM, TOPTransformerModel
7
+
8
+ AutoConfig.register(TOPTransformerConfig.model_type, TOPTransformerConfig)
9
+ AutoModel.register(TOPTransformerConfig, TOPTransformerModel)
10
+ AutoModelForCausalLM.register(TOPTransformerConfig, TOPTransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TOPTransformerConfig', 'TOPTransformerForCausalLM', 'TOPTransformerModel']
fla/models/transformer_top/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (749 Bytes). View file
 
fla/models/transformer_top/configuration_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class TOPTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'top_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ use_top_loss: bool = False,
40
+ top_window_size: Optional[int] = None,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.qkv_bias = qkv_bias
48
+ self.qk_norm = qk_norm
49
+ self.window_size = window_size
50
+ self.rope_theta = rope_theta
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+
57
+ self.initializer_range = initializer_range
58
+ self.elementwise_affine = elementwise_affine
59
+ self.norm_eps = norm_eps
60
+ self.use_cache = use_cache
61
+
62
+ self.fuse_norm = fuse_norm
63
+ self.fuse_swiglu = fuse_swiglu
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+ self.vocab_size = vocab_size
66
+
67
+ self.use_top_loss = use_top_loss
68
+ self.top_window_size = top_window_size if top_window_size is not None else max_position_embeddings
69
+
70
+ super().__init__(
71
+ pad_token_id=pad_token_id,
72
+ bos_token_id=bos_token_id,
73
+ eos_token_id=eos_token_id,
74
+ tie_word_embeddings=tie_word_embeddings,
75
+ **kwargs,
76
+ )
fla/models/transformer_top/modeling_transformer.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from dataclasses import dataclass
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.utils.deprecation import deprecate_kwarg
19
+
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from fla.layers.attn import Attention
24
+ from fla.models.transformer_top.configuration_transformer import TOPTransformerConfig
25
+ from fla.models.utils import Cache
26
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, FusedLinearListNetLoss
27
+ from fla.modules import GatedMLP as TransformerMLP
28
+ from fla.modules import RMSNorm
29
+ from fla.modules.seq_to_top import seq_to_top
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers.processing_utils import Unpack
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ @dataclass
38
+ class TOPLMOutputWithPast(CausalLMOutputWithPast):
39
+ ntp_loss: Optional[torch.FloatTensor] = None
40
+ top_loss: Optional[torch.FloatTensor] = None
41
+
42
+ class TOPTransformerBlock(nn.Module):
43
+
44
+ def __init__(self, config: TOPTransformerConfig, layer_idx: int):
45
+ super().__init__()
46
+
47
+ self.config = config
48
+ self.layer_idx = layer_idx
49
+
50
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
51
+ self.attn = Attention(
52
+ hidden_size=config.hidden_size,
53
+ num_heads=config.num_heads,
54
+ num_kv_heads=config.num_kv_heads,
55
+ qkv_bias=config.qkv_bias,
56
+ qk_norm=config.qk_norm,
57
+ window_size=config.window_size,
58
+ rope_theta=config.rope_theta,
59
+ max_position_embeddings=config.max_position_embeddings,
60
+ layer_idx=layer_idx
61
+ )
62
+
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = TransformerMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
77
+ output_attentions: Optional[bool] = False,
78
+ use_cache: Optional[bool] = False,
79
+ **kwargs: Unpack[Any]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+
82
+ residual = hidden_states
83
+ hidden_states = self.attn_norm(hidden_states)
84
+ hidden_states, attentions, past_key_values = self.attn(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ past_key_values=past_key_values,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ **kwargs
91
+ )
92
+ if self.config.fuse_norm:
93
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
94
+ else:
95
+ hidden_states = residual + hidden_states
96
+ residual = hidden_states
97
+ hidden_states = self.mlp_norm(hidden_states)
98
+ hidden_states = self.mlp(hidden_states, **kwargs)
99
+ hidden_states = residual + hidden_states
100
+
101
+ outputs = (hidden_states,)
102
+
103
+ if output_attentions:
104
+ outputs += (attentions,)
105
+
106
+ if use_cache:
107
+ outputs += (past_key_values,)
108
+
109
+ return outputs
110
+
111
+
112
+ class TOPTransformerPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = TOPTransformerConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['TOPTransformerBlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ rescale_prenorm_residual: bool = False,
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if rescale_prenorm_residual:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per TOPTransformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
158
+ with torch.no_grad():
159
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
160
+
161
+
162
+ class TOPTransformerModel(TOPTransformerPreTrainedModel):
163
+
164
+ def __init__(
165
+ self,
166
+ config: TOPTransformerConfig
167
+ ) -> TOPTransformerModel:
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([TOPTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
191
+ inputs_embeds: Optional[torch.FloatTensor] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Any]
197
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn(
200
+ "`TOPTransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
201
+ )
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ elif input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if use_cache and not isinstance(past_key_values, Cache):
215
+ past_key_values = Cache.from_legacy_cache(past_key_values)
216
+
217
+ if inputs_embeds is None:
218
+ inputs_embeds = self.embeddings(input_ids)
219
+
220
+ # embed positions
221
+ hidden_states = inputs_embeds
222
+
223
+ if self.gradient_checkpointing and self.training:
224
+ if use_cache:
225
+ logger.warning_once(
226
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
227
+ )
228
+ use_cache = False
229
+
230
+ all_hidden_states = () if output_hidden_states else None
231
+ all_attns = () if output_attentions else None
232
+ next_cache = None
233
+
234
+ for layer in self.layers:
235
+ if output_hidden_states:
236
+ all_hidden_states += (hidden_states,)
237
+
238
+ if self.gradient_checkpointing and self.training:
239
+ layer_outputs = self._gradient_checkpointing_func(
240
+ layer.__call__,
241
+ hidden_states,
242
+ attention_mask,
243
+ past_key_values,
244
+ output_attentions,
245
+ use_cache,
246
+ **kwargs
247
+ )
248
+ else:
249
+ layer_outputs = layer(
250
+ hidden_states,
251
+ attention_mask=attention_mask,
252
+ past_key_values=past_key_values,
253
+ output_attentions=output_attentions,
254
+ use_cache=use_cache,
255
+ **kwargs
256
+ )
257
+
258
+ hidden_states = layer_outputs[0]
259
+
260
+ if use_cache:
261
+ next_cache = layer_outputs[2 if output_attentions else 1]
262
+
263
+ if output_attentions:
264
+ all_attns += (layer_outputs[1],)
265
+
266
+ hidden_states = self.norm(hidden_states)
267
+
268
+ # add hidden states from the last decoder layer
269
+ if output_hidden_states:
270
+ all_hidden_states += (hidden_states,)
271
+
272
+ if not return_dict:
273
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
274
+
275
+ return BaseModelOutputWithPast(
276
+ last_hidden_state=hidden_states,
277
+ past_key_values=next_cache,
278
+ hidden_states=all_hidden_states,
279
+ attentions=all_attns
280
+ )
281
+
282
+
283
+ class TOPTransformerForCausalLM(TOPTransformerPreTrainedModel, GenerationMixin):
284
+
285
+ _tied_weights_keys = ["lm_head.weight"]
286
+
287
+ def __init__(self, config):
288
+ super().__init__(config)
289
+ self.model = TOPTransformerModel(config)
290
+ self.vocab_size = config.vocab_size
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
292
+ if config.use_top_loss:
293
+ self.top_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
294
+ self.top_criterion = FusedLinearListNetLoss()
295
+ self.top_window_size = config.top_window_size
296
+ self.criterion = None
297
+ self.pad_token_id = config.pad_token_id
298
+
299
+ # Initialize weights and apply final processing
300
+ self.post_init()
301
+
302
+ def get_input_embeddings(self):
303
+ return self.model.embeddings
304
+
305
+ def set_input_embeddings(self, value):
306
+ self.model.embeddings = value
307
+
308
+ def get_output_embeddings(self):
309
+ return self.lm_head
310
+
311
+ def set_output_embeddings(self, new_embeddings):
312
+ self.lm_head = new_embeddings
313
+
314
+ def set_decoder(self, decoder):
315
+ self.model = decoder
316
+
317
+ def get_decoder(self):
318
+ return self.model
319
+
320
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
321
+ def prepare_inputs_for_generation(
322
+ self,
323
+ input_ids: torch.LongTensor = None,
324
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ inputs_embeds: Optional[torch.Tensor] = None,
327
+ use_cache: bool = True,
328
+ logits_to_keep: Optional[int] = None,
329
+ **kwargs
330
+ ):
331
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
332
+ if past_key_values is not None and len(past_key_values) > 0:
333
+ input_ids = input_ids[:, -1:]
334
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
+ if inputs_embeds is not None and len(past_key_values) == 0:
336
+ model_inputs = {'inputs_embeds': inputs_embeds}
337
+ else:
338
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
339
+ # recompiles graphs as the stride of the inputs is a guard.
340
+ # Ref: https://github.com/huggingface/transformers/pull/29114
341
+ # TODO: use `next_tokens` directly instead.
342
+ model_inputs = {'input_ids': input_ids.contiguous()}
343
+
344
+ if logits_to_keep is not None:
345
+ model_inputs['logits_to_keep'] = logits_to_keep
346
+
347
+ model_inputs.update({
348
+ 'past_key_values': past_key_values,
349
+ 'use_cache': use_cache,
350
+ 'attention_mask': attention_mask,
351
+ })
352
+ return model_inputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ inputs_embeds: Optional[torch.FloatTensor] = None,
361
+ labels: Optional[torch.LongTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ logits_to_keep: Optional[int] = 0,
367
+ **kwargs: Unpack[Any]
368
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
370
+ output_hidden_states = (
371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
372
+ )
373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
374
+
375
+ outputs = self.model(
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ past_key_values=past_key_values,
379
+ inputs_embeds=inputs_embeds,
380
+ use_cache=use_cache,
381
+ output_attentions=output_attentions,
382
+ output_hidden_states=output_hidden_states,
383
+ return_dict=return_dict,
384
+ **kwargs
385
+ )
386
+
387
+ hidden_states = outputs[0]
388
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
389
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
390
+
391
+ loss = None
392
+ ntp_loss = None
393
+ top_loss = None
394
+ if labels is not None:
395
+ if getattr(self, 'criterion', None) is None:
396
+ if fuse_linear_and_cross_entropy:
397
+ criterion = FusedLinearCrossEntropyLoss()
398
+ elif self.config.fuse_cross_entropy:
399
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
400
+ else:
401
+ criterion = nn.CrossEntropyLoss()
402
+ else:
403
+ criterion = self.criterion
404
+ # Enable model parallelism
405
+ labels = labels.to(hidden_states.device)
406
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
407
+ ntp_labels = labels[..., :hidden_states.shape[1]].contiguous()
408
+ if fuse_linear_and_cross_entropy:
409
+ ntp_loss = criterion(hidden_states, ntp_labels, self.lm_head.weight, self.lm_head.bias)
410
+ else:
411
+ ntp_loss = criterion(logits.view(ntp_labels.numel(), -1), ntp_labels.reshape(-1))
412
+
413
+ if self.config.use_top_loss:
414
+ top_labels = seq_to_top(labels, vocab_size=self.vocab_size, window_size=self.top_window_size, pad_token_id=self.pad_token_id).contiguous()
415
+ top_loss = self.top_criterion(hidden_states, top_labels, self.top_head.weight, self.top_head.bias)
416
+ # print(f"NTP Loss: {ntp_loss.item()}, TOP Loss: {top_loss.item()}")
417
+ # For debugging, get the index where the top label is the highest and print the corresponding logits
418
+ # idx_max = torch.argmax(top_labels.view(-1, self.vocab_size), dim=1)
419
+ # # Print the labels and logits at that index
420
+ # print(f"Labels: {top_labels.view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}")
421
+ # print(f"Logits: {F.sigmoid(top_logits).view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}")
422
+ loss = ntp_loss + top_loss
423
+ else:
424
+ loss = ntp_loss
425
+
426
+ if not return_dict:
427
+ output = (logits,) + outputs[1:]
428
+ return (loss,) + output if loss is not None else output
429
+
430
+ return TOPLMOutputWithPast(
431
+ loss=loss,
432
+ ntp_loss=ntp_loss,
433
+ top_loss=top_loss,
434
+ logits=logits,
435
+ past_key_values=outputs.past_key_values,
436
+ hidden_states=outputs.hidden_states,
437
+ attentions=outputs.attentions,
438
+ )
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21 kB). View file
 
fla/modules/__pycache__/feature_map.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
fla/modules/__pycache__/fused_norm_gate.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
fla/modules/feature_map.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from fla.modules.activations import fast_gelu_impl, sigmoid, sqrelu, swish
13
+ from fla.modules.layernorm import layer_norm
14
+ from fla.utils import checkpoint
15
+
16
+
17
+ @checkpoint
18
+ def flatten_diag_outer_product(x, y):
19
+ z = torch.einsum("...i,...j->...ij", x, y)
20
+ N = z.size(-1)
21
+ indicies = torch.triu_indices(N, N)
22
+ return z[..., indicies[0], indicies[1]]
23
+
24
+
25
+ @checkpoint
26
+ def flatten_diag_outer_product_off1(x, y):
27
+ z = torch.einsum("...i,...j->...ij", x, y)
28
+ N = z.size(-1)
29
+ indicies = torch.triu_indices(N, N, 1)
30
+ indices2 = torch.arange(0, N)
31
+ return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
32
+
33
+
34
+ def is_power_of_2(n):
35
+ return (n & (n - 1) == 0) and n != 0
36
+
37
+
38
+ class HedgehogFeatureMap(nn.Module):
39
+
40
+ r"""
41
+ Hedgehog feature map as introduced in
42
+ `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ head_dim: int
48
+ ) -> HedgehogFeatureMap:
49
+ super().__init__()
50
+ # Trainable map
51
+ self.layer = nn.Linear(head_dim, head_dim)
52
+ self.init_weights_()
53
+
54
+ def init_weights_(self):
55
+ """Initialize trainable map as identity"""
56
+ with torch.no_grad():
57
+ identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
58
+ self.layer.weight.copy_(identity.to(self.layer.weight))
59
+ nn.init.zeros_(self.layer.bias)
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ x = self.layer(x) # shape b, h, l, d
63
+ return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
64
+
65
+
66
+ class T2RFeatureMap(nn.Module):
67
+
68
+ r"""
69
+ Simple linear mapping feature map as in
70
+ `Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ head_dim: int,
76
+ dot_dim: int = None,
77
+ bias: Optional[bool] = False
78
+ ) -> T2RFeatureMap:
79
+ super().__init__()
80
+ # Trainable map
81
+ if dot_dim is None:
82
+ dot_dim = head_dim
83
+
84
+ self.head_dim = head_dim
85
+ self.dot_dim = dot_dim
86
+ self.bias = bias
87
+
88
+ self.layer = nn.Linear(head_dim, dot_dim, bias=bias)
89
+
90
+ def __repr__(self) -> str:
91
+ return f"{self.__class__.__name__}(head_dim={self.head_dim}, dot_dim={self.dot_dim}, bias={self.bias})"
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ return self.layer(x).relu()
95
+
96
+
97
+ class DPFPFeatureMap(nn.Module):
98
+
99
+ r"""
100
+ Deterministic Parameter-Free Projection (DPFP) feature map in
101
+ `Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ head_dim: int,
107
+ nu: int = 4
108
+ ) -> DPFPFeatureMap:
109
+ super().__init__()
110
+ self.nu = nu
111
+
112
+ def forward(self, x: torch.Tensor):
113
+ x = torch.cat([x.relu(), -x.relu()], dim=-1)
114
+ x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
115
+ x_repeat = torch.cat([x] * self.nu, dim=-1)
116
+ return x_repeat * x_rolled
117
+
118
+
119
+ class HadamardFeatureMap(nn.Module):
120
+ def __init__(
121
+ self,
122
+ head_dim: int
123
+ ) -> HadamardFeatureMap:
124
+ super().__init__()
125
+ # Trainable map
126
+ self.layer1 = nn.Linear(head_dim, head_dim)
127
+ self.layer2 = nn.Linear(head_dim, head_dim)
128
+
129
+ def forward(self, x: torch.Tensor):
130
+ return self.layer1(x) * self.layer2(x)
131
+
132
+
133
+ class LearnableOuterProductFeatureMap(nn.Module):
134
+ def __init__(
135
+ self,
136
+ head_dim: int,
137
+ feature_dim: int
138
+ ) -> LearnableOuterProductFeatureMap:
139
+ super().__init__()
140
+ # Trainable map
141
+ self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
142
+ self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
143
+ self.normalizer = feature_dim ** -0.5
144
+
145
+ def forward(self, x: torch.Tensor):
146
+ return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
147
+
148
+
149
+ class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
150
+
151
+ def __init__(
152
+ self,
153
+ head_dim: int,
154
+ sketch_size: Optional[int] = None,
155
+ degree: Optional[int] = 2
156
+ ) -> LearnablePolySketchNonNegativeFeatureMap:
157
+ super().__init__()
158
+
159
+ assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
160
+
161
+ self.head_dim = head_dim
162
+ self.sketch_size = sketch_size if sketch_size is not None else head_dim
163
+ self.degree = degree
164
+
165
+ self.gamma = nn.Parameter(torch.ones(head_dim))
166
+ self.beta = nn.Parameter(torch.zeros(head_dim))
167
+ # NOTE: the sketch layers defined here are quite different from the original paper
168
+ # currently we simply use linear layers without any non-linear activations
169
+ self.sketches1 = nn.ModuleList([
170
+ nn.Linear(head_dim, sketch_size, bias=False),
171
+ *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
172
+ ])
173
+ self.sketches2 = nn.ModuleList([
174
+ nn.Linear(head_dim, sketch_size, bias=False),
175
+ *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
176
+ ])
177
+
178
+ def forward(self, x: torch.Tensor):
179
+ # Section 2.1
180
+ x = layer_norm(x, self.gamma, self.beta)
181
+ # first map the input to sketch size with learnable parameters
182
+ x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
183
+ for i in range(1, int(math.log2(self.degree)) - 1):
184
+ x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
185
+ # do sketch mapping for log2(p) - 1 times in total
186
+ # do p=2 mapping to ensure non-negativity
187
+ return flatten_diag_outer_product(x, x)
188
+
189
+
190
+ class TaylorFeatureMap(nn.Module):
191
+ def __init__(
192
+ self,
193
+ head_dim: int
194
+ ) -> TaylorFeatureMap:
195
+ super().__init__()
196
+ self.head_dim = head_dim
197
+ self.r2 = math.sqrt(2)
198
+ self.rd = math.sqrt(self.head_dim)
199
+ self.rrd = math.sqrt(self.rd)
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
203
+ return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
204
+
205
+
206
+ class RebasedFeatureMap(nn.Module):
207
+
208
+ def __init__(
209
+ self,
210
+ head_dim: int,
211
+ use_gamma: Optional[bool] = True,
212
+ use_beta: Optional[bool] = True,
213
+ normalize: Optional[bool] = True
214
+ ) -> RebasedFeatureMap:
215
+ super().__init__()
216
+
217
+ self.head_dim = head_dim
218
+ self.use_gamma = use_gamma
219
+ self.use_beta = use_beta
220
+ self.normalize = normalize
221
+
222
+ self.gamma = None
223
+ self.beta = None
224
+ if use_gamma:
225
+ self.gamma = nn.Parameter(torch.ones(head_dim))
226
+ if use_beta:
227
+ self.beta = nn.Parameter(torch.zeros(head_dim))
228
+
229
+ def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
230
+ if self.use_beta and self.use_gamma and self.normalize:
231
+ x = layer_norm(x, self.gamma, self.beta)
232
+ elif self.normalize:
233
+ x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
234
+ elif self.use_gamma and self.use_beta:
235
+ x = torch.addcmul(self.beta, x, self.gamma)
236
+ elif self.use_gamma:
237
+ x = x.mul(self.gamma)
238
+ else:
239
+ raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
240
+ f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
241
+ if not flatten:
242
+ return x
243
+ x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
244
+ # rebased use learnable parameters to approximate any quadratic function
245
+ return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)
246
+
247
+
248
+ class ReLUFeatureMap(nn.Module):
249
+
250
+ def __init__(
251
+ self,
252
+ ) -> ReLUFeatureMap:
253
+ super().__init__()
254
+
255
+ def forward(self, x: torch.Tensor):
256
+ return F.relu(x)
257
+
258
+
259
+ class SquaredReLUFeatureMap(nn.Module):
260
+
261
+ def __init__(
262
+ self,
263
+ ) -> SquaredReLUFeatureMap:
264
+ super().__init__()
265
+
266
+ def forward(self, x: torch.Tensor):
267
+ return sqrelu(x)
268
+
269
+
270
+ class GELUFeatureMap(nn.Module):
271
+
272
+ def __init__(
273
+ self,
274
+ ) -> GELUFeatureMap:
275
+ super().__init__()
276
+
277
+ def forward(self, x: torch.Tensor):
278
+ return fast_gelu_impl(x)
279
+
280
+
281
+ class SwishFeatureMap(nn.Module):
282
+
283
+ def __init__(
284
+ self,
285
+ ) -> SwishFeatureMap:
286
+ super().__init__()
287
+
288
+ def forward(self, x: torch.Tensor):
289
+ return swish(x)
290
+
291
+
292
+ class SigmoidFeatureMap(nn.Module):
293
+
294
+ def __init__(
295
+ self,
296
+ ) -> SigmoidFeatureMap:
297
+ super().__init__()
298
+
299
+ def forward(self, x: torch.Tensor):
300
+ return sigmoid(x)
fla/modules/fused_cross_entropy.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+
5
+ from typing import Any, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import input_guard
14
+
15
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
16
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
17
+ # version of PyTorch. The following 2 lines are for backward compatibility with
18
+ # older PyTorch.
19
+ if "all_gather_into_tensor" not in dir(torch.distributed):
20
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
21
+
22
+
23
+ @triton.heuristics({
24
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
25
+ })
26
+ @triton.jit
27
+ def cross_entropy_fwd_kernel(
28
+ loss_ptr, # data ptrs
29
+ lse_ptr,
30
+ z_loss_ptr,
31
+ logits_ptr,
32
+ labels_ptr,
33
+ label_smoothing,
34
+ logit_scale,
35
+ lse_square_scale,
36
+ ignore_index,
37
+ total_classes,
38
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
39
+ n_cols, # shapes
40
+ n_rows,
41
+ logits_row_stride, # strides
42
+ BLOCK_SIZE: tl.constexpr,
43
+ HAS_SMOOTHING: tl.constexpr,
44
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
45
+ SPLIT: tl.constexpr,
46
+ ):
47
+ row_idx = tl.program_id(0)
48
+ col_block_idx = tl.program_id(1)
49
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
50
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
51
+ label_idx = tl.load(labels_ptr + row_idx)
52
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf"))
53
+ logits = logits.to(tl.float32) * logit_scale
54
+ max_logits = tl.max(logits, 0)
55
+ if HAS_SMOOTHING:
56
+ sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
57
+ lse = log(tl.sum(exp(logits - max_logits), 0)) + max_logits
58
+ tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
59
+ if label_idx == ignore_index:
60
+ loss = 0.0
61
+ z_loss = 0.0
62
+ else:
63
+ label_idx -= class_start_idx
64
+ if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
65
+ n_cols, (col_block_idx + 1) * BLOCK_SIZE
66
+ ):
67
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
68
+ if HAS_SMOOTHING:
69
+ loss = (
70
+ (lse if not SPLIT else 0.0)
71
+ - label_smoothing * sum_logits / total_classes
72
+ - (1 - label_smoothing) * logits_label
73
+ )
74
+ else:
75
+ loss = (lse if not SPLIT else 0.0) - logits_label
76
+ else:
77
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss
78
+ if HAS_SMOOTHING:
79
+ loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
80
+ else:
81
+ loss = 0.0
82
+ if not SPLIT:
83
+ z_loss = lse_square_scale * lse * lse
84
+ loss += z_loss
85
+ else:
86
+ z_loss = 0.0
87
+ tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
88
+ if not SPLIT:
89
+ tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
90
+
91
+
92
+ @triton.heuristics({
93
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
94
+ })
95
+ @triton.jit
96
+ def cross_entropy_bwd_kernel(
97
+ dlogits_ptr, # data ptrs
98
+ dloss_ptr,
99
+ logits_ptr,
100
+ lse_ptr,
101
+ labels_ptr,
102
+ label_smoothing,
103
+ logit_scale,
104
+ lse_square_scale,
105
+ ignore_index,
106
+ total_classes,
107
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
108
+ n_cols, # shapes
109
+ logits_row_stride, # strides
110
+ dlogits_row_stride,
111
+ dloss_row_stride,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ HAS_SMOOTHING: tl.constexpr,
114
+ ):
115
+ row_idx = tl.program_id(0)
116
+ col_block_idx = tl.program_id(1)
117
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
118
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
119
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
120
+ label_idx = tl.load(labels_ptr + row_idx)
121
+ if label_idx != ignore_index:
122
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
123
+ else:
124
+ dloss = 0.0
125
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
126
+ tl.float32
127
+ ) * logit_scale
128
+ lse = tl.load(lse_ptr + row_idx)
129
+ probs = exp(logits - lse)
130
+ probs += 2.0 * lse_square_scale * lse * probs
131
+ label_idx -= class_start_idx
132
+ if HAS_SMOOTHING:
133
+ smooth_negative = label_smoothing / total_classes
134
+ probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative
135
+ else:
136
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
137
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
138
+
139
+
140
+ def fused_cross_entropy_forward(
141
+ logits: torch.Tensor,
142
+ target: torch.Tensor,
143
+ label_smoothing: float = 0.0,
144
+ logit_scale: float = 1.0,
145
+ lse_square_scale: float = 0.0,
146
+ ignore_index: int = -100,
147
+ process_group=None,
148
+ ):
149
+ n_rows, n_cols = logits.shape
150
+ assert target.shape == (n_rows,)
151
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
152
+ total_classes = world_size * n_cols
153
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
154
+ class_start_idx = rank * n_cols
155
+
156
+ if logits.stride(-1) != 1:
157
+ logits = logits.contiguous()
158
+ # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
159
+ MAX_BLOCK_SIZE = 64 * 1024
160
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
161
+ num_warps = (
162
+ 4
163
+ if BLOCK_SIZE < 2048
164
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
165
+ )
166
+ # We may split the lse computation across multiple blocks, then do a reduction
167
+ # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
168
+ # where having just one thread block processing more than 64k elements is slow.
169
+ split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
170
+ n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
171
+ loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
172
+ losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
173
+ lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
174
+ z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
175
+
176
+ cross_entropy_fwd_kernel[(n_rows, n_splits)](
177
+ losses, # data ptrs
178
+ lse,
179
+ z_losses,
180
+ logits,
181
+ target,
182
+ label_smoothing,
183
+ logit_scale,
184
+ lse_square_scale,
185
+ ignore_index,
186
+ total_classes,
187
+ class_start_idx,
188
+ n_cols, # shapes
189
+ n_rows,
190
+ logits.stride(0), # strides
191
+ BLOCK_SIZE=BLOCK_SIZE, # constants
192
+ num_warps=num_warps,
193
+ SPLIT=split
194
+ )
195
+
196
+ if split:
197
+ # If there's no label_smoothing, if target are in the vocab of this partition, losses contains
198
+ # - predicted logit, and 0 otherwise.
199
+ # If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains
200
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
201
+ # For target not in the vocab of this partition, losses contains
202
+ # -0.1 * sum logit / total_classes.
203
+ if n_splits > 1:
204
+ lse = torch.logsumexp(lse, dim=0)
205
+ losses = losses.sum(dim=0)
206
+ if world_size > 1:
207
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
208
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
209
+ handle_losses = torch.distributed.all_reduce(
210
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
211
+ )
212
+ lse = torch.logsumexp(lse_allgather, dim=0)
213
+ handle_losses.wait()
214
+ # After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit,
215
+ # we just have to add the (global) lse.
216
+ # If there's label_smoothing=0.1, the total losses are
217
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
218
+ # Again, we just have to add the (global) lse.
219
+ losses += lse
220
+ if lse_square_scale != 0.0:
221
+ z_losses = lse_square_scale * lse.square()
222
+ z_losses.masked_fill_(target == ignore_index, 0.0)
223
+ losses += z_losses
224
+ else:
225
+ z_losses = torch.zeros_like(losses)
226
+ losses.masked_fill_(target == ignore_index, 0.0)
227
+
228
+ return losses, z_losses, lse, total_classes, class_start_idx
229
+
230
+
231
+ class CrossEntropyLossFunction(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ @input_guard
235
+ def forward(
236
+ ctx,
237
+ logits,
238
+ target,
239
+ label_smoothing=0.0,
240
+ logit_scale=1.0,
241
+ lse_square_scale=0.0,
242
+ ignore_index=-100,
243
+ inplace_backward=False,
244
+ process_group=None,
245
+ ):
246
+ losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward(
247
+ logits,
248
+ target,
249
+ label_smoothing,
250
+ logit_scale,
251
+ lse_square_scale,
252
+ ignore_index,
253
+ process_group,
254
+ )
255
+ ctx.save_for_backward(logits, lse, target)
256
+ ctx.mark_non_differentiable(z_losses)
257
+ ctx.label_smoothing = label_smoothing
258
+ ctx.logit_scale = logit_scale
259
+ ctx.lse_square_scale = lse_square_scale
260
+ ctx.ignore_index = ignore_index
261
+ ctx.total_classes = total_classes
262
+ ctx.class_start_idx = class_start_idx
263
+ ctx.inplace_backward = inplace_backward
264
+
265
+ return losses, z_losses
266
+
267
+ @staticmethod
268
+ @input_guard
269
+ def backward(ctx, grad_losses, grad_z_losses):
270
+ del grad_z_losses # z_losses are only for logging.
271
+
272
+ logits, lse, target = ctx.saved_tensors
273
+ dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
274
+ n_rows, n_cols = logits.shape
275
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
276
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
277
+ def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
278
+ cross_entropy_bwd_kernel[grid](
279
+ dlogits, # data ptrs
280
+ grad_losses,
281
+ logits,
282
+ lse,
283
+ target,
284
+ ctx.label_smoothing,
285
+ ctx.logit_scale,
286
+ ctx.lse_square_scale,
287
+ ctx.ignore_index,
288
+ ctx.total_classes,
289
+ ctx.class_start_idx,
290
+ n_cols, # shapes
291
+ logits.stride(0), # strides
292
+ dlogits.stride(0),
293
+ grad_losses.stride(0),
294
+ BLOCK_SIZE=BLOCK_SIZE, # constants
295
+ num_warps=num_warps,
296
+ )
297
+ return dlogits, None, None, None, None, None, None, None, None
298
+
299
+
300
+ def cross_entropy_loss(
301
+ logits: torch.Tensor,
302
+ target: torch.Tensor,
303
+ label_smoothing: float = 0.0,
304
+ logit_scale: float = 1.0,
305
+ lse_square_scale: float = 0.0,
306
+ ignore_index=-100,
307
+ inplace_backward: bool = False,
308
+ process_group=None,
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ """
311
+ Arguments:
312
+ logits: [batch, vocab_size]
313
+ target: [batch,]
314
+ label_smoothing: float
315
+ logit_scale: float.
316
+ Multiply logits by this scale before calculating the loss.
317
+ lse_square_scale: float.
318
+ If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
319
+ This is also referred to as "z-loss".
320
+ ignore_index: int.
321
+ If target == ignore_index, the loss is set to 0.0.
322
+ inplace_backward: bool.
323
+ If True, we do the backward pass in-place by modifying the logits.
324
+ This saves memory.
325
+ process_group:
326
+ if not None, we're doing Tensor Parallel: each process is responsible for
327
+ one part of the vocab. The loss will be aggregated across processes.
328
+ Returns:
329
+ losses: [batch,], float
330
+ z_losses: [batch,], float
331
+ """
332
+ return CrossEntropyLossFunction.apply(
333
+ logits,
334
+ target,
335
+ label_smoothing,
336
+ logit_scale,
337
+ lse_square_scale,
338
+ ignore_index,
339
+ inplace_backward,
340
+ process_group,
341
+ )
342
+
343
+
344
+ class FusedCrossEntropyLoss(nn.Module):
345
+ def __init__(
346
+ self,
347
+ ignore_index: int = -100,
348
+ reduction: str = "mean",
349
+ label_smoothing: float = 0.0,
350
+ logit_scale: float = 1.0,
351
+ lse_square_scale: float = 0.0,
352
+ inplace_backward: bool = False,
353
+ process_group: Any = None,
354
+ return_z_loss: bool = False,
355
+ ):
356
+ """
357
+ Arguments:
358
+ ignore_index: int. If target == ignore_index, the loss is set to 0.0.
359
+ label_smoothing: float
360
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
361
+ This is also referred to as "z-loss".
362
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
363
+ This saves memory.
364
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
365
+ one part of the vocab. The loss will be aggregated across processes.
366
+ return_z_loss: bool. If True, we return the component of the loss contributed by
367
+ the lse_square_scale value. This value is only for logging and does not support
368
+ backprop.
369
+ """
370
+ super().__init__()
371
+ if reduction not in ["mean", "none", "sum"]:
372
+ raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
373
+ self.ignore_index = ignore_index
374
+ self.reduction = reduction
375
+ self.label_smoothing = label_smoothing
376
+ self.logit_scale = logit_scale
377
+ self.lse_square_scale = lse_square_scale
378
+ self.inplace_backward = inplace_backward
379
+ self.process_group = process_group
380
+ self.return_z_loss = return_z_loss
381
+
382
+ def forward(self, input, target):
383
+ """
384
+ Arguments:
385
+ input: (batch, vocab_size)
386
+ target: (batch,)
387
+ Returns:
388
+ losses: (batch,) if reduction is 'none', else (1,), dtype float
389
+ z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
390
+ """
391
+ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
392
+ loss, z_loss = cross_entropy_loss(
393
+ input,
394
+ target,
395
+ label_smoothing=self.label_smoothing,
396
+ logit_scale=self.logit_scale,
397
+ lse_square_scale=self.lse_square_scale,
398
+ ignore_index=self.ignore_index,
399
+ inplace_backward=self.inplace_backward,
400
+ process_group=self.process_group,
401
+ )
402
+ if self.reduction == "mean":
403
+ loss = loss.sum() / (target != self.ignore_index).sum()
404
+ elif self.reduction == "sum":
405
+ loss = loss.sum()
406
+ else:
407
+ loss = loss
408
+
409
+ if not self.return_z_loss:
410
+ return loss
411
+
412
+ if self.reduction == "mean":
413
+ z_loss = z_loss.sum() / (target != self.ignore_index).sum()
414
+ elif self.reduction == "sum":
415
+ z_loss = z_loss.sum()
416
+ else:
417
+ z_loss = z_loss
418
+
419
+ return loss, z_loss
fla/modules/fused_kl_div.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.ops.utils.op import exp, log
12
+ from fla.utils import input_guard
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
15
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
16
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
17
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
18
+ MAX_FUSED_SIZE = 65536 // 2
19
+
20
+
21
+ @triton.jit
22
+ def kl_div_kernel(
23
+ logits,
24
+ target_logits,
25
+ loss,
26
+ s_logits,
27
+ s_loss,
28
+ reduction: tl.constexpr,
29
+ N: tl.constexpr,
30
+ V: tl.constexpr,
31
+ BV: tl.constexpr
32
+ ):
33
+ # https://github.com/triton-lang/triton/issues/1058
34
+ # If N*V is too large, i_n * stride will overflow out of int32, so we convert to int64
35
+ i_n = tl.program_id(0).to(tl.int64)
36
+
37
+ logits += i_n * s_logits
38
+ target_logits += i_n * s_logits
39
+
40
+ # m is the max value. use the notation from the paper
41
+ sm = float('-inf')
42
+ tm = float('-inf')
43
+ # d is the sum. use the notation from the paper
44
+ sd, td = 0.0, 0.0
45
+
46
+ NV = tl.cdiv(V, BV)
47
+ for iv in range(0, NV):
48
+ o_x = iv * BV + tl.arange(0, BV)
49
+ # for student
50
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
51
+ b_sm = tl.max(b_sl)
52
+ m_new = tl.maximum(sm, b_sm)
53
+ sd = sd * exp(sm - m_new) + tl.sum(exp(b_sl - m_new))
54
+ sm = m_new
55
+ # for teacher
56
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
57
+ b_tm = tl.max(b_tl)
58
+ m_new = tl.maximum(tm, b_tm)
59
+ td = td * exp(tm - m_new) + tl.sum(exp(b_tl - m_new))
60
+ tm = m_new
61
+
62
+ b_loss = 0.
63
+ # KL(y_true || y) = exp(y_true) * (log(y_true) - log(y))
64
+ for iv in range(0, NV):
65
+ o_x = iv * BV + tl.arange(0, BV)
66
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
67
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
68
+ b_sp_log = b_sl - sm - log(sd)
69
+ b_tp_log = b_tl - tm - log(td)
70
+ b_sp = exp(b_sp_log)
71
+ b_tp = exp(b_tp_log)
72
+ b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0)
73
+ b_dl = -b_tp + b_sp
74
+ b_loss += tl.sum(b_kl)
75
+ if reduction == 'batchmean':
76
+ b_dl = b_dl / N
77
+ tl.store(logits + o_x, b_dl, mask=o_x < V)
78
+
79
+ # Normalize the loss by the number of elements if reduction is 'batchmean'
80
+ if reduction == 'batchmean':
81
+ b_loss = b_loss / N
82
+
83
+ tl.store(loss + i_n * s_loss, b_loss)
84
+
85
+
86
+ @triton.jit
87
+ def elementwise_mul_kernel(
88
+ x,
89
+ g,
90
+ N: tl.constexpr,
91
+ B: tl.constexpr
92
+ ):
93
+ """
94
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
95
+ The multiplication is performed in-place on the tensor pointed by x.
96
+
97
+ Parameters:
98
+ x:
99
+ Pointer to the input tensor.
100
+ g:
101
+ Pointer to the gradient output value.
102
+ N (int):
103
+ The number of columns in the input tensor.
104
+ B (int):
105
+ The block size for Triton operations.
106
+ """
107
+
108
+ # Get the program ID and convert it to int64 to avoid overflow
109
+ i_x = tl.program_id(0).to(tl.int64)
110
+ o_x = i_x * B + tl.arange(0, B)
111
+
112
+ # Load the gradient output value
113
+ b_g = tl.load(g)
114
+ b_x = tl.load(x + o_x, mask=o_x < N)
115
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
116
+
117
+
118
+ def fused_kl_div_forward(
119
+ x: torch.Tensor,
120
+ target_x: torch.Tensor,
121
+ weight: torch.Tensor,
122
+ target_weight: torch.Tensor,
123
+ reduction: str = 'batchmean'
124
+ ):
125
+ device = x.device
126
+
127
+ # ideally, we would like to achieve the same memory consumption as [N, H],
128
+ # so the expected chunk size should be:
129
+ # NC = ceil(V / H)
130
+ # C = ceil(N / NC)
131
+ # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
132
+ N, H, V = *x.shape, weight.shape[0]
133
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
134
+ # TODO: in real cases, we may need to limit the number of chunks NC to
135
+ # ensure the precisions of accumulated gradients
136
+ NC = min(8, triton.cdiv(V, H))
137
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
138
+ NC = triton.cdiv(N, C)
139
+
140
+ dx = torch.zeros_like(x, device=device)
141
+ dw = torch.zeros_like(weight, device=device) if weight is not None else None
142
+ # we use fp32 for loss accumulator
143
+ loss = torch.zeros(N, dtype=torch.float32, device=device)
144
+
145
+ for ic in range(NC):
146
+ start, end = ic * C, min((ic + 1) * C, N)
147
+ # [C, N]
148
+ c_sx = x[start:end]
149
+ c_tx = target_x[start:end]
150
+ # when doing matmul, use the original precision
151
+ # [C, V]
152
+ c_sl = F.linear(c_sx, weight)
153
+ c_tl = F.linear(c_tx, target_weight)
154
+
155
+ # unreduced loss
156
+ c_loss = loss[start:end]
157
+
158
+ # Here we calculate the gradient of c_sx in place so we can save memory.
159
+ kl_div_kernel[(c_sx.shape[0],)](
160
+ logits=c_sl,
161
+ target_logits=c_tl,
162
+ loss=c_loss,
163
+ s_logits=c_sl.stride(-2),
164
+ s_loss=c_loss.stride(-1),
165
+ reduction=reduction,
166
+ N=N,
167
+ V=V,
168
+ BV=BV,
169
+ num_warps=32
170
+ )
171
+
172
+ # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
173
+ # thus dx[start: end] should be of shape: C x H
174
+ # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
175
+ # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
176
+ # Thus, we need an additional scaling factor of (n_non_ignore/total) to scale the gradients.
177
+ # [C, H]
178
+
179
+ dx[start:end] = torch.mm(c_sl, weight)
180
+
181
+ if weight is not None:
182
+ torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw)
183
+
184
+ loss = loss.sum()
185
+ return loss, dx, dw
186
+
187
+
188
+ def fused_kl_div_backward(
189
+ do: torch.Tensor,
190
+ dx: torch.Tensor,
191
+ dw: torch.Tensor
192
+ ):
193
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
194
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
195
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
196
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
197
+ N, H = dx.shape
198
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
199
+
200
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
201
+ x=dx,
202
+ g=do,
203
+ N=N*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ # handle dw
209
+ if dw is not None:
210
+ V, H = dw.shape
211
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
212
+ x=dw,
213
+ g=do,
214
+ N=V*H,
215
+ B=B,
216
+ num_warps=32,
217
+ )
218
+
219
+ return dx, dw
220
+
221
+
222
+ class FusedKLDivLossFunction(torch.autograd.Function):
223
+
224
+ @staticmethod
225
+ @input_guard
226
+ def forward(
227
+ ctx,
228
+ x: torch.Tensor,
229
+ target_x: torch.Tensor,
230
+ weight: torch.Tensor,
231
+ target_weight: torch.Tensor,
232
+ reduction: str
233
+ ):
234
+ loss, dx, dw = fused_kl_div_forward(
235
+ x=x,
236
+ target_x=target_x,
237
+ weight=weight,
238
+ target_weight=target_weight,
239
+ reduction=reduction
240
+ )
241
+ ctx.save_for_backward(dx, dw)
242
+ return loss
243
+
244
+ @staticmethod
245
+ @input_guard
246
+ def backward(ctx, do):
247
+ dx, dw = ctx.saved_tensors
248
+ dx, dw = fused_kl_div_backward(do, dx, dw)
249
+ return dx, None, dw, None, None
250
+
251
+
252
+ def fused_kl_div_loss(
253
+ x: torch.Tensor,
254
+ target_x: torch.Tensor,
255
+ weight: torch.Tensor,
256
+ target_weight: torch.Tensor,
257
+ reduction: str = 'batchmean'
258
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ """
260
+ Args:
261
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
262
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
266
+ where `vocab_size` is the number of classes.
267
+ reduction:
268
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
269
+ Returns:
270
+ loss
271
+ """
272
+ return FusedKLDivLossFunction.apply(
273
+ x,
274
+ target_x,
275
+ weight,
276
+ target_weight,
277
+ reduction
278
+ )
279
+
280
+
281
+ class FusedKLDivLoss(nn.Module):
282
+
283
+ def __init__(
284
+ self,
285
+ reduction: str = 'batchmean'
286
+ ):
287
+ """
288
+ Args:
289
+ reduction:
290
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
291
+ """
292
+ super().__init__()
293
+
294
+ assert reduction in ['batchmean'], f"reduction: {reduction} is not supported"
295
+
296
+ self.reduction = reduction
297
+
298
+ def forward(
299
+ self,
300
+ x: torch.Tensor,
301
+ target_x: torch.Tensor,
302
+ weight: torch.Tensor,
303
+ target_weight: torch.Tensor
304
+ ):
305
+ """
306
+ Args:
307
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
308
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
309
+ weight (torch.Tensor): [vocab_size, hidden_size]
310
+ where `vocab_size` is the number of classes.
311
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
312
+ where `vocab_size` is the number of classes.
313
+ Returns:
314
+ loss
315
+ """
316
+ loss = fused_kl_div_loss(
317
+ x=x,
318
+ target_x=target_x,
319
+ weight=weight,
320
+ target_weight=target_weight,
321
+ reduction=self.reduction
322
+ )
323
+ return loss
fla/modules/parallel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch.nn as nn
7
+ from torch.distributed import DeviceMesh
8
+ from torch.distributed.tensor import DTensor, distribute_module
9
+ from torch.distributed.tensor.parallel import ParallelStyle
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class PrepareModuleWeight(ParallelStyle):
14
+ def __init__(self, *, layouts: Optional[Placement] = None):
15
+ super().__init__()
16
+ self.layouts = layouts
17
+
18
+ def _replicate_module_fn(
19
+ self,
20
+ name: str,
21
+ module: nn.Module,
22
+ device_mesh: DeviceMesh
23
+ ):
24
+ for p_name, param in module.named_parameters():
25
+ replicated_param = nn.Parameter(
26
+ DTensor.from_local(param, device_mesh, [self.layouts], run_check=False)
27
+ )
28
+ module.register_parameter(p_name, replicated_param)
29
+
30
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
31
+ return distribute_module(
32
+ module,
33
+ device_mesh,
34
+ partition_fn=self._replicate_module_fn,
35
+ input_fn=None,
36
+ output_fn=None
37
+ )
fla/ops/abc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+
5
+ __all__ = [
6
+ 'chunk_abc'
7
+ ]
fla/ops/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
fla/ops/abc/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (72 kB). View file
 
fla/ops/abc/chunk.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def chunk_abc_fwd_kernel_h(
17
+ k,
18
+ v,
19
+ z,
20
+ h,
21
+ h0,
22
+ ht,
23
+ T,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ NT: tl.constexpr,
30
+ NORMK: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr
33
+ ):
34
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+
36
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
37
+ if USE_INITIAL_STATE:
38
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
39
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
40
+ if NORMK:
41
+ p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,))
42
+ else:
43
+ p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,))
44
+ b_zp = tl.load(p_z0).to(tl.float32)
45
+ for i_t in range(NT):
46
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
47
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
48
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
49
+
50
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
51
+ # [BK, BT]
52
+ b_k = tl.load(p_k, boundary_check=(0, 1))
53
+ # [BT, BV]
54
+ b_v = tl.load(p_v, boundary_check=(0, 1))
55
+ if NORMK:
56
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
57
+ # [BK,]
58
+ b_zc = tl.load(p_zc, boundary_check=(0,))
59
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
60
+ # [BK, BV]
61
+ b_h = b_h * b_r[:, None]
62
+ b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype)
63
+ else:
64
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
65
+ # [BV,]
66
+ b_zc = tl.load(p_zc, boundary_check=(0,))
67
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
68
+ # [BK, BV]
69
+ b_h = b_h * b_r[None, :]
70
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
71
+ # [BK, BV]
72
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
73
+
74
+ if STORE_FINAL_STATE:
75
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+
79
+ @triton.jit(do_not_specialize=['T'])
80
+ def chunk_abc_fwd_kernel_intra_K(
81
+ v,
82
+ z,
83
+ o,
84
+ A,
85
+ T,
86
+ V: tl.constexpr,
87
+ BT: tl.constexpr,
88
+ BC: tl.constexpr,
89
+ BV: tl.constexpr,
90
+ NC: tl.constexpr
91
+ ):
92
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_t, i_i = i_c // NC, i_c % NC
94
+
95
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
97
+ # [BV,]
98
+ b_zn = tl.load(p_zn, boundary_check=(0,))
99
+ # [BC, BV]
100
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
101
+ for i_j in range(0, i_i):
102
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
103
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
104
+ # [BC, BV]
105
+ b_v = tl.load(p_v, boundary_check=(0, 1))
106
+ # [BC, BC]
107
+ b_A = tl.load(p_A, boundary_check=(0, 1))
108
+ b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)
109
+ b_z = tl.load(p_z, boundary_check=(0, 1))
110
+ b_o *= exp(b_zn[None, :] - b_z)
111
+
112
+ o_i = tl.arange(0, BC)
113
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
114
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
115
+ for j in range(0, BC):
116
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
117
+ # [BC,]
118
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
119
+ # [BV,]
120
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
121
+ # [BC, BV]
122
+ # avoid 0 * inf = inf
123
+ m_i = o_i[:, None] >= j
124
+ b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0)
125
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+
128
+
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def chunk_abc_fwd_kernel_K(
131
+ q,
132
+ k,
133
+ z,
134
+ h,
135
+ o,
136
+ A,
137
+ scale,
138
+ T,
139
+ K: tl.constexpr,
140
+ V: tl.constexpr,
141
+ BT: tl.constexpr,
142
+ BK: tl.constexpr,
143
+ BV: tl.constexpr,
144
+ NT: tl.constexpr
145
+ ):
146
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
147
+ i_p = tl.maximum(i_t * BT - 1, 0)
148
+
149
+ o_i = tl.arange(0, BT)
150
+ m_s = o_i[:, None] >= o_i[None, :]
151
+
152
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
153
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
154
+ for i_k in range(tl.cdiv(K, BK)):
155
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
157
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
158
+
159
+ # [BT, BK]
160
+ b_q = tl.load(p_q, boundary_check=(0, 1))
161
+ b_q = (b_q * scale).to(b_q.dtype)
162
+ # [BK, BT]
163
+ b_k = tl.load(p_k, boundary_check=(0, 1))
164
+ # [BK, BV]
165
+ b_h = tl.load(p_h, boundary_check=(0, 1))
166
+ # [BT, BV]
167
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
168
+ # [BT, BT]
169
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
170
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
171
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ # [BT, BV]
173
+ b_z = tl.load(p_z, boundary_check=(0, 1))
174
+ # [BT, BV]
175
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
176
+ b_zp = tl.load(p_zp, boundary_check=(0,))
177
+ b_o = b_o * exp(b_zp[None, :] - b_z)
178
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
179
+
180
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+ # [BT, BT]
182
+ b_A = tl.where(m_s, b_A, 0.)
183
+ if i_v == 0:
184
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
185
+
186
+
187
+ @triton.jit(do_not_specialize=['T'])
188
+ def chunk_abc_fwd_kernel_intra_V(
189
+ q,
190
+ k,
191
+ z,
192
+ A,
193
+ scale,
194
+ T,
195
+ K: tl.constexpr,
196
+ BT: tl.constexpr,
197
+ BC: tl.constexpr,
198
+ BK: tl.constexpr,
199
+ NC: tl.constexpr
200
+ ):
201
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
203
+ n_bh = tl.num_programs(2)
204
+
205
+ if i_i > i_j:
206
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
208
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
209
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
210
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
211
+ # [BK,]
212
+ b_zn = tl.load(p_zn, boundary_check=(0,))
213
+ # [BC, BK]
214
+ b_q = tl.load(p_q, boundary_check=(0, 1))
215
+ b_z = tl.load(p_z, boundary_check=(0, 1))
216
+ b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype)
217
+ # [BK, BC]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype)
220
+ # [BC, BC]
221
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
222
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
223
+ elif i_i == i_j:
224
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
225
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
226
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
227
+ # [BC, BK]
228
+ b_q = tl.load(p_q, boundary_check=(0, 1))
229
+ b_z = tl.load(p_z, boundary_check=(0, 1))
230
+
231
+ o_i = tl.arange(0, BC)
232
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
233
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
234
+ for j in range(0, BC):
235
+ # [BK,]
236
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
237
+ # [BC,]
238
+ b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1)
239
+ b_A = tl.where(o_i >= j, b_A, 0.)
240
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
241
+
242
+ p_k = tl.advance(p_k, (K,))
243
+
244
+
245
+ @triton.jit(do_not_specialize=['T'])
246
+ def chunk_abc_fwd_kernel_V(
247
+ q,
248
+ v,
249
+ z,
250
+ h,
251
+ o,
252
+ A,
253
+ scale,
254
+ T,
255
+ K: tl.constexpr,
256
+ V: tl.constexpr,
257
+ BT: tl.constexpr,
258
+ BK: tl.constexpr,
259
+ BV: tl.constexpr,
260
+ NT: tl.constexpr
261
+ ):
262
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
263
+ i_p = tl.maximum(i_t * BT - 1, 0)
264
+
265
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
266
+ for i_k in range(tl.cdiv(K, BK)):
267
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
268
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
270
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
271
+
272
+ # [BT, BK]
273
+ b_q = tl.load(p_q, boundary_check=(0, 1))
274
+ b_q = (b_q * scale).to(b_q.dtype)
275
+ # [BT, BK]
276
+ b_z = tl.load(p_z, boundary_check=(0, 1))
277
+ # [BT, BK]
278
+ b_zp = tl.load(p_zp, boundary_check=(0,))
279
+ b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype)
280
+ # [BK, BV]
281
+ b_h = tl.load(p_h, boundary_check=(0, 1))
282
+ # works but dkw, owing to divine benevolence
283
+ # [BT, BV]
284
+ if i_k >= 0:
285
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
286
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
287
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
288
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
289
+ # [BT, BV]
290
+ b_v = tl.load(p_v, boundary_check=(0, 1))
291
+ # [BT, BT]
292
+ b_A = tl.load(p_A, boundary_check=(0, 1))
293
+ b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False)
294
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
295
+
296
+
297
+ @triton.jit(do_not_specialize=['T'])
298
+ def chunk_abc_bwd_kernel_dh(
299
+ q,
300
+ z,
301
+ do,
302
+ dh,
303
+ scale,
304
+ T,
305
+ K: tl.constexpr,
306
+ V: tl.constexpr,
307
+ BT: tl.constexpr,
308
+ BK: tl.constexpr,
309
+ BV: tl.constexpr,
310
+ NT: tl.constexpr,
311
+ NORMK: tl.constexpr
312
+ ):
313
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
314
+
315
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
316
+ b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32)
317
+ for i_t in range(NT - 1, -1, -1):
318
+ i_p = tl.maximum(i_t * BT - 1, 0)
319
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
320
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
321
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+
323
+ # [BK, BT]
324
+ b_q = tl.load(p_q, boundary_check=(0, 1))
325
+ b_q = (b_q * scale).to(b_q.dtype)
326
+ # [BT, BV]
327
+ b_do = tl.load(p_do, boundary_check=(0, 1))
328
+
329
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
330
+ if NORMK:
331
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
333
+ # [BK,]
334
+ b_zc = tl.load(p_zc, boundary_check=(0,))
335
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
336
+ # [BK, BT]
337
+ b_z = tl.load(p_z, boundary_check=(0, 1))
338
+ b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype)
339
+ # [BK, BV]
340
+ b_dh = b_dh * b_r[:, None]
341
+ else:
342
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
343
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
344
+ # [BV,]
345
+ b_zc = tl.load(p_zc, boundary_check=(0,))
346
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
347
+ # [BT, BV]
348
+ b_z = tl.load(p_z, boundary_check=(0,))
349
+ b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype)
350
+ # [BK, BV]
351
+ b_dh = b_dh * b_r[None, :]
352
+ # [BK, BV]
353
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
354
+
355
+
356
+ @triton.jit(do_not_specialize=['T'])
357
+ def chunk_abc_bwd_kernel_V(
358
+ k,
359
+ v,
360
+ z,
361
+ h,
362
+ A,
363
+ do,
364
+ dh,
365
+ dq,
366
+ dk,
367
+ dv,
368
+ dA,
369
+ scale,
370
+ T,
371
+ K: tl.constexpr,
372
+ V: tl.constexpr,
373
+ BT: tl.constexpr,
374
+ BK: tl.constexpr,
375
+ BV: tl.constexpr,
376
+ NT: tl.constexpr
377
+ ):
378
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
379
+ i_p = tl.maximum(i_t * BT - 1, 0)
380
+ n_bh = tl.num_programs(2)
381
+
382
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
383
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
384
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
385
+
386
+ # [BK,]
387
+ b_zc = tl.load(p_zc, boundary_check=(0,))
388
+ # [BT, BK]
389
+ b_k = tl.load(p_k, boundary_check=(0, 1))
390
+ b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype)
391
+ # [BT, BT]
392
+ b_A = tl.load(p_A, boundary_check=(0, 1))
393
+
394
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
395
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
396
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
397
+ for i_v in range(tl.cdiv(V, BV)):
398
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
399
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
400
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
401
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
402
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
403
+
404
+ # [BT, BV]
405
+ b_v = tl.load(p_v, boundary_check=(0, 1))
406
+ # [BV, BK]
407
+ b_h = tl.load(p_h, boundary_check=(0, 1))
408
+ # [BT, BV]
409
+ b_do = tl.load(p_do, boundary_check=(0, 1))
410
+ # [BK, BV]
411
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
412
+
413
+ # [BT, BV]
414
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
415
+ if i_k == 0:
416
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False)
417
+ b_do = (b_do * scale).to(b_do.dtype)
418
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
419
+ # [BT, BT]
420
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
421
+ # [BT, BK]
422
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
423
+ # [BT, BK]
424
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
425
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
426
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
427
+ # [BK,]
428
+ b_zp = tl.load(p_zp, boundary_check=(0,))
429
+ # [BT, BK]
430
+ b_z = tl.load(p_z, boundary_check=(0, 1))
431
+ b_z = exp(b_zp[None, :] - b_z)
432
+ # [BT, BK]
433
+ b_dq = b_dq * b_z
434
+ b_dk = b_dk * b_k
435
+
436
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
437
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
438
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
439
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
440
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
441
+
442
+ o_i = tl.arange(0, BT)
443
+ m_s = o_i[:, None] >= o_i[None, :]
444
+ # [BT, BT]
445
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
446
+ if i_k == 0:
447
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
448
+
449
+
450
+ @triton.jit(do_not_specialize=['T'])
451
+ def chunk_abc_bwd_kernel_intra_V(
452
+ q,
453
+ k,
454
+ z,
455
+ dA,
456
+ dq,
457
+ dk,
458
+ T,
459
+ K: tl.constexpr,
460
+ BT: tl.constexpr,
461
+ BC: tl.constexpr,
462
+ BK: tl.constexpr,
463
+ NC: tl.constexpr
464
+ ):
465
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
466
+ i_t, i_i = i_c // NC, i_c % NC
467
+
468
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
469
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
470
+ # [BK,]
471
+ b_zn = tl.load(p_zn, boundary_check=(0,))
472
+ # [BC, BK]
473
+ b_z = tl.load(p_z, boundary_check=(0, 1))
474
+ b_zq = exp(b_zn[None, :] - b_z)
475
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
476
+ for i_j in range(0, i_i):
477
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
478
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
479
+ # [BC, BK]
480
+ b_k = tl.load(p_k, boundary_check=(0, 1))
481
+ b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype)
482
+ # [BC, BC]
483
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
484
+ # [BC, BK]
485
+ b_dq += tl.dot(b_dA, b_kz, allow_tf32=False)
486
+ b_dq *= b_zq
487
+
488
+ o_i = tl.arange(0, BC)
489
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
490
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
491
+ for j in range(0, BC):
492
+ p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
493
+ # [BC,]
494
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
495
+ # [BK,]
496
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
497
+ # [BC, BK]
498
+ m_i = o_i[:, None] >= j
499
+ # [BC, BK]
500
+ b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.)
501
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
502
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
503
+
504
+ tl.debug_barrier()
505
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
506
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
507
+ # [BK,]
508
+ b_zn = tl.load(p_zn, boundary_check=(0,))
509
+ # [BC, BK]
510
+ b_k = tl.load(p_k, boundary_check=(0, 1))
511
+ b_kz = exp(b_k - b_zn[None, :])
512
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
513
+ for i_j in range(i_i + 1, NC):
514
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
515
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
516
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
517
+ # [BC, BK]
518
+ b_q = tl.load(p_q, boundary_check=(0, 1))
519
+ b_z = tl.load(p_z, boundary_check=(0, 1))
520
+ b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype)
521
+ # [BC, BC]
522
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
523
+ # [BC, BK]
524
+ b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False)
525
+ b_dk *= b_kz
526
+
527
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
528
+ for j in range(0, BC):
529
+ p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
530
+ p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
531
+ # [BC,]
532
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
533
+ # [BK,]
534
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
535
+ b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32)
536
+ # [BC, BK]
537
+ m_i = o_i[:, None] <= j
538
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.)
539
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
540
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
541
+
542
+
543
+ @triton.jit(do_not_specialize=['T'])
544
+ def chunk_abc_bwd_kernel_intra_K(
545
+ v,
546
+ z,
547
+ do,
548
+ dA,
549
+ scale,
550
+ T,
551
+ V: tl.constexpr,
552
+ BT: tl.constexpr,
553
+ BC: tl.constexpr,
554
+ BV: tl.constexpr,
555
+ NC: tl.constexpr
556
+ ):
557
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
558
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
559
+ n_bh = tl.num_programs(2)
560
+
561
+ if i_i > i_j:
562
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
563
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
564
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
565
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
566
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
567
+ # [BV,]
568
+ b_zn = tl.load(p_zn, boundary_check=(0,))
569
+ # [BC, BV]
570
+ b_z = tl.load(p_z, boundary_check=(0, 1))
571
+ b_do = tl.load(p_do, boundary_check=(0, 1))
572
+ b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype)
573
+ # [BV, BC]
574
+ b_v = tl.load(p_v, boundary_check=(0, 1))
575
+ b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype)
576
+ # [BC, BC]
577
+ b_dA = tl.dot(b_do, b_v, allow_tf32=False)
578
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
579
+ elif i_i == i_j:
580
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
581
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
582
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
583
+ # [BC, BV]
584
+ b_z = tl.load(p_z, boundary_check=(0, 1))
585
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
586
+
587
+ o_i = tl.arange(0, BC)
588
+ o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
589
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
590
+ for j in range(0, BC):
591
+ # [BV,]
592
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
593
+ # [BC,]
594
+ b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1)
595
+ b_dA = tl.where(o_i >= j, b_dA, 0)
596
+ tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A)
597
+
598
+ p_v = tl.advance(p_v, (V,))
599
+
600
+
601
+ @triton.jit(do_not_specialize=['T'])
602
+ def chunk_abc_bwd_kernel_K(
603
+ q,
604
+ k,
605
+ v,
606
+ z,
607
+ h,
608
+ A,
609
+ do,
610
+ dh,
611
+ dq,
612
+ dk,
613
+ dv,
614
+ dA,
615
+ scale,
616
+ T,
617
+ K: tl.constexpr,
618
+ V: tl.constexpr,
619
+ BT: tl.constexpr,
620
+ BK: tl.constexpr,
621
+ BV: tl.constexpr,
622
+ NT: tl.constexpr
623
+ ):
624
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
625
+ i_p = tl.maximum(i_t * BT - 1, 0)
626
+ n_bh = tl.num_programs(2)
627
+
628
+ o_i = tl.arange(0, BT)
629
+ m_s = o_i[:, None] >= o_i[None, :]
630
+
631
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
632
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
633
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
634
+
635
+ # [BT, BK]
636
+ b_q = tl.load(p_q, boundary_check=(0, 1))
637
+ b_k = tl.load(p_k, boundary_check=(0, 1))
638
+ # [BT, BT]
639
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
640
+ b_A = tl.where(m_s, b_A, 0.)
641
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
642
+
643
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
644
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
645
+ for i_v in range(tl.cdiv(V, BV)):
646
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
647
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
648
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
649
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
650
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
651
+
652
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
653
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
654
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
655
+
656
+ # [BV,]
657
+ b_zp = tl.load(p_zp, boundary_check=(0,))
658
+ b_zc = tl.load(p_zc, boundary_check=(0,))
659
+ # [BT, BV]
660
+ b_v = tl.load(p_v, boundary_check=(0, 1))
661
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
662
+ b_z = tl.load(p_z, boundary_check=(0, 1))
663
+ b_z = exp(b_zp[None, :] - b_z)
664
+ # [BV, BK]
665
+ b_h = tl.load(p_h, boundary_check=(0, 1))
666
+ # [BT, BV]
667
+ b_do = tl.load(p_do, boundary_check=(0, 1))
668
+ b_do = (b_do * b_z * scale).to(b_do.dtype)
669
+ # [BK, BV]
670
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
671
+
672
+ # [BT, BK]
673
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
674
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
675
+ # [BT, BV]
676
+ b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False)
677
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
678
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
679
+ # [BT, BT]
680
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
681
+ # [BT, BK]
682
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
683
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
684
+
685
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
686
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
687
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
688
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
689
+
690
+
691
+ @triton.jit(do_not_specialize=['T'])
692
+ def chunk_abc_bwd_kernel_intra_KV(
693
+ v,
694
+ z,
695
+ A,
696
+ do,
697
+ dv,
698
+ T,
699
+ V: tl.constexpr,
700
+ BT: tl.constexpr,
701
+ BC: tl.constexpr,
702
+ BV: tl.constexpr,
703
+ NC: tl.constexpr
704
+ ):
705
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
706
+ i_t, i_i = i_c // NC, i_c % NC
707
+
708
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
709
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
710
+ # [BV,]
711
+ b_zn = tl.load(p_zn, boundary_check=(0,))
712
+ # [BC, BV]
713
+ b_v = tl.load(p_v, boundary_check=(0, 1))
714
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
715
+ for i_j in range(i_i + 1, NC):
716
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
717
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
718
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
719
+ # [BC, BV]
720
+ b_z = tl.load(p_z, boundary_check=(0, 1))
721
+ b_do = tl.load(p_do, boundary_check=(0, 1))
722
+ b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype)
723
+ # [BC, BC]
724
+ b_A = tl.load(p_A, boundary_check=(0, 1))
725
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
726
+ b_dv *= exp(b_v - b_zn[None, :])
727
+
728
+ o_i = tl.arange(0, BC)
729
+ for j in range(0, BC):
730
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
731
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
732
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
733
+ # [BC,]
734
+ b_A = tl.load(p_A, boundary_check=(0,))
735
+ # [BV,]
736
+ b_z = tl.load(p_z, boundary_check=(0,))
737
+ b_do = tl.load(p_do, boundary_check=(0,))
738
+ # [BC, BV]
739
+ m_i = o_i[:, None] <= j
740
+ b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.)
741
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
742
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
743
+
744
+
745
+ @triton.jit(do_not_specialize=['T'])
746
+ def chunk_abc_bwd_kernel_rcum_inter(
747
+ s,
748
+ z,
749
+ ss,
750
+ doo,
751
+ T,
752
+ S: tl.constexpr,
753
+ BT: tl.constexpr,
754
+ BS: tl.constexpr,
755
+ NT: tl.constexpr
756
+ ):
757
+ i_m, i_bh = tl.program_id(0), tl.program_id(1)
758
+
759
+ b_sp = tl.zeros([BS,], dtype=tl.float32)
760
+ b_zp = tl.full([BS,], float('inf'), dtype=tl.float32)
761
+ for i_t in range(NT - 1, -1, -1):
762
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
763
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
764
+ p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,))
765
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
766
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
767
+ # [BS,]
768
+ b_zc = tl.load(p_zc, boundary_check=(0,))
769
+ # [BT, BS]
770
+ b_s = tl.load(p_s, boundary_check=(0, 1))
771
+ b_z = tl.load(p_z, boundary_check=(0, 1))
772
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
773
+
774
+ b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :]
775
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
776
+ # [BS,]
777
+ b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0)
778
+ b_zp = b_zc
779
+
780
+
781
+ @triton.jit(do_not_specialize=['T'])
782
+ def chunk_abc_bwd_kernel_rcum_intra(
783
+ s,
784
+ z,
785
+ ss,
786
+ doo,
787
+ T,
788
+ S: tl.constexpr,
789
+ BT: tl.constexpr,
790
+ BC: tl.constexpr,
791
+ BS: tl.constexpr,
792
+ NC: tl.constexpr
793
+ ):
794
+ i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
795
+ i_t, i_i = i_c // NC, i_c % NC
796
+
797
+ o_i = tl.arange(0, BC)
798
+ m_o = tl.full([BC, BC], 1., dtype=tl.float32)
799
+
800
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
801
+ p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,))
802
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
803
+ # [BC, BS]
804
+ b_s = tl.load(p_s, boundary_check=(0, 1))
805
+ # [BS,]
806
+ b_zn = tl.load(p_zn, boundary_check=(0,))
807
+
808
+ b_doo = tl.zeros([BC, BS], dtype=tl.float32)
809
+ for i_j in range(i_i + 1, NC):
810
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
811
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
812
+ # [BC, BS]
813
+ b_z = tl.load(p_z, boundary_check=(0, 1))
814
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
815
+ # [BC, BS]
816
+ b_doo += b_ss * exp(b_zn[None, :] - b_z)
817
+ b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False)
818
+
819
+ for j in range(0, BC):
820
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
821
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
822
+ # [BS,]
823
+ b_z = tl.load(p_z, boundary_check=(0,))
824
+ b_ss = tl.load(p_ss, boundary_check=(0,))
825
+ # [BC, BS]
826
+ m_i = o_i[:, None] <= j
827
+ b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.)
828
+ b_doo += tl.load(p_doo, boundary_check=(0, 1))
829
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
830
+
831
+
832
+ class ChunkABCFunction(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ @input_guard
836
+ def forward(ctx, q, k, v, s, initial_state, output_final_state):
837
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
838
+ BT, BC = 64, 16
839
+ BK = min(64, triton.next_power_of_2(K))
840
+ BV = min(64, triton.next_power_of_2(V))
841
+ BM = min(64, triton.next_power_of_2(M))
842
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
843
+ NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)
844
+ num_warps = 4 if BK == 64 else 2
845
+ num_stages = 1
846
+
847
+ def fwd_pre(s, B, H, T, S):
848
+ # keep cummulative normalizer in fp32
849
+ z = torch.empty_like(s, dtype=torch.float)
850
+ grid = (B * H,)
851
+ logcumsumexp_fwd_kernel[grid](
852
+ s, z,
853
+ T=T, S=S
854
+ )
855
+ return z
856
+
857
+ def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):
858
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
859
+ h = q.new_empty(B, H, NT * K, V)
860
+ grid = (NV, NK, B * H)
861
+ chunk_abc_fwd_kernel_h[grid](
862
+ k, v, z, h, h0, ht,
863
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
864
+ NORMK=normk,
865
+ USE_INITIAL_STATE=h0 is not None,
866
+ STORE_FINAL_STATE=ht is not None,
867
+ num_warps=num_warps,
868
+ num_stages=num_stages
869
+ )
870
+ return h
871
+
872
+ final_state = None
873
+ if output_final_state:
874
+ final_state = (q.new_empty(B, H, K, M, dtype=torch.float),
875
+ q.new_empty(B, H, M, V, dtype=torch.float))
876
+
877
+ z = fwd_pre(s, B, H, T, M)
878
+ scale = K ** -0.5
879
+ hk = fwd_inner(
880
+ q=q, k=k, v=s, z=z,
881
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
882
+ normk=False,
883
+ h0=initial_state[0] if initial_state is not None else None,
884
+ ht=final_state[0] if final_state is not None else None
885
+ )
886
+ ok1 = torch.empty_like(s)
887
+ Ak = q.new_empty(B, H, T, BT)
888
+ grid = (NM, NT, B * H)
889
+ chunk_abc_fwd_kernel_K[grid](
890
+ q, k, z, hk, ok1, Ak,
891
+ scale=scale,
892
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
893
+ num_warps=num_warps,
894
+ num_stages=num_stages
895
+ )
896
+ ok0 = torch.empty_like(s)
897
+ grid = (NM, NT * NC, B * H)
898
+ chunk_abc_fwd_kernel_intra_K[grid](
899
+ s, z, ok0, Ak,
900
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
901
+ num_warps=2,
902
+ num_stages=num_stages
903
+ )
904
+ ok = ok0.add_(ok1)
905
+
906
+ scale = 1.
907
+ # p is kept in fp32 for safe softmax backward
908
+ p = softmax_fwd(ok, dtype=torch.float)
909
+ qv = p.to(q.dtype)
910
+
911
+ scale = 1.
912
+ hv = fwd_inner(
913
+ q=qv, k=s, v=v, z=z,
914
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
915
+ normk=True,
916
+ h0=initial_state[1] if initial_state is not None else None,
917
+ ht=final_state[1] if final_state is not None else None
918
+ )
919
+ Av = q.new_zeros(NM, B, H, T, BT)
920
+ grid = (NM, NT * NC * NC, B * H)
921
+ chunk_abc_fwd_kernel_intra_V[grid](
922
+ qv, s, z, Av,
923
+ scale=scale,
924
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
925
+ num_warps=2,
926
+ num_stages=num_stages
927
+ )
928
+ Av = Av.sum(0)
929
+ ov = torch.empty_like(v)
930
+ grid = (NV, NT, B * H)
931
+ chunk_abc_fwd_kernel_V[grid](
932
+ qv, v, z, hv, ov, Av,
933
+ scale=scale,
934
+ T=T,
935
+ K=M,
936
+ V=V,
937
+ BT=BT,
938
+ BK=BM,
939
+ BV=BV,
940
+ NT=NT,
941
+ num_warps=num_warps,
942
+ num_stages=num_stages
943
+ )
944
+ ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av)
945
+ ctx.BT = BT
946
+ return ov, final_state
947
+
948
+ @staticmethod
949
+ @input_guard
950
+ def backward(ctx, dov, dht=None):
951
+ q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors
952
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
953
+ BT, BC = ctx.BT, 16
954
+ BK = min(64, triton.next_power_of_2(K))
955
+ BV = min(64, triton.next_power_of_2(V))
956
+ BM = min(64, triton.next_power_of_2(M))
957
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
958
+ NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM)
959
+ num_warps = 4 if BK == 64 else 2
960
+ num_stages = 1
961
+
962
+ def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False):
963
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
964
+ dh = q.new_empty(B, H, NT * K, V)
965
+ grid = (NK, NV, B * H)
966
+ chunk_abc_bwd_kernel_dh[grid](
967
+ q, z, do, dh,
968
+ scale=scale,
969
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
970
+ NORMK=normk,
971
+ num_warps=num_warps,
972
+ num_stages=num_stages
973
+ )
974
+ return dh
975
+
976
+ def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS):
977
+ doo = torch.empty_like(s)
978
+ grid = (NS, B * H)
979
+ chunk_abc_bwd_kernel_rcum_inter[grid](
980
+ s, z, ss, doo,
981
+ T=T, S=S, BT=BT, BS=BS, NT=NT,
982
+ num_warps=num_warps,
983
+ num_stages=num_stages
984
+ )
985
+ grid = (NS, NT * NC, B * H)
986
+ chunk_abc_bwd_kernel_rcum_intra[grid](
987
+ s, z, ss, doo,
988
+ T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC,
989
+ num_warps=num_warps,
990
+ num_stages=num_stages
991
+ )
992
+ return doo
993
+
994
+ scale = 1.
995
+ qv = p.to(q.dtype)
996
+ dhv = bwd_inner(
997
+ qv, z, dov,
998
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
999
+ scale=scale,
1000
+ normk=True
1001
+ )
1002
+ dp1 = torch.empty_like(p)
1003
+ dsv1 = torch.empty_like(s, dtype=torch.float)
1004
+ dv = v.new_empty(NM, *v.shape)
1005
+ dAv = q.new_zeros(B, H, T, BT)
1006
+ grid = (NM, NT, B * H)
1007
+ chunk_abc_bwd_kernel_V[grid](
1008
+ s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv,
1009
+ scale=scale,
1010
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1011
+ num_warps=num_warps,
1012
+ num_stages=num_stages
1013
+ )
1014
+ dv = dv.sum(0)
1015
+ dp0 = torch.empty_like(p)
1016
+ dsv0 = s.new_zeros(s.shape, dtype=torch.float)
1017
+ grid = (NM, NT * NC, B * H)
1018
+ chunk_abc_bwd_kernel_intra_V[grid](
1019
+ qv, s, z, dAv, dp0, dsv0,
1020
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1021
+ num_warps=2,
1022
+ num_stages=num_stages
1023
+ )
1024
+ dp = dp1.add_(dp0)
1025
+ dsv = dsv1.add_(dsv0)
1026
+
1027
+ # softmax gradient, equivalent to:
1028
+ # dok = p * (dp - (p * dp).sum(-1, True))
1029
+ dok = softmax_bwd(p, dp, dtype=ok.dtype)
1030
+
1031
+ scale = K ** -0.5
1032
+ dhk = bwd_inner(
1033
+ q, z, dok,
1034
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1035
+ scale=scale,
1036
+ normk=False
1037
+ )
1038
+ dAk = q.new_zeros(NM, B, H, T, BT)
1039
+ grid = (NM, NT * NC * NC, B * H)
1040
+ chunk_abc_bwd_kernel_intra_K[grid](
1041
+ s, z, dok, dAk,
1042
+ scale=scale,
1043
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1044
+ num_warps=2,
1045
+ num_stages=num_stages
1046
+ )
1047
+ dAk = dAk.sum(0)
1048
+
1049
+ Ak = q.new_zeros(NK, B, H, T, BT)
1050
+ dq = torch.empty_like(q)
1051
+ dk = torch.empty_like(k)
1052
+ dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float)
1053
+ grid = (NK, NT, B * H)
1054
+ chunk_abc_bwd_kernel_K[grid](
1055
+ q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk,
1056
+ scale=scale,
1057
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1058
+ num_warps=num_warps,
1059
+ num_stages=num_stages
1060
+ )
1061
+ Ak = Ak.sum(0)
1062
+ dsk1 = dsk1.sum(0)
1063
+ dsk0 = torch.empty_like(s, dtype=torch.float)
1064
+ grid = (NM, NT * NC, B * H)
1065
+ chunk_abc_bwd_kernel_intra_KV[grid](
1066
+ s, z, Ak, dok, dsk0,
1067
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1068
+ num_warps=2,
1069
+ num_stages=num_stages
1070
+ )
1071
+ ds = dsv.add_(dsk1.add_(dsk0))
1072
+ ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM)
1073
+ ds = ds.to(s.dtype)
1074
+ return dq, dk, dv, ds, None, None
1075
+
1076
+
1077
+ @torch.compiler.disable
1078
+ def chunk_abc(
1079
+ q: torch.Tensor,
1080
+ k: torch.Tensor,
1081
+ v: torch.Tensor,
1082
+ s: torch.Tensor,
1083
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1084
+ output_final_state: bool = False,
1085
+ head_first: bool = True
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ r"""
1088
+ Args:
1089
+ q (torch.Tensor):
1090
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1091
+ k (torch.Tensor):
1092
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1093
+ v (torch.Tensor):
1094
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
1095
+ s (torch.Tensor):
1096
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
1097
+ initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
1098
+ Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`.
1099
+ output_final_state (Optional[bool]):
1100
+ Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`.
1101
+ head_first (Optional[bool]):
1102
+ Whether the inputs are in the head-first format.
1103
+ Default: `True`.
1104
+
1105
+ Returns:
1106
+ o (torch.Tensor):
1107
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1108
+ final_state (torch.Tensor):
1109
+ Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`.
1110
+ """
1111
+ if not head_first:
1112
+ q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s))
1113
+ o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)
1114
+ if not head_first:
1115
+ o = o.transpose(1, 2)
1116
+ return o, final_state
fla/ops/attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (220 Bytes). View file
 
fla/ops/attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (286 Bytes). View file
 
fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
fla/ops/based/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/based/fused_chunk.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+
13
+ @triton.jit(do_not_specialize=['T'])
14
+ def fused_chunk_based_fwd_kernel(
15
+ q,
16
+ k,
17
+ v,
18
+ o,
19
+ z,
20
+ scale, # K ** -0.5
21
+ T,
22
+ B: tl.constexpr,
23
+ H: tl.constexpr,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ ):
30
+ # indices
31
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+
33
+ o_i = tl.arange(0, BT)
34
+
35
+ # [BT, BT]
36
+ m_s = o_i[:, None] >= o_i[None, :]
37
+
38
+ # [BV], zero-order taylor expansion
39
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
40
+ # [BK, BV], first-order taylor expansion
41
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
42
+ # [BK, BK, BV] second-order taylor expansion
43
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
44
+
45
+ # make block pointers
46
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
49
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+
51
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
52
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
53
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
54
+ k_0o = 0
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BK, BT]
58
+ b_k = tl.load(p_k, boundary_check=(0, 1))
59
+ # [BK*BK, BT]
60
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
61
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BT, BK]
65
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
66
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
67
+ b_z = tl.zeros([BT], dtype=tl.float32)
68
+
69
+ # interchunk
70
+ # zero-order
71
+ b_o += b_h_0o
72
+ b_z += k_0o
73
+ # first-order
74
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
75
+ b_z += tl.sum(b_q * k_1o, axis=1)
76
+ # second-order
77
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
78
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
79
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
80
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
81
+
82
+ # update running statistics
83
+ k_1o += tl.sum(b_k, axis=1)[None, :]
84
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
85
+ k_0o += BT
86
+
87
+ # intrachunk
88
+ # [BT, BT]
89
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
90
+ b_s = 1 + b_s + 0.5 * b_s * b_s
91
+ b_s = tl.where(m_s, b_s, 0)
92
+ b_z += tl.sum(b_s, axis=1)
93
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
94
+ # [TB, BV]
95
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
96
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
97
+
98
+ # update hidden state
99
+ # [BK, BV]
100
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
101
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
102
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
103
+
104
+ p_q = tl.advance(p_q, (BT, 0))
105
+ p_k = tl.advance(p_k, (0, BT))
106
+ p_v = tl.advance(p_v, (BT, 0))
107
+ p_o = tl.advance(p_o, (BT, 0))
108
+ p_z += BT
109
+
110
+
111
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
112
+ @triton.jit
113
+ def fused_chunk_based_bwd_kernel(
114
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ dk,
122
+ dv,
123
+ scale, # K ** -0.5
124
+ T,
125
+ B: tl.constexpr,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ ):
133
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
134
+
135
+ o_i = tl.arange(0, BT)
136
+ m_s = o_i[:, None] >= o_i[None, :]
137
+
138
+ # [BV], zero-order taylor expansion
139
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
140
+ # [BK, BV], first-order taylor expansion
141
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
142
+ # [BK, BK, BV] second-order taylor expansion
143
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
144
+
145
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
146
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
147
+
148
+ for i in range(0, tl.cdiv(T, BT)):
149
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
150
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
151
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
152
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
154
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
155
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
156
+
157
+ # load tensors
158
+ # [BT, BK]
159
+ b_q = tl.load(p_q, boundary_check=(0, 1))
160
+ b_q = (b_q * scale).to(b_q.dtype)
161
+ b_k = tl.load(p_k, boundary_check=(0, 1))
162
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
163
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
164
+ # [BV, BT]
165
+ b_v = tl.load(p_v, boundary_check=(0, 1))
166
+
167
+ # inter-chunk
168
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
169
+ if i_v == 0:
170
+ b_dq += b_dz[:, None] * k_1o
171
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
172
+ if i_v == 0:
173
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
174
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
175
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
176
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
177
+ b_dq *= scale
178
+
179
+ # intra-chunk
180
+ # [BT, BT]
181
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
182
+ if i_v == 0:
183
+ b_ds += b_dz[:, None]
184
+ b_ds = tl.where(m_s, b_ds, 0) * scale
185
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
186
+ b_s = tl.where(m_s, b_s, 0)
187
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
188
+
189
+ # store
190
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
191
+
192
+ # update hidden state
193
+ # [BT, BK*BK]
194
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
195
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
196
+ # [BV, BK*BK]
197
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
198
+ # [BV, BK]
199
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
200
+
201
+ if i_v == 0:
202
+ # update running statistics
203
+ k_1o += tl.sum(b_k, axis=0)[None, :]
204
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
205
+
206
+ tl.debug_barrier()
207
+ b_h_1o = None
208
+ b_h_2o = None
209
+
210
+ # [BK, BV], first-order taylor expansion
211
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
212
+ # [BK, BK, BV] second-order taylor expansion
213
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
214
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
215
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
216
+
217
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
218
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
219
+
220
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
221
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1))
222
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0))
223
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
224
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
225
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0))
226
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0))
227
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
228
+
229
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
230
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
231
+
232
+ b_q = tl.load(p_q, boundary_check=(0, 1))
233
+ b_k = tl.load(p_k, boundary_check=(0, 1))
234
+ b_v = tl.load(p_v, boundary_check=(0, 1))
235
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
236
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
237
+ b_q = (b_q * scale).to(b_k.dtype)
238
+
239
+ # intra chunk
240
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
241
+ if i_v == 0:
242
+ b_ds += b_dz[None, :]
243
+ b_ds = tl.where(m_s, b_ds, 0)
244
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
245
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
246
+ b_s = tl.where(m_s, b_s, 0)
247
+ b_s2 = tl.where(m_s, b_s2, 0)
248
+ b_ds *= (1+b_s)
249
+
250
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
251
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
252
+
253
+ # inter chunk
254
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
255
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
256
+
257
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
258
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
259
+ b_dv += b_dh_0o
260
+
261
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
262
+
263
+ if i_v == 0:
264
+ b_dk += dq_1o
265
+
266
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
267
+ if i_v == 0:
268
+ b_dk_2o += dq_2o
269
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
270
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
271
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
272
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
273
+ b_dk += tl.trans(b_dk2)
274
+
275
+ # hidden state update
276
+ b_dh_0o += tl.sum(b_do, axis=0)
277
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
278
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
279
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
280
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
281
+
282
+ if i_v == 0:
283
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
284
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
285
+
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ class FusedChunkBasedFunction(torch.autograd.Function):
291
+
292
+ @staticmethod
293
+ @input_guard
294
+ @autocast_custom_fwd
295
+ def forward(ctx, q, k, v, scale=1):
296
+ B, H, T, K, V = *k.shape, v.shape[-1]
297
+
298
+ scale = scale
299
+ BT = 16
300
+ BK, BV = min(K, 16), min(V, 32)
301
+ BK, BV = max(BK, 16), max(BV, 16)
302
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
303
+
304
+ num_warps = 4
305
+
306
+ # the norm of o might explode, so we need to use float32 here
307
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
308
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
309
+
310
+ grid = (NV, NK, B * H)
311
+ fused_chunk_based_fwd_kernel[grid](
312
+ q, k, v, o, z,
313
+ scale,
314
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
315
+ num_warps=num_warps,
316
+ )
317
+ o = o.sum(0)
318
+ z = z.sum(0)
319
+ ctx.save_for_backward(q, k, v)
320
+ ctx.scale = scale
321
+ return o.to(q.dtype), z.to(z.dtype)
322
+
323
+ @staticmethod
324
+ @input_guard
325
+ @autocast_custom_bwd
326
+ def backward(ctx, do, dz):
327
+ q, k, v = ctx.saved_tensors
328
+ B, H, T, K, V = *k.shape, v.shape[-1]
329
+ scale = ctx.scale
330
+
331
+ BT = 16
332
+ BK, BV = min(K, 16), min(V, 32)
333
+ BK, BV = max(BK, 16), max(BV, 16)
334
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
335
+ num_stages = 1
336
+ num_warps = 4
337
+
338
+ dq = q.new_empty(NV, B, H, T, K)
339
+ dk = q.new_empty(NV, B, H, T, K)
340
+ dv = q.new_empty(NK, B, H, T, V)
341
+ grid = (NV, NK, B * H)
342
+
343
+ fused_chunk_based_bwd_kernel[grid](
344
+ q, k, v, do, dz, dq, dk, dv,
345
+ scale,
346
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
347
+ num_warps=num_warps,
348
+ num_stages=num_stages
349
+ )
350
+ dq = dq.sum(0)
351
+ dk = dk.sum(0)
352
+ dv = dv.sum(0)
353
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
354
+
355
+
356
+ def fused_chunk_based(
357
+ q: torch.Tensor,
358
+ k: torch.Tensor,
359
+ v: torch.Tensor,
360
+ scale: Optional[float] = None,
361
+ use_norm: bool = True,
362
+ head_first: bool = True
363
+ ):
364
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
365
+ if scale is None:
366
+ scale = q.shape[-1] ** -0.5
367
+ if not head_first:
368
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
369
+ o, z = FusedChunkBasedFunction.apply(q, k, v, scale)
370
+ if use_norm:
371
+ o = o / (z[..., None] + 1e-6)
372
+ if not head_first:
373
+ o = o.transpose(1, 2)
374
+ return o.to(q.dtype)
fla/ops/based/parallel.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+ # Based: An Educational and Effective Sequence Mixer
13
+ # https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
14
+
15
+
16
+ @triton.jit(do_not_specialize=['T'])
17
+ def parallel_based_fwd_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ o,
22
+ z,
23
+ scale,
24
+ T,
25
+ B: tl.constexpr,
26
+ H: tl.constexpr,
27
+ K: tl.constexpr,
28
+ V: tl.constexpr,
29
+ BTL: tl.constexpr,
30
+ BTS: tl.constexpr,
31
+ BK: tl.constexpr,
32
+ BV: tl.constexpr,
33
+ ):
34
+ # i_c: chunk index. used for sequence parallelism
35
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
36
+ NV = tl.cdiv(V, BV)
37
+ i_k = i_kv // (NV)
38
+ i_v = i_kv % (NV)
39
+
40
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
41
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1))
42
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0))
43
+
44
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
45
+ b_q = tl.load(p_q, boundary_check=(0, 1))
46
+ b_q = (b_q * scale).to(b_q.dtype)
47
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
48
+ b_z = tl.zeros([BTL], dtype=tl.float32)
49
+
50
+ # Q block and K block have no overlap
51
+ # no need for mask, thereby saving flops
52
+ for _ in range(0, i_c * BTL, BTS):
53
+ # [BK, BTS]
54
+ b_k = tl.load(p_k, boundary_check=(0, 1))
55
+
56
+ # [BTS, BV]
57
+ b_v = tl.load(p_v, boundary_check=(0, 1))
58
+ # [BTL, BTS]
59
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
60
+ b_s = 1 + b_s + 0.5 * b_s * b_s
61
+ b_z += tl.sum(b_s, axis=1)
62
+
63
+ # [BQ, BD]
64
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
65
+ p_k = tl.advance(p_k, (0, BTS))
66
+ p_v = tl.advance(p_v, (BTS, 0))
67
+
68
+ # # rescale interchunk output
69
+ tl.debug_barrier()
70
+ o_q = tl.arange(0, BTL)
71
+ # # sync threads, easy for compiler to optimize
72
+ # tl.debug_barrier()
73
+
74
+ o_k = tl.arange(0, BTS)
75
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
76
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
77
+ # Q block and K block have overlap. masks required
78
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
79
+ # [BK, BTS]
80
+ b_k = tl.load(p_k, boundary_check=(0, 1))
81
+ # [BTS, BV]
82
+ b_v = tl.load(p_v, boundary_check=(0, 1))
83
+ # [BTL, BTS]
84
+ m_s = o_q[:, None] >= o_k[None, :]
85
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
86
+ b_s = 1 + b_s + 0.5 * b_s * b_s
87
+ b_s = tl.where(m_s, b_s, 0)
88
+ b_z += tl.sum(b_s, axis=1)
89
+ # [BTL, BV]
90
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
91
+
92
+ p_k = tl.advance(p_k, (0, BTS))
93
+ p_v = tl.advance(p_v, (BTS, 0))
94
+ o_k += BTS
95
+
96
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
97
+ p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
98
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))
100
+
101
+
102
+ @triton.jit
103
+ def _parallel_based_bwd_dq(
104
+ i_bh,
105
+ i_c,
106
+ i_k,
107
+ i_v,
108
+ q,
109
+ k,
110
+ v,
111
+ do,
112
+ dz,
113
+ dq,
114
+ scale,
115
+ T,
116
+ B: tl.constexpr,
117
+ H: tl.constexpr,
118
+ BTL: tl.constexpr,
119
+ BTS: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ ):
125
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
126
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
127
+ b_q = tl.load(p_q, boundary_check=(0, 1))
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+
130
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
131
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
132
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0))
133
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1))
134
+ p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
135
+ b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
136
+
137
+ for _ in range(0, i_c * BTL, BTS):
138
+ # [BTS, BK]
139
+ b_k = tl.load(p_k, boundary_check=(0, 1))
140
+ # [BV, BTS]
141
+ b_v = tl.load(p_v, boundary_check=(0, 1))
142
+ # [BTL, BTS]
143
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
144
+ if i_v == 0:
145
+ b_ds += b_dz[:, None]
146
+ else:
147
+ b_ds = b_ds
148
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
149
+ # [BQ, BD]
150
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
151
+ p_k = tl.advance(p_k, (BTS, 0))
152
+ p_v = tl.advance(p_v, (0, BTS))
153
+
154
+ b_dq *= scale
155
+ o_q = tl.arange(0, BTL)
156
+ o_k = tl.arange(0, BTS)
157
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
158
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
159
+ # Q block and K block have overlap. masks required
160
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
161
+ # [BTS, BK]
162
+ b_k = tl.load(p_k, boundary_check=(0, 1))
163
+ # [BV, BTS]
164
+ b_v = tl.load(p_v, boundary_check=(0, 1))
165
+ # [BTL, BTS]
166
+ m_s = o_q[:, None] >= o_k[None, :]
167
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
168
+ if i_v == 0:
169
+ b_ds += b_dz[:, None]
170
+ else:
171
+ b_ds = b_ds
172
+ b_ds = tl.where(m_s, b_ds, 0) * scale
173
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
174
+ b_s = tl.where(m_s, b_s, 0)
175
+ # [BTL, BK]
176
+ b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)
177
+ p_k = tl.advance(p_k, (BTS, 0))
178
+ p_v = tl.advance(p_v, (0, BTS))
179
+ o_k += BTS
180
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
181
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
182
+ return
183
+
184
+
185
+ @triton.jit
186
+ def _parallel_based_bwd_dkv(
187
+ i_bh,
188
+ i_c,
189
+ i_k,
190
+ i_v,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ BTL: tl.constexpr,
203
+ BTS: tl.constexpr,
204
+ BK: tl.constexpr,
205
+ BV: tl.constexpr,
206
+ K: tl.constexpr,
207
+ V: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32)
214
+
215
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
216
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
217
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
218
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
219
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
220
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
221
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
222
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS]
223
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
224
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
225
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
226
+ if i_v == 0:
227
+ b_ds += b_dz[None, :] * scale
228
+ else:
229
+ b_ds = b_ds
230
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
231
+
232
+ tl.debug_barrier()
233
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
234
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
235
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
236
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
237
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
238
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
239
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
240
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
241
+ # [BK, BQ]
242
+ m_s = o_k[:, None] <= o_q[None, :]
243
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
244
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
245
+ b_s = tl.where(m_s, b_s, 0)
246
+ b_s2 = tl.where(m_s, b_s2, 0)
247
+
248
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
249
+ if i_v == 0:
250
+ b_ds += b_dz[None, :]
251
+ else:
252
+ b_ds = b_ds
253
+ b_ds = tl.where(m_s, b_ds, 0) * scale
254
+ # [BK, BD]
255
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
256
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
257
+ o_q += BTS
258
+
259
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
260
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
261
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
262
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
263
+ return
264
+
265
+
266
+ @triton.jit(do_not_specialize=['T'])
267
+ def parallel_based_bwd_kernel(
268
+ q,
269
+ k,
270
+ v,
271
+ do,
272
+ dz,
273
+ dq,
274
+ dk,
275
+ dv,
276
+ scale,
277
+ T,
278
+ B: tl.constexpr,
279
+ H: tl.constexpr,
280
+ K: tl.constexpr,
281
+ V: tl.constexpr,
282
+ BTL: tl.constexpr,
283
+ BTS: tl.constexpr,
284
+ BK: tl.constexpr,
285
+ BV: tl.constexpr,
286
+ ):
287
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
288
+ NV = tl.cdiv(V, BV)
289
+ i_k = i_kv // (NV)
290
+ i_v = i_kv % NV
291
+ _parallel_based_bwd_dq(
292
+ i_bh, i_c, i_k, i_v,
293
+ q, k, v, do, dz, dq,
294
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
295
+ )
296
+ tl.debug_barrier()
297
+ _parallel_based_bwd_dkv(
298
+ i_bh, i_c, i_k, i_v,
299
+ q, k, v, do, dz, dk, dv,
300
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
301
+ )
302
+
303
+
304
+ class ParallelBasedFunction(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ @input_guard
308
+ @autocast_custom_fwd
309
+ def forward(ctx, q, k, v, scale):
310
+ BTL, BTS = 128, 32
311
+ assert BTL % BTS == 0
312
+ # assert q.shape[-1] % 16 == 0
313
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
314
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
315
+ BK, BV = max(BK, 16), max(BV, 16)
316
+ B, H, T, K, V = *k.shape, v.shape[-1]
317
+ num_stages = 2
318
+ num_warps = 4
319
+ NK = triton.cdiv(K, BK)
320
+ NV = triton.cdiv(V, BV)
321
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
322
+
323
+ assert NK == 1, "will encounter some synchronization issue if not."
324
+
325
+ o = torch.empty(NK, B, H, T, V, device=q.device)
326
+ z = torch.empty(NK, B, H, T, device=q.device)
327
+ parallel_based_fwd_kernel[grid](
328
+ q, k, v, o, z,
329
+ scale,
330
+ B=B,
331
+ H=H,
332
+ T=T,
333
+ K=K,
334
+ V=V,
335
+ BTL=BTL,
336
+ BTS=BTS,
337
+ BK=BK,
338
+ BV=BV,
339
+ num_warps=num_warps,
340
+ num_stages=num_stages
341
+ )
342
+ ctx.save_for_backward(q, k, v)
343
+ ctx.scale = scale
344
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
345
+
346
+ @staticmethod
347
+ @input_guard
348
+ @autocast_custom_bwd
349
+ def backward(ctx, do, dz):
350
+ q, k, v = ctx.saved_tensors
351
+ scale = ctx.scale
352
+ BTL, BTS = 64, 32
353
+ assert BTL % BTS == 0
354
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
355
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
356
+ BK, BV = max(BK, 16), max(BV, 16)
357
+ B, H, T, K, V = *k.shape, v.shape[-1]
358
+ num_stages = 2
359
+ num_warps = 4
360
+ NK = triton.cdiv(K, BK)
361
+ NV = triton.cdiv(V, BV)
362
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
363
+
364
+ assert NK == 1, "will encounter some synchronization issue if not"
365
+
366
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
367
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
368
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
369
+
370
+ parallel_based_bwd_kernel[grid](
371
+ q, k, v, do, dz, dq, dk, dv,
372
+ scale,
373
+ B=B,
374
+ H=H,
375
+ T=T,
376
+ K=K,
377
+ V=V,
378
+ BTL=BTL,
379
+ BTS=BTS,
380
+ BK=BK,
381
+ BV=BV,
382
+ num_warps=num_warps,
383
+ num_stages=num_stages
384
+ )
385
+
386
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
387
+
388
+
389
+ triton_parallel_based = ParallelBasedFunction.apply
390
+
391
+
392
+ def parallel_based(
393
+ q: torch.Tensor,
394
+ k: torch.Tensor,
395
+ v: torch.Tensor,
396
+ scale: Optional[float] = None,
397
+ use_norm: bool = True,
398
+ head_first: bool = True
399
+ ):
400
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
401
+ if scale is None:
402
+ scale = q.shape[-1] ** -0.5
403
+ if not head_first:
404
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
405
+ o, z = triton_parallel_based(q, k, v, scale)
406
+ if use_norm:
407
+ o = o / (z[..., None] + 1e-6)
408
+ if not head_first:
409
+ o = o.transpose(1, 2)
410
+ return o.to(q.dtype)
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (139 Bytes). View file