medmekk HF Staff commited on
Commit
07fba08
·
verified ·
1 Parent(s): 84a23cd

Upload custom kernels

Browse files
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## triton-layer-norm
8
+
9
+ Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
build-output/torch-universal/rmsnorm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .rmsnorm import rms_norm_fn
2
+
3
+ from . import layers
4
+
5
+ __all__ = ["layers", "rms_norm_fn"]
build-output/torch-universal/rmsnorm/layers.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .rmsnorm import rms_norm_fn
5
+
6
+
7
+ class RMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return rms_norm_fn(
13
+ hidden_states,
14
+ self.weight,
15
+ bias=None,
16
+ residual=None,
17
+ eps=self.variance_epsilon,
18
+ dropout_p=0.0,
19
+ prenorm=False,
20
+ residual_in_fp32=False,
21
+ ) # type: ignore
22
+
23
+
24
+ __all__ = ["RMSNorm"]
build-output/torch-universal/rmsnorm/rmsnorm.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(
20
+ x,
21
+ weight,
22
+ bias,
23
+ residual=None,
24
+ x1=None,
25
+ weight1=None,
26
+ bias1=None,
27
+ eps=1e-6,
28
+ dropout_p=0.0,
29
+ rowscale=None,
30
+ prenorm=False,
31
+ dropout_mask=None,
32
+ dropout_mask1=None,
33
+ upcast=False,
34
+ ):
35
+ dtype = x.dtype
36
+ if upcast:
37
+ x = x.float()
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ residual = residual.float() if residual is not None else residual
41
+ x1 = x1.float() if x1 is not None else None
42
+ weight1 = weight1.float() if weight1 is not None else None
43
+ bias1 = bias1.float() if bias1 is not None else None
44
+ if x1 is not None:
45
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
+ if rowscale is not None:
47
+ x = x * rowscale[..., None]
48
+ if dropout_p > 0.0:
49
+ if dropout_mask is not None:
50
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
51
+ else:
52
+ x = F.dropout(x, p=dropout_p)
53
+ if x1 is not None:
54
+ if dropout_mask1 is not None:
55
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
56
+ else:
57
+ x1 = F.dropout(x1, p=dropout_p)
58
+ if x1 is not None:
59
+ x = x + x1
60
+ if residual is not None:
61
+ x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(
63
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
+ ).to(dtype)
65
+ if weight1 is None:
66
+ return out if not prenorm else (out, x)
67
+ else:
68
+ out1 = F.layer_norm(
69
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
70
+ ).to(dtype)
71
+ return (out, out1) if not prenorm else (out, out1, x)
72
+
73
+
74
+ def rms_norm_ref(
75
+ x,
76
+ weight,
77
+ bias,
78
+ residual=None,
79
+ x1=None,
80
+ weight1=None,
81
+ bias1=None,
82
+ eps=1e-6,
83
+ dropout_p=0.0,
84
+ rowscale=None,
85
+ prenorm=False,
86
+ dropout_mask=None,
87
+ dropout_mask1=None,
88
+ upcast=False,
89
+ ):
90
+ dtype = x.dtype
91
+ if upcast:
92
+ x = x.float()
93
+ weight = weight.float()
94
+ bias = bias.float() if bias is not None else None
95
+ residual = residual.float() if residual is not None else residual
96
+ x1 = x1.float() if x1 is not None else None
97
+ weight1 = weight1.float() if weight1 is not None else None
98
+ bias1 = bias1.float() if bias1 is not None else None
99
+ if x1 is not None:
100
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
+ if rowscale is not None:
102
+ x = x * rowscale[..., None]
103
+ if dropout_p > 0.0:
104
+ if dropout_mask is not None:
105
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
106
+ else:
107
+ x = F.dropout(x, p=dropout_p)
108
+ if x1 is not None:
109
+ if dropout_mask1 is not None:
110
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
111
+ else:
112
+ x1 = F.dropout(x1, p=dropout_p)
113
+ if x1 is not None:
114
+ x = x + x1
115
+ if residual is not None:
116
+ x = (x + residual).to(x.dtype)
117
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
+ dtype
120
+ )
121
+ if weight1 is None:
122
+ return out if not prenorm else (out, x)
123
+ else:
124
+ out1 = (
125
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
+ ).to(dtype)
127
+ return (out, out1) if not prenorm else (out, out1, x)
128
+
129
+
130
+ @triton.autotune(
131
+ configs=[
132
+ triton.Config({}, num_warps=1),
133
+ triton.Config({}, num_warps=2),
134
+ triton.Config({}, num_warps=4),
135
+ triton.Config({}, num_warps=8),
136
+ triton.Config({}, num_warps=16),
137
+ triton.Config({}, num_warps=32),
138
+ ],
139
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
+ )
141
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
142
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
143
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
144
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
145
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
146
+ @triton.jit
147
+ def _layer_norm_fwd_1pass_kernel(
148
+ X, # pointer to the input
149
+ Y, # pointer to the output
150
+ W, # pointer to the weights
151
+ B, # pointer to the biases
152
+ RESIDUAL, # pointer to the residual
153
+ X1,
154
+ W1,
155
+ B1,
156
+ Y1,
157
+ RESIDUAL_OUT, # pointer to the residual
158
+ ROWSCALE,
159
+ SEEDS, # Dropout seeds for each row
160
+ DROPOUT_MASK,
161
+ Mean, # pointer to the mean
162
+ Rstd, # pointer to the 1/std
163
+ stride_x_row, # how much to increase the pointer when moving by 1 row
164
+ stride_y_row,
165
+ stride_res_row,
166
+ stride_res_out_row,
167
+ stride_x1_row,
168
+ stride_y1_row,
169
+ M, # number of rows in X
170
+ N, # number of columns in X
171
+ eps, # epsilon to avoid division by zero
172
+ dropout_p, # Dropout probability
173
+ IS_RMS_NORM: tl.constexpr,
174
+ BLOCK_N: tl.constexpr,
175
+ HAS_RESIDUAL: tl.constexpr,
176
+ STORE_RESIDUAL_OUT: tl.constexpr,
177
+ HAS_BIAS: tl.constexpr,
178
+ HAS_DROPOUT: tl.constexpr,
179
+ STORE_DROPOUT_MASK: tl.constexpr,
180
+ HAS_ROWSCALE: tl.constexpr,
181
+ HAS_X1: tl.constexpr,
182
+ HAS_W1: tl.constexpr,
183
+ HAS_B1: tl.constexpr,
184
+ ):
185
+ # Map the program id to the row of X and Y it should compute.
186
+ row = tl.program_id(0)
187
+ X += row * stride_x_row
188
+ Y += row * stride_y_row
189
+ if HAS_RESIDUAL:
190
+ RESIDUAL += row * stride_res_row
191
+ if STORE_RESIDUAL_OUT:
192
+ RESIDUAL_OUT += row * stride_res_out_row
193
+ if HAS_X1:
194
+ X1 += row * stride_x1_row
195
+ if HAS_W1:
196
+ Y1 += row * stride_y1_row
197
+ # Compute mean and variance
198
+ cols = tl.arange(0, BLOCK_N)
199
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
200
+ if HAS_ROWSCALE:
201
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
202
+ x *= rowscale
203
+ if HAS_DROPOUT:
204
+ # Compute dropout mask
205
+ # 7 rounds is good enough, and reduces register pressure
206
+ keep_mask = (
207
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
+ )
209
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
+ if STORE_DROPOUT_MASK:
211
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
212
+ if HAS_X1:
213
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
214
+ if HAS_ROWSCALE:
215
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
216
+ x1 *= rowscale
217
+ if HAS_DROPOUT:
218
+ # Compute dropout mask
219
+ # 7 rounds is good enough, and reduces register pressure
220
+ keep_mask = (
221
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
+ > dropout_p
223
+ )
224
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
+ if STORE_DROPOUT_MASK:
226
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
227
+ x += x1
228
+ if HAS_RESIDUAL:
229
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
230
+ x += residual
231
+ if STORE_RESIDUAL_OUT:
232
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
233
+ if not IS_RMS_NORM:
234
+ mean = tl.sum(x, axis=0) / N
235
+ tl.store(Mean + row, mean)
236
+ xbar = tl.where(cols < N, x - mean, 0.0)
237
+ var = tl.sum(xbar * xbar, axis=0) / N
238
+ else:
239
+ xbar = tl.where(cols < N, x, 0.0)
240
+ var = tl.sum(xbar * xbar, axis=0) / N
241
+ rstd = 1 / tl.sqrt(var + eps)
242
+ tl.store(Rstd + row, rstd)
243
+ # Normalize and apply linear transformation
244
+ mask = cols < N
245
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
246
+ if HAS_BIAS:
247
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
248
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
249
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
250
+ # Write output
251
+ tl.store(Y + cols, y, mask=mask)
252
+ if HAS_W1:
253
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
254
+ if HAS_B1:
255
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
256
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
257
+ tl.store(Y1 + cols, y1, mask=mask)
258
+
259
+
260
+ def _layer_norm_fwd(
261
+ x,
262
+ weight,
263
+ bias,
264
+ eps,
265
+ residual=None,
266
+ x1=None,
267
+ weight1=None,
268
+ bias1=None,
269
+ dropout_p=0.0,
270
+ rowscale=None,
271
+ out_dtype=None,
272
+ residual_dtype=None,
273
+ is_rms_norm=False,
274
+ return_dropout_mask=False,
275
+ out=None,
276
+ residual_out=None,
277
+ ):
278
+ if residual is not None:
279
+ residual_dtype = residual.dtype
280
+ M, N = x.shape
281
+ assert x.stride(-1) == 1
282
+ if residual is not None:
283
+ assert residual.stride(-1) == 1
284
+ assert residual.shape == (M, N)
285
+ assert weight.shape == (N,)
286
+ assert weight.stride(-1) == 1
287
+ if bias is not None:
288
+ assert bias.stride(-1) == 1
289
+ assert bias.shape == (N,)
290
+ if x1 is not None:
291
+ assert x1.shape == x.shape
292
+ assert rowscale is None
293
+ assert x1.stride(-1) == 1
294
+ if weight1 is not None:
295
+ assert weight1.shape == (N,)
296
+ assert weight1.stride(-1) == 1
297
+ if bias1 is not None:
298
+ assert bias1.shape == (N,)
299
+ assert bias1.stride(-1) == 1
300
+ if rowscale is not None:
301
+ assert rowscale.is_contiguous()
302
+ assert rowscale.shape == (M,)
303
+ # allocate output
304
+ if out is None:
305
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
306
+ else:
307
+ assert out.shape == x.shape
308
+ assert out.stride(-1) == 1
309
+ if weight1 is not None:
310
+ y1 = torch.empty_like(out)
311
+ assert y1.stride(-1) == 1
312
+ else:
313
+ y1 = None
314
+ if (
315
+ residual is not None
316
+ or (residual_dtype is not None and residual_dtype != x.dtype)
317
+ or dropout_p > 0.0
318
+ or rowscale is not None
319
+ or x1 is not None
320
+ ):
321
+ if residual_out is None:
322
+ residual_out = torch.empty(
323
+ M,
324
+ N,
325
+ device=x.device,
326
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
+ )
328
+ else:
329
+ assert residual_out.shape == x.shape
330
+ assert residual_out.stride(-1) == 1
331
+ else:
332
+ residual_out = None
333
+ mean = (
334
+ torch.empty((M,), dtype=torch.float32, device=x.device)
335
+ if not is_rms_norm
336
+ else None
337
+ )
338
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
+ if dropout_p > 0.0:
340
+ seeds = torch.randint(
341
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
342
+ )
343
+ else:
344
+ seeds = None
345
+ if return_dropout_mask and dropout_p > 0.0:
346
+ dropout_mask = torch.empty(
347
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
+ )
349
+ else:
350
+ dropout_mask = None
351
+ # Less than 64KB per feature: enqueue fused kernel
352
+ MAX_FUSED_SIZE = 65536 // x.element_size()
353
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
354
+ if N > BLOCK_N:
355
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
356
+ with torch.cuda.device(x.device.index):
357
+ _layer_norm_fwd_1pass_kernel[(M,)](
358
+ x,
359
+ out,
360
+ weight,
361
+ bias,
362
+ residual,
363
+ x1,
364
+ weight1,
365
+ bias1,
366
+ y1,
367
+ residual_out,
368
+ rowscale,
369
+ seeds,
370
+ dropout_mask,
371
+ mean,
372
+ rstd,
373
+ x.stride(0),
374
+ out.stride(0),
375
+ residual.stride(0) if residual is not None else 0,
376
+ residual_out.stride(0) if residual_out is not None else 0,
377
+ x1.stride(0) if x1 is not None else 0,
378
+ y1.stride(0) if y1 is not None else 0,
379
+ M,
380
+ N,
381
+ eps,
382
+ dropout_p,
383
+ is_rms_norm,
384
+ BLOCK_N,
385
+ residual is not None,
386
+ residual_out is not None,
387
+ bias is not None,
388
+ dropout_p > 0.0,
389
+ dropout_mask is not None,
390
+ rowscale is not None,
391
+ )
392
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
393
+ if dropout_mask is not None and x1 is not None:
394
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
395
+ else:
396
+ dropout_mask1 = None
397
+ return (
398
+ out,
399
+ y1,
400
+ mean,
401
+ rstd,
402
+ residual_out if residual_out is not None else x,
403
+ seeds,
404
+ dropout_mask,
405
+ dropout_mask1,
406
+ )
407
+
408
+
409
+ @triton.autotune(
410
+ configs=[
411
+ triton.Config({}, num_warps=1),
412
+ triton.Config({}, num_warps=2),
413
+ triton.Config({}, num_warps=4),
414
+ triton.Config({}, num_warps=8),
415
+ triton.Config({}, num_warps=16),
416
+ triton.Config({}, num_warps=32),
417
+ ],
418
+ key=[
419
+ "N",
420
+ "HAS_DRESIDUAL",
421
+ "STORE_DRESIDUAL",
422
+ "IS_RMS_NORM",
423
+ "HAS_BIAS",
424
+ "HAS_DROPOUT",
425
+ ],
426
+ )
427
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
429
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
430
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
431
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
432
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
433
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
434
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
435
+ @triton.jit
436
+ def _layer_norm_bwd_kernel(
437
+ X, # pointer to the input
438
+ W, # pointer to the weights
439
+ B, # pointer to the biases
440
+ Y, # pointer to the output to be recomputed
441
+ DY, # pointer to the output gradient
442
+ DX, # pointer to the input gradient
443
+ DW, # pointer to the partial sum of weights gradient
444
+ DB, # pointer to the partial sum of biases gradient
445
+ DRESIDUAL,
446
+ W1,
447
+ DY1,
448
+ DX1,
449
+ DW1,
450
+ DB1,
451
+ DRESIDUAL_IN,
452
+ ROWSCALE,
453
+ SEEDS,
454
+ Mean, # pointer to the mean
455
+ Rstd, # pointer to the 1/std
456
+ stride_x_row, # how much to increase the pointer when moving by 1 row
457
+ stride_y_row,
458
+ stride_dy_row,
459
+ stride_dx_row,
460
+ stride_dres_row,
461
+ stride_dy1_row,
462
+ stride_dx1_row,
463
+ stride_dres_in_row,
464
+ M, # number of rows in X
465
+ N, # number of columns in X
466
+ eps, # epsilon to avoid division by zero
467
+ dropout_p,
468
+ rows_per_program,
469
+ IS_RMS_NORM: tl.constexpr,
470
+ BLOCK_N: tl.constexpr,
471
+ HAS_DRESIDUAL: tl.constexpr,
472
+ STORE_DRESIDUAL: tl.constexpr,
473
+ HAS_BIAS: tl.constexpr,
474
+ HAS_DROPOUT: tl.constexpr,
475
+ HAS_ROWSCALE: tl.constexpr,
476
+ HAS_DY1: tl.constexpr,
477
+ HAS_DX1: tl.constexpr,
478
+ HAS_B1: tl.constexpr,
479
+ RECOMPUTE_OUTPUT: tl.constexpr,
480
+ ):
481
+ # Map the program id to the elements of X, DX, and DY it should compute.
482
+ row_block_id = tl.program_id(0)
483
+ row_start = row_block_id * rows_per_program
484
+ # Do not early exit if row_start >= M, because we need to write DW and DB
485
+ cols = tl.arange(0, BLOCK_N)
486
+ mask = cols < N
487
+ X += row_start * stride_x_row
488
+ if HAS_DRESIDUAL:
489
+ DRESIDUAL += row_start * stride_dres_row
490
+ if STORE_DRESIDUAL:
491
+ DRESIDUAL_IN += row_start * stride_dres_in_row
492
+ DY += row_start * stride_dy_row
493
+ DX += row_start * stride_dx_row
494
+ if HAS_DY1:
495
+ DY1 += row_start * stride_dy1_row
496
+ if HAS_DX1:
497
+ DX1 += row_start * stride_dx1_row
498
+ if RECOMPUTE_OUTPUT:
499
+ Y += row_start * stride_y_row
500
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
501
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
502
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
503
+ if HAS_DY1:
504
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
505
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
+ if HAS_BIAS:
507
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
508
+ if HAS_DY1:
509
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
510
+ if HAS_B1:
511
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
512
+ row_end = min((row_block_id + 1) * rows_per_program, M)
513
+ for row in range(row_start, row_end):
514
+ # Load data to SRAM
515
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
516
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
517
+ if HAS_DY1:
518
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
519
+ if not IS_RMS_NORM:
520
+ mean = tl.load(Mean + row)
521
+ rstd = tl.load(Rstd + row)
522
+ # Compute dx
523
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
524
+ xhat = tl.where(mask, xhat, 0.0)
525
+ if RECOMPUTE_OUTPUT:
526
+ y = xhat * w + b if HAS_BIAS else xhat * w
527
+ tl.store(Y + cols, y, mask=mask)
528
+ wdy = w * dy
529
+ dw += dy * xhat
530
+ if HAS_BIAS:
531
+ db += dy
532
+ if HAS_DY1:
533
+ wdy += w1 * dy1
534
+ dw1 += dy1 * xhat
535
+ if HAS_B1:
536
+ db1 += dy1
537
+ if not IS_RMS_NORM:
538
+ c1 = tl.sum(xhat * wdy, axis=0) / N
539
+ c2 = tl.sum(wdy, axis=0) / N
540
+ dx = (wdy - (xhat * c1 + c2)) * rstd
541
+ else:
542
+ c1 = tl.sum(xhat * wdy, axis=0) / N
543
+ dx = (wdy - xhat * c1) * rstd
544
+ if HAS_DRESIDUAL:
545
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
546
+ dx += dres
547
+ # Write dx
548
+ if STORE_DRESIDUAL:
549
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
550
+ if HAS_DX1:
551
+ if HAS_DROPOUT:
552
+ keep_mask = (
553
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
+ > dropout_p
555
+ )
556
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
+ else:
558
+ dx1 = dx
559
+ tl.store(DX1 + cols, dx1, mask=mask)
560
+ if HAS_DROPOUT:
561
+ keep_mask = (
562
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
+ > dropout_p
564
+ )
565
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
+ if HAS_ROWSCALE:
567
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
568
+ dx *= rowscale
569
+ tl.store(DX + cols, dx, mask=mask)
570
+
571
+ X += stride_x_row
572
+ if HAS_DRESIDUAL:
573
+ DRESIDUAL += stride_dres_row
574
+ if STORE_DRESIDUAL:
575
+ DRESIDUAL_IN += stride_dres_in_row
576
+ if RECOMPUTE_OUTPUT:
577
+ Y += stride_y_row
578
+ DY += stride_dy_row
579
+ DX += stride_dx_row
580
+ if HAS_DY1:
581
+ DY1 += stride_dy1_row
582
+ if HAS_DX1:
583
+ DX1 += stride_dx1_row
584
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
585
+ if HAS_BIAS:
586
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
587
+ if HAS_DY1:
588
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
589
+ if HAS_B1:
590
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
591
+
592
+
593
+ def _layer_norm_bwd(
594
+ dy,
595
+ x,
596
+ weight,
597
+ bias,
598
+ eps,
599
+ mean,
600
+ rstd,
601
+ dresidual=None,
602
+ dy1=None,
603
+ weight1=None,
604
+ bias1=None,
605
+ seeds=None,
606
+ dropout_p=0.0,
607
+ rowscale=None,
608
+ has_residual=False,
609
+ has_x1=False,
610
+ is_rms_norm=False,
611
+ x_dtype=None,
612
+ recompute_output=False,
613
+ ):
614
+ M, N = x.shape
615
+ assert x.stride(-1) == 1
616
+ assert dy.stride(-1) == 1
617
+ assert dy.shape == (M, N)
618
+ if dresidual is not None:
619
+ assert dresidual.stride(-1) == 1
620
+ assert dresidual.shape == (M, N)
621
+ assert weight.shape == (N,)
622
+ assert weight.stride(-1) == 1
623
+ if bias is not None:
624
+ assert bias.stride(-1) == 1
625
+ assert bias.shape == (N,)
626
+ if dy1 is not None:
627
+ assert weight1 is not None
628
+ assert dy1.shape == dy.shape
629
+ assert dy1.stride(-1) == 1
630
+ if weight1 is not None:
631
+ assert weight1.shape == (N,)
632
+ assert weight1.stride(-1) == 1
633
+ if bias1 is not None:
634
+ assert bias1.shape == (N,)
635
+ assert bias1.stride(-1) == 1
636
+ if seeds is not None:
637
+ assert seeds.is_contiguous()
638
+ assert seeds.shape == (M if not has_x1 else M * 2,)
639
+ if rowscale is not None:
640
+ assert rowscale.is_contiguous()
641
+ assert rowscale.shape == (M,)
642
+ # allocate output
643
+ dx = (
644
+ torch.empty_like(x)
645
+ if x_dtype is None
646
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
647
+ )
648
+ dresidual_in = (
649
+ torch.empty_like(x)
650
+ if has_residual
651
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
652
+ else None
653
+ )
654
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
+ y = (
656
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
+ if recompute_output
658
+ else None
659
+ )
660
+ if recompute_output:
661
+ assert (
662
+ weight1 is None
663
+ ), "recompute_output is not supported with parallel LayerNorm"
664
+
665
+ # Less than 64KB per feature: enqueue fused kernel
666
+ MAX_FUSED_SIZE = 65536 // x.element_size()
667
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
668
+ if N > BLOCK_N:
669
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
670
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
671
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
672
+ _db = (
673
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
674
+ if bias is not None
675
+ else None
676
+ )
677
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
678
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
679
+ rows_per_program = math.ceil(M / sm_count)
680
+ grid = (sm_count,)
681
+ with torch.cuda.device(x.device.index):
682
+ _layer_norm_bwd_kernel[grid](
683
+ x,
684
+ weight,
685
+ bias,
686
+ y,
687
+ dy,
688
+ dx,
689
+ _dw,
690
+ _db,
691
+ dresidual,
692
+ weight1,
693
+ dy1,
694
+ dx1,
695
+ _dw1,
696
+ _db1,
697
+ dresidual_in,
698
+ rowscale,
699
+ seeds,
700
+ mean,
701
+ rstd,
702
+ x.stride(0),
703
+ 0 if not recompute_output else y.stride(0),
704
+ dy.stride(0),
705
+ dx.stride(0),
706
+ dresidual.stride(0) if dresidual is not None else 0,
707
+ dy1.stride(0) if dy1 is not None else 0,
708
+ dx1.stride(0) if dx1 is not None else 0,
709
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
710
+ M,
711
+ N,
712
+ eps,
713
+ dropout_p,
714
+ rows_per_program,
715
+ is_rms_norm,
716
+ BLOCK_N,
717
+ dresidual is not None,
718
+ dresidual_in is not None,
719
+ bias is not None,
720
+ dropout_p > 0.0,
721
+ )
722
+ dw = _dw.sum(0).to(weight.dtype)
723
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
724
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
725
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
726
+ # Don't need to compute dresidual_in separately in this case
727
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
728
+ dresidual_in = dx
729
+ if has_x1 and dropout_p == 0.0:
730
+ dx1 = dx
731
+ return (
732
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
733
+ if not recompute_output
734
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
735
+ )
736
+
737
+
738
+ class LayerNormFn(torch.autograd.Function):
739
+ @staticmethod
740
+ def forward(
741
+ ctx,
742
+ x,
743
+ weight,
744
+ bias,
745
+ residual=None,
746
+ x1=None,
747
+ weight1=None,
748
+ bias1=None,
749
+ eps=1e-6,
750
+ dropout_p=0.0,
751
+ rowscale=None,
752
+ prenorm=False,
753
+ residual_in_fp32=False,
754
+ is_rms_norm=False,
755
+ return_dropout_mask=False,
756
+ out=None,
757
+ residual_out=None,
758
+ ):
759
+ x_shape_og = x.shape
760
+ # reshape input data into 2D tensor
761
+ x = x.reshape(-1, x.shape[-1])
762
+ if x.stride(-1) != 1:
763
+ x = x.contiguous()
764
+ if residual is not None:
765
+ assert residual.shape == x_shape_og
766
+ residual = residual.reshape(-1, residual.shape[-1])
767
+ if residual.stride(-1) != 1:
768
+ residual = residual.contiguous()
769
+ if x1 is not None:
770
+ assert x1.shape == x_shape_og
771
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
772
+ x1 = x1.reshape(-1, x1.shape[-1])
773
+ if x1.stride(-1) != 1:
774
+ x1 = x1.contiguous()
775
+ weight = weight.contiguous()
776
+ if bias is not None:
777
+ bias = bias.contiguous()
778
+ if weight1 is not None:
779
+ weight1 = weight1.contiguous()
780
+ if bias1 is not None:
781
+ bias1 = bias1.contiguous()
782
+ if rowscale is not None:
783
+ rowscale = rowscale.reshape(-1).contiguous()
784
+ residual_dtype = (
785
+ residual.dtype
786
+ if residual is not None
787
+ else (torch.float32 if residual_in_fp32 else None)
788
+ )
789
+ if out is not None:
790
+ out = out.reshape(-1, out.shape[-1])
791
+ if residual_out is not None:
792
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
+ _layer_norm_fwd(
795
+ x,
796
+ weight,
797
+ bias,
798
+ eps,
799
+ residual,
800
+ x1,
801
+ weight1,
802
+ bias1,
803
+ dropout_p=dropout_p,
804
+ rowscale=rowscale,
805
+ residual_dtype=residual_dtype,
806
+ is_rms_norm=is_rms_norm,
807
+ return_dropout_mask=return_dropout_mask,
808
+ out=out,
809
+ residual_out=residual_out,
810
+ )
811
+ )
812
+ ctx.save_for_backward(
813
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
814
+ )
815
+ ctx.x_shape_og = x_shape_og
816
+ ctx.eps = eps
817
+ ctx.dropout_p = dropout_p
818
+ ctx.is_rms_norm = is_rms_norm
819
+ ctx.has_residual = residual is not None
820
+ ctx.has_x1 = x1 is not None
821
+ ctx.prenorm = prenorm
822
+ ctx.x_dtype = x.dtype
823
+ y = y.reshape(x_shape_og)
824
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
+ residual_out = (
826
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
827
+ )
828
+ dropout_mask = (
829
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
+ )
831
+ dropout_mask1 = (
832
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
+ )
834
+ if not return_dropout_mask:
835
+ if weight1 is None:
836
+ return y if not prenorm else (y, residual_out)
837
+ else:
838
+ return (y, y1) if not prenorm else (y, y1, residual_out)
839
+ else:
840
+ if weight1 is None:
841
+ return (
842
+ (y, dropout_mask, dropout_mask1)
843
+ if not prenorm
844
+ else (y, residual_out, dropout_mask, dropout_mask1)
845
+ )
846
+ else:
847
+ return (
848
+ (y, y1, dropout_mask, dropout_mask1)
849
+ if not prenorm
850
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
851
+ )
852
+
853
+ @staticmethod
854
+ def backward(ctx, dy, *args):
855
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
856
+ dy = dy.reshape(-1, dy.shape[-1])
857
+ if dy.stride(-1) != 1:
858
+ dy = dy.contiguous()
859
+ assert dy.shape == x.shape
860
+ if weight1 is not None:
861
+ dy1, args = args[0], args[1:]
862
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
863
+ if dy1.stride(-1) != 1:
864
+ dy1 = dy1.contiguous()
865
+ assert dy1.shape == x.shape
866
+ else:
867
+ dy1 = None
868
+ if ctx.prenorm:
869
+ dresidual = args[0]
870
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
871
+ if dresidual.stride(-1) != 1:
872
+ dresidual = dresidual.contiguous()
873
+ assert dresidual.shape == x.shape
874
+ else:
875
+ dresidual = None
876
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
877
+ dy,
878
+ x,
879
+ weight,
880
+ bias,
881
+ ctx.eps,
882
+ mean,
883
+ rstd,
884
+ dresidual,
885
+ dy1,
886
+ weight1,
887
+ bias1,
888
+ seeds,
889
+ ctx.dropout_p,
890
+ rowscale,
891
+ ctx.has_residual,
892
+ ctx.has_x1,
893
+ ctx.is_rms_norm,
894
+ x_dtype=ctx.x_dtype,
895
+ )
896
+ return (
897
+ dx.reshape(ctx.x_shape_og),
898
+ dw,
899
+ db,
900
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
901
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
902
+ dw1,
903
+ db1,
904
+ None,
905
+ None,
906
+ None,
907
+ None,
908
+ None,
909
+ None,
910
+ None,
911
+ None,
912
+ None,
913
+ )
914
+
915
+
916
+ def layer_norm_fn(
917
+ x,
918
+ weight,
919
+ bias,
920
+ residual=None,
921
+ x1=None,
922
+ weight1=None,
923
+ bias1=None,
924
+ eps=1e-6,
925
+ dropout_p=0.0,
926
+ rowscale=None,
927
+ prenorm=False,
928
+ residual_in_fp32=False,
929
+ is_rms_norm=False,
930
+ return_dropout_mask=False,
931
+ out=None,
932
+ residual_out=None,
933
+ ):
934
+ return LayerNormFn.apply(
935
+ x,
936
+ weight,
937
+ bias,
938
+ residual,
939
+ x1,
940
+ weight1,
941
+ bias1,
942
+ eps,
943
+ dropout_p,
944
+ rowscale,
945
+ prenorm,
946
+ residual_in_fp32,
947
+ is_rms_norm,
948
+ return_dropout_mask,
949
+ out,
950
+ residual_out,
951
+ )
952
+
953
+
954
+ def rms_norm_fn(
955
+ x,
956
+ weight,
957
+ bias,
958
+ residual=None,
959
+ x1=None,
960
+ weight1=None,
961
+ bias1=None,
962
+ eps=1e-6,
963
+ dropout_p=0.0,
964
+ rowscale=None,
965
+ prenorm=False,
966
+ residual_in_fp32=False,
967
+ return_dropout_mask=False,
968
+ out=None,
969
+ residual_out=None,
970
+ ):
971
+ return LayerNormFn.apply(
972
+ x,
973
+ weight,
974
+ bias,
975
+ residual,
976
+ x1,
977
+ weight1,
978
+ bias1,
979
+ eps,
980
+ dropout_p,
981
+ rowscale,
982
+ prenorm,
983
+ residual_in_fp32,
984
+ True,
985
+ return_dropout_mask,
986
+ out,
987
+ residual_out,
988
+ )
build.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [general]
2
+ name = "rmsnorm"
3
+
4
+ [torch]
5
+ universal = true
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Triton layer norm kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/rmsnorm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .rmsnorm import rms_norm_fn
2
+
3
+ from . import layers
4
+
5
+ __all__ = ["layers", "rms_norm_fn"]
torch-ext/rmsnorm/layers.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .rmsnorm import rms_norm_fn
5
+
6
+
7
+ class RMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return rms_norm_fn(
13
+ hidden_states,
14
+ self.weight,
15
+ bias=None,
16
+ residual=None,
17
+ eps=self.variance_epsilon,
18
+ dropout_p=0.0,
19
+ prenorm=False,
20
+ residual_in_fp32=False,
21
+ ) # type: ignore
22
+
23
+
24
+ __all__ = ["RMSNorm"]
torch-ext/rmsnorm/rmsnorm.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(
20
+ x,
21
+ weight,
22
+ bias,
23
+ residual=None,
24
+ x1=None,
25
+ weight1=None,
26
+ bias1=None,
27
+ eps=1e-6,
28
+ dropout_p=0.0,
29
+ rowscale=None,
30
+ prenorm=False,
31
+ dropout_mask=None,
32
+ dropout_mask1=None,
33
+ upcast=False,
34
+ ):
35
+ dtype = x.dtype
36
+ if upcast:
37
+ x = x.float()
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ residual = residual.float() if residual is not None else residual
41
+ x1 = x1.float() if x1 is not None else None
42
+ weight1 = weight1.float() if weight1 is not None else None
43
+ bias1 = bias1.float() if bias1 is not None else None
44
+ if x1 is not None:
45
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
+ if rowscale is not None:
47
+ x = x * rowscale[..., None]
48
+ if dropout_p > 0.0:
49
+ if dropout_mask is not None:
50
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
51
+ else:
52
+ x = F.dropout(x, p=dropout_p)
53
+ if x1 is not None:
54
+ if dropout_mask1 is not None:
55
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
56
+ else:
57
+ x1 = F.dropout(x1, p=dropout_p)
58
+ if x1 is not None:
59
+ x = x + x1
60
+ if residual is not None:
61
+ x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(
63
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
+ ).to(dtype)
65
+ if weight1 is None:
66
+ return out if not prenorm else (out, x)
67
+ else:
68
+ out1 = F.layer_norm(
69
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
70
+ ).to(dtype)
71
+ return (out, out1) if not prenorm else (out, out1, x)
72
+
73
+
74
+ def rms_norm_ref(
75
+ x,
76
+ weight,
77
+ bias,
78
+ residual=None,
79
+ x1=None,
80
+ weight1=None,
81
+ bias1=None,
82
+ eps=1e-6,
83
+ dropout_p=0.0,
84
+ rowscale=None,
85
+ prenorm=False,
86
+ dropout_mask=None,
87
+ dropout_mask1=None,
88
+ upcast=False,
89
+ ):
90
+ dtype = x.dtype
91
+ if upcast:
92
+ x = x.float()
93
+ weight = weight.float()
94
+ bias = bias.float() if bias is not None else None
95
+ residual = residual.float() if residual is not None else residual
96
+ x1 = x1.float() if x1 is not None else None
97
+ weight1 = weight1.float() if weight1 is not None else None
98
+ bias1 = bias1.float() if bias1 is not None else None
99
+ if x1 is not None:
100
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
+ if rowscale is not None:
102
+ x = x * rowscale[..., None]
103
+ if dropout_p > 0.0:
104
+ if dropout_mask is not None:
105
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
106
+ else:
107
+ x = F.dropout(x, p=dropout_p)
108
+ if x1 is not None:
109
+ if dropout_mask1 is not None:
110
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
111
+ else:
112
+ x1 = F.dropout(x1, p=dropout_p)
113
+ if x1 is not None:
114
+ x = x + x1
115
+ if residual is not None:
116
+ x = (x + residual).to(x.dtype)
117
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
+ dtype
120
+ )
121
+ if weight1 is None:
122
+ return out if not prenorm else (out, x)
123
+ else:
124
+ out1 = (
125
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
+ ).to(dtype)
127
+ return (out, out1) if not prenorm else (out, out1, x)
128
+
129
+
130
+ @triton.autotune(
131
+ configs=[
132
+ triton.Config({}, num_warps=1),
133
+ triton.Config({}, num_warps=2),
134
+ triton.Config({}, num_warps=4),
135
+ triton.Config({}, num_warps=8),
136
+ triton.Config({}, num_warps=16),
137
+ triton.Config({}, num_warps=32),
138
+ ],
139
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
+ )
141
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
142
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
143
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
144
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
145
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
146
+ @triton.jit
147
+ def _layer_norm_fwd_1pass_kernel(
148
+ X, # pointer to the input
149
+ Y, # pointer to the output
150
+ W, # pointer to the weights
151
+ B, # pointer to the biases
152
+ RESIDUAL, # pointer to the residual
153
+ X1,
154
+ W1,
155
+ B1,
156
+ Y1,
157
+ RESIDUAL_OUT, # pointer to the residual
158
+ ROWSCALE,
159
+ SEEDS, # Dropout seeds for each row
160
+ DROPOUT_MASK,
161
+ Mean, # pointer to the mean
162
+ Rstd, # pointer to the 1/std
163
+ stride_x_row, # how much to increase the pointer when moving by 1 row
164
+ stride_y_row,
165
+ stride_res_row,
166
+ stride_res_out_row,
167
+ stride_x1_row,
168
+ stride_y1_row,
169
+ M, # number of rows in X
170
+ N, # number of columns in X
171
+ eps, # epsilon to avoid division by zero
172
+ dropout_p, # Dropout probability
173
+ IS_RMS_NORM: tl.constexpr,
174
+ BLOCK_N: tl.constexpr,
175
+ HAS_RESIDUAL: tl.constexpr,
176
+ STORE_RESIDUAL_OUT: tl.constexpr,
177
+ HAS_BIAS: tl.constexpr,
178
+ HAS_DROPOUT: tl.constexpr,
179
+ STORE_DROPOUT_MASK: tl.constexpr,
180
+ HAS_ROWSCALE: tl.constexpr,
181
+ HAS_X1: tl.constexpr,
182
+ HAS_W1: tl.constexpr,
183
+ HAS_B1: tl.constexpr,
184
+ ):
185
+ # Map the program id to the row of X and Y it should compute.
186
+ row = tl.program_id(0)
187
+ X += row * stride_x_row
188
+ Y += row * stride_y_row
189
+ if HAS_RESIDUAL:
190
+ RESIDUAL += row * stride_res_row
191
+ if STORE_RESIDUAL_OUT:
192
+ RESIDUAL_OUT += row * stride_res_out_row
193
+ if HAS_X1:
194
+ X1 += row * stride_x1_row
195
+ if HAS_W1:
196
+ Y1 += row * stride_y1_row
197
+ # Compute mean and variance
198
+ cols = tl.arange(0, BLOCK_N)
199
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
200
+ if HAS_ROWSCALE:
201
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
202
+ x *= rowscale
203
+ if HAS_DROPOUT:
204
+ # Compute dropout mask
205
+ # 7 rounds is good enough, and reduces register pressure
206
+ keep_mask = (
207
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
+ )
209
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
+ if STORE_DROPOUT_MASK:
211
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
212
+ if HAS_X1:
213
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
214
+ if HAS_ROWSCALE:
215
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
216
+ x1 *= rowscale
217
+ if HAS_DROPOUT:
218
+ # Compute dropout mask
219
+ # 7 rounds is good enough, and reduces register pressure
220
+ keep_mask = (
221
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
+ > dropout_p
223
+ )
224
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
+ if STORE_DROPOUT_MASK:
226
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
227
+ x += x1
228
+ if HAS_RESIDUAL:
229
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
230
+ x += residual
231
+ if STORE_RESIDUAL_OUT:
232
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
233
+ if not IS_RMS_NORM:
234
+ mean = tl.sum(x, axis=0) / N
235
+ tl.store(Mean + row, mean)
236
+ xbar = tl.where(cols < N, x - mean, 0.0)
237
+ var = tl.sum(xbar * xbar, axis=0) / N
238
+ else:
239
+ xbar = tl.where(cols < N, x, 0.0)
240
+ var = tl.sum(xbar * xbar, axis=0) / N
241
+ rstd = 1 / tl.sqrt(var + eps)
242
+ tl.store(Rstd + row, rstd)
243
+ # Normalize and apply linear transformation
244
+ mask = cols < N
245
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
246
+ if HAS_BIAS:
247
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
248
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
249
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
250
+ # Write output
251
+ tl.store(Y + cols, y, mask=mask)
252
+ if HAS_W1:
253
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
254
+ if HAS_B1:
255
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
256
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
257
+ tl.store(Y1 + cols, y1, mask=mask)
258
+
259
+
260
+ def _layer_norm_fwd(
261
+ x,
262
+ weight,
263
+ bias,
264
+ eps,
265
+ residual=None,
266
+ x1=None,
267
+ weight1=None,
268
+ bias1=None,
269
+ dropout_p=0.0,
270
+ rowscale=None,
271
+ out_dtype=None,
272
+ residual_dtype=None,
273
+ is_rms_norm=False,
274
+ return_dropout_mask=False,
275
+ out=None,
276
+ residual_out=None,
277
+ ):
278
+ if residual is not None:
279
+ residual_dtype = residual.dtype
280
+ M, N = x.shape
281
+ assert x.stride(-1) == 1
282
+ if residual is not None:
283
+ assert residual.stride(-1) == 1
284
+ assert residual.shape == (M, N)
285
+ assert weight.shape == (N,)
286
+ assert weight.stride(-1) == 1
287
+ if bias is not None:
288
+ assert bias.stride(-1) == 1
289
+ assert bias.shape == (N,)
290
+ if x1 is not None:
291
+ assert x1.shape == x.shape
292
+ assert rowscale is None
293
+ assert x1.stride(-1) == 1
294
+ if weight1 is not None:
295
+ assert weight1.shape == (N,)
296
+ assert weight1.stride(-1) == 1
297
+ if bias1 is not None:
298
+ assert bias1.shape == (N,)
299
+ assert bias1.stride(-1) == 1
300
+ if rowscale is not None:
301
+ assert rowscale.is_contiguous()
302
+ assert rowscale.shape == (M,)
303
+ # allocate output
304
+ if out is None:
305
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
306
+ else:
307
+ assert out.shape == x.shape
308
+ assert out.stride(-1) == 1
309
+ if weight1 is not None:
310
+ y1 = torch.empty_like(out)
311
+ assert y1.stride(-1) == 1
312
+ else:
313
+ y1 = None
314
+ if (
315
+ residual is not None
316
+ or (residual_dtype is not None and residual_dtype != x.dtype)
317
+ or dropout_p > 0.0
318
+ or rowscale is not None
319
+ or x1 is not None
320
+ ):
321
+ if residual_out is None:
322
+ residual_out = torch.empty(
323
+ M,
324
+ N,
325
+ device=x.device,
326
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
+ )
328
+ else:
329
+ assert residual_out.shape == x.shape
330
+ assert residual_out.stride(-1) == 1
331
+ else:
332
+ residual_out = None
333
+ mean = (
334
+ torch.empty((M,), dtype=torch.float32, device=x.device)
335
+ if not is_rms_norm
336
+ else None
337
+ )
338
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
+ if dropout_p > 0.0:
340
+ seeds = torch.randint(
341
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
342
+ )
343
+ else:
344
+ seeds = None
345
+ if return_dropout_mask and dropout_p > 0.0:
346
+ dropout_mask = torch.empty(
347
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
+ )
349
+ else:
350
+ dropout_mask = None
351
+ # Less than 64KB per feature: enqueue fused kernel
352
+ MAX_FUSED_SIZE = 65536 // x.element_size()
353
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
354
+ if N > BLOCK_N:
355
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
356
+ with torch.cuda.device(x.device.index):
357
+ _layer_norm_fwd_1pass_kernel[(M,)](
358
+ x,
359
+ out,
360
+ weight,
361
+ bias,
362
+ residual,
363
+ x1,
364
+ weight1,
365
+ bias1,
366
+ y1,
367
+ residual_out,
368
+ rowscale,
369
+ seeds,
370
+ dropout_mask,
371
+ mean,
372
+ rstd,
373
+ x.stride(0),
374
+ out.stride(0),
375
+ residual.stride(0) if residual is not None else 0,
376
+ residual_out.stride(0) if residual_out is not None else 0,
377
+ x1.stride(0) if x1 is not None else 0,
378
+ y1.stride(0) if y1 is not None else 0,
379
+ M,
380
+ N,
381
+ eps,
382
+ dropout_p,
383
+ is_rms_norm,
384
+ BLOCK_N,
385
+ residual is not None,
386
+ residual_out is not None,
387
+ bias is not None,
388
+ dropout_p > 0.0,
389
+ dropout_mask is not None,
390
+ rowscale is not None,
391
+ )
392
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
393
+ if dropout_mask is not None and x1 is not None:
394
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
395
+ else:
396
+ dropout_mask1 = None
397
+ return (
398
+ out,
399
+ y1,
400
+ mean,
401
+ rstd,
402
+ residual_out if residual_out is not None else x,
403
+ seeds,
404
+ dropout_mask,
405
+ dropout_mask1,
406
+ )
407
+
408
+
409
+ @triton.autotune(
410
+ configs=[
411
+ triton.Config({}, num_warps=1),
412
+ triton.Config({}, num_warps=2),
413
+ triton.Config({}, num_warps=4),
414
+ triton.Config({}, num_warps=8),
415
+ triton.Config({}, num_warps=16),
416
+ triton.Config({}, num_warps=32),
417
+ ],
418
+ key=[
419
+ "N",
420
+ "HAS_DRESIDUAL",
421
+ "STORE_DRESIDUAL",
422
+ "IS_RMS_NORM",
423
+ "HAS_BIAS",
424
+ "HAS_DROPOUT",
425
+ ],
426
+ )
427
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
429
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
430
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
431
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
432
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
433
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
434
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
435
+ @triton.jit
436
+ def _layer_norm_bwd_kernel(
437
+ X, # pointer to the input
438
+ W, # pointer to the weights
439
+ B, # pointer to the biases
440
+ Y, # pointer to the output to be recomputed
441
+ DY, # pointer to the output gradient
442
+ DX, # pointer to the input gradient
443
+ DW, # pointer to the partial sum of weights gradient
444
+ DB, # pointer to the partial sum of biases gradient
445
+ DRESIDUAL,
446
+ W1,
447
+ DY1,
448
+ DX1,
449
+ DW1,
450
+ DB1,
451
+ DRESIDUAL_IN,
452
+ ROWSCALE,
453
+ SEEDS,
454
+ Mean, # pointer to the mean
455
+ Rstd, # pointer to the 1/std
456
+ stride_x_row, # how much to increase the pointer when moving by 1 row
457
+ stride_y_row,
458
+ stride_dy_row,
459
+ stride_dx_row,
460
+ stride_dres_row,
461
+ stride_dy1_row,
462
+ stride_dx1_row,
463
+ stride_dres_in_row,
464
+ M, # number of rows in X
465
+ N, # number of columns in X
466
+ eps, # epsilon to avoid division by zero
467
+ dropout_p,
468
+ rows_per_program,
469
+ IS_RMS_NORM: tl.constexpr,
470
+ BLOCK_N: tl.constexpr,
471
+ HAS_DRESIDUAL: tl.constexpr,
472
+ STORE_DRESIDUAL: tl.constexpr,
473
+ HAS_BIAS: tl.constexpr,
474
+ HAS_DROPOUT: tl.constexpr,
475
+ HAS_ROWSCALE: tl.constexpr,
476
+ HAS_DY1: tl.constexpr,
477
+ HAS_DX1: tl.constexpr,
478
+ HAS_B1: tl.constexpr,
479
+ RECOMPUTE_OUTPUT: tl.constexpr,
480
+ ):
481
+ # Map the program id to the elements of X, DX, and DY it should compute.
482
+ row_block_id = tl.program_id(0)
483
+ row_start = row_block_id * rows_per_program
484
+ # Do not early exit if row_start >= M, because we need to write DW and DB
485
+ cols = tl.arange(0, BLOCK_N)
486
+ mask = cols < N
487
+ X += row_start * stride_x_row
488
+ if HAS_DRESIDUAL:
489
+ DRESIDUAL += row_start * stride_dres_row
490
+ if STORE_DRESIDUAL:
491
+ DRESIDUAL_IN += row_start * stride_dres_in_row
492
+ DY += row_start * stride_dy_row
493
+ DX += row_start * stride_dx_row
494
+ if HAS_DY1:
495
+ DY1 += row_start * stride_dy1_row
496
+ if HAS_DX1:
497
+ DX1 += row_start * stride_dx1_row
498
+ if RECOMPUTE_OUTPUT:
499
+ Y += row_start * stride_y_row
500
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
501
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
502
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
503
+ if HAS_DY1:
504
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
505
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
+ if HAS_BIAS:
507
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
508
+ if HAS_DY1:
509
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
510
+ if HAS_B1:
511
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
512
+ row_end = min((row_block_id + 1) * rows_per_program, M)
513
+ for row in range(row_start, row_end):
514
+ # Load data to SRAM
515
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
516
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
517
+ if HAS_DY1:
518
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
519
+ if not IS_RMS_NORM:
520
+ mean = tl.load(Mean + row)
521
+ rstd = tl.load(Rstd + row)
522
+ # Compute dx
523
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
524
+ xhat = tl.where(mask, xhat, 0.0)
525
+ if RECOMPUTE_OUTPUT:
526
+ y = xhat * w + b if HAS_BIAS else xhat * w
527
+ tl.store(Y + cols, y, mask=mask)
528
+ wdy = w * dy
529
+ dw += dy * xhat
530
+ if HAS_BIAS:
531
+ db += dy
532
+ if HAS_DY1:
533
+ wdy += w1 * dy1
534
+ dw1 += dy1 * xhat
535
+ if HAS_B1:
536
+ db1 += dy1
537
+ if not IS_RMS_NORM:
538
+ c1 = tl.sum(xhat * wdy, axis=0) / N
539
+ c2 = tl.sum(wdy, axis=0) / N
540
+ dx = (wdy - (xhat * c1 + c2)) * rstd
541
+ else:
542
+ c1 = tl.sum(xhat * wdy, axis=0) / N
543
+ dx = (wdy - xhat * c1) * rstd
544
+ if HAS_DRESIDUAL:
545
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
546
+ dx += dres
547
+ # Write dx
548
+ if STORE_DRESIDUAL:
549
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
550
+ if HAS_DX1:
551
+ if HAS_DROPOUT:
552
+ keep_mask = (
553
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
+ > dropout_p
555
+ )
556
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
+ else:
558
+ dx1 = dx
559
+ tl.store(DX1 + cols, dx1, mask=mask)
560
+ if HAS_DROPOUT:
561
+ keep_mask = (
562
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
+ > dropout_p
564
+ )
565
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
+ if HAS_ROWSCALE:
567
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
568
+ dx *= rowscale
569
+ tl.store(DX + cols, dx, mask=mask)
570
+
571
+ X += stride_x_row
572
+ if HAS_DRESIDUAL:
573
+ DRESIDUAL += stride_dres_row
574
+ if STORE_DRESIDUAL:
575
+ DRESIDUAL_IN += stride_dres_in_row
576
+ if RECOMPUTE_OUTPUT:
577
+ Y += stride_y_row
578
+ DY += stride_dy_row
579
+ DX += stride_dx_row
580
+ if HAS_DY1:
581
+ DY1 += stride_dy1_row
582
+ if HAS_DX1:
583
+ DX1 += stride_dx1_row
584
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
585
+ if HAS_BIAS:
586
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
587
+ if HAS_DY1:
588
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
589
+ if HAS_B1:
590
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
591
+
592
+
593
+ def _layer_norm_bwd(
594
+ dy,
595
+ x,
596
+ weight,
597
+ bias,
598
+ eps,
599
+ mean,
600
+ rstd,
601
+ dresidual=None,
602
+ dy1=None,
603
+ weight1=None,
604
+ bias1=None,
605
+ seeds=None,
606
+ dropout_p=0.0,
607
+ rowscale=None,
608
+ has_residual=False,
609
+ has_x1=False,
610
+ is_rms_norm=False,
611
+ x_dtype=None,
612
+ recompute_output=False,
613
+ ):
614
+ M, N = x.shape
615
+ assert x.stride(-1) == 1
616
+ assert dy.stride(-1) == 1
617
+ assert dy.shape == (M, N)
618
+ if dresidual is not None:
619
+ assert dresidual.stride(-1) == 1
620
+ assert dresidual.shape == (M, N)
621
+ assert weight.shape == (N,)
622
+ assert weight.stride(-1) == 1
623
+ if bias is not None:
624
+ assert bias.stride(-1) == 1
625
+ assert bias.shape == (N,)
626
+ if dy1 is not None:
627
+ assert weight1 is not None
628
+ assert dy1.shape == dy.shape
629
+ assert dy1.stride(-1) == 1
630
+ if weight1 is not None:
631
+ assert weight1.shape == (N,)
632
+ assert weight1.stride(-1) == 1
633
+ if bias1 is not None:
634
+ assert bias1.shape == (N,)
635
+ assert bias1.stride(-1) == 1
636
+ if seeds is not None:
637
+ assert seeds.is_contiguous()
638
+ assert seeds.shape == (M if not has_x1 else M * 2,)
639
+ if rowscale is not None:
640
+ assert rowscale.is_contiguous()
641
+ assert rowscale.shape == (M,)
642
+ # allocate output
643
+ dx = (
644
+ torch.empty_like(x)
645
+ if x_dtype is None
646
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
647
+ )
648
+ dresidual_in = (
649
+ torch.empty_like(x)
650
+ if has_residual
651
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
652
+ else None
653
+ )
654
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
+ y = (
656
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
+ if recompute_output
658
+ else None
659
+ )
660
+ if recompute_output:
661
+ assert (
662
+ weight1 is None
663
+ ), "recompute_output is not supported with parallel LayerNorm"
664
+
665
+ # Less than 64KB per feature: enqueue fused kernel
666
+ MAX_FUSED_SIZE = 65536 // x.element_size()
667
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
668
+ if N > BLOCK_N:
669
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
670
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
671
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
672
+ _db = (
673
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
674
+ if bias is not None
675
+ else None
676
+ )
677
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
678
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
679
+ rows_per_program = math.ceil(M / sm_count)
680
+ grid = (sm_count,)
681
+ with torch.cuda.device(x.device.index):
682
+ _layer_norm_bwd_kernel[grid](
683
+ x,
684
+ weight,
685
+ bias,
686
+ y,
687
+ dy,
688
+ dx,
689
+ _dw,
690
+ _db,
691
+ dresidual,
692
+ weight1,
693
+ dy1,
694
+ dx1,
695
+ _dw1,
696
+ _db1,
697
+ dresidual_in,
698
+ rowscale,
699
+ seeds,
700
+ mean,
701
+ rstd,
702
+ x.stride(0),
703
+ 0 if not recompute_output else y.stride(0),
704
+ dy.stride(0),
705
+ dx.stride(0),
706
+ dresidual.stride(0) if dresidual is not None else 0,
707
+ dy1.stride(0) if dy1 is not None else 0,
708
+ dx1.stride(0) if dx1 is not None else 0,
709
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
710
+ M,
711
+ N,
712
+ eps,
713
+ dropout_p,
714
+ rows_per_program,
715
+ is_rms_norm,
716
+ BLOCK_N,
717
+ dresidual is not None,
718
+ dresidual_in is not None,
719
+ bias is not None,
720
+ dropout_p > 0.0,
721
+ )
722
+ dw = _dw.sum(0).to(weight.dtype)
723
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
724
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
725
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
726
+ # Don't need to compute dresidual_in separately in this case
727
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
728
+ dresidual_in = dx
729
+ if has_x1 and dropout_p == 0.0:
730
+ dx1 = dx
731
+ return (
732
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
733
+ if not recompute_output
734
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
735
+ )
736
+
737
+
738
+ class LayerNormFn(torch.autograd.Function):
739
+ @staticmethod
740
+ def forward(
741
+ ctx,
742
+ x,
743
+ weight,
744
+ bias,
745
+ residual=None,
746
+ x1=None,
747
+ weight1=None,
748
+ bias1=None,
749
+ eps=1e-6,
750
+ dropout_p=0.0,
751
+ rowscale=None,
752
+ prenorm=False,
753
+ residual_in_fp32=False,
754
+ is_rms_norm=False,
755
+ return_dropout_mask=False,
756
+ out=None,
757
+ residual_out=None,
758
+ ):
759
+ x_shape_og = x.shape
760
+ # reshape input data into 2D tensor
761
+ x = x.reshape(-1, x.shape[-1])
762
+ if x.stride(-1) != 1:
763
+ x = x.contiguous()
764
+ if residual is not None:
765
+ assert residual.shape == x_shape_og
766
+ residual = residual.reshape(-1, residual.shape[-1])
767
+ if residual.stride(-1) != 1:
768
+ residual = residual.contiguous()
769
+ if x1 is not None:
770
+ assert x1.shape == x_shape_og
771
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
772
+ x1 = x1.reshape(-1, x1.shape[-1])
773
+ if x1.stride(-1) != 1:
774
+ x1 = x1.contiguous()
775
+ weight = weight.contiguous()
776
+ if bias is not None:
777
+ bias = bias.contiguous()
778
+ if weight1 is not None:
779
+ weight1 = weight1.contiguous()
780
+ if bias1 is not None:
781
+ bias1 = bias1.contiguous()
782
+ if rowscale is not None:
783
+ rowscale = rowscale.reshape(-1).contiguous()
784
+ residual_dtype = (
785
+ residual.dtype
786
+ if residual is not None
787
+ else (torch.float32 if residual_in_fp32 else None)
788
+ )
789
+ if out is not None:
790
+ out = out.reshape(-1, out.shape[-1])
791
+ if residual_out is not None:
792
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
+ _layer_norm_fwd(
795
+ x,
796
+ weight,
797
+ bias,
798
+ eps,
799
+ residual,
800
+ x1,
801
+ weight1,
802
+ bias1,
803
+ dropout_p=dropout_p,
804
+ rowscale=rowscale,
805
+ residual_dtype=residual_dtype,
806
+ is_rms_norm=is_rms_norm,
807
+ return_dropout_mask=return_dropout_mask,
808
+ out=out,
809
+ residual_out=residual_out,
810
+ )
811
+ )
812
+ ctx.save_for_backward(
813
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
814
+ )
815
+ ctx.x_shape_og = x_shape_og
816
+ ctx.eps = eps
817
+ ctx.dropout_p = dropout_p
818
+ ctx.is_rms_norm = is_rms_norm
819
+ ctx.has_residual = residual is not None
820
+ ctx.has_x1 = x1 is not None
821
+ ctx.prenorm = prenorm
822
+ ctx.x_dtype = x.dtype
823
+ y = y.reshape(x_shape_og)
824
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
+ residual_out = (
826
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
827
+ )
828
+ dropout_mask = (
829
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
+ )
831
+ dropout_mask1 = (
832
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
+ )
834
+ if not return_dropout_mask:
835
+ if weight1 is None:
836
+ return y if not prenorm else (y, residual_out)
837
+ else:
838
+ return (y, y1) if not prenorm else (y, y1, residual_out)
839
+ else:
840
+ if weight1 is None:
841
+ return (
842
+ (y, dropout_mask, dropout_mask1)
843
+ if not prenorm
844
+ else (y, residual_out, dropout_mask, dropout_mask1)
845
+ )
846
+ else:
847
+ return (
848
+ (y, y1, dropout_mask, dropout_mask1)
849
+ if not prenorm
850
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
851
+ )
852
+
853
+ @staticmethod
854
+ def backward(ctx, dy, *args):
855
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
856
+ dy = dy.reshape(-1, dy.shape[-1])
857
+ if dy.stride(-1) != 1:
858
+ dy = dy.contiguous()
859
+ assert dy.shape == x.shape
860
+ if weight1 is not None:
861
+ dy1, args = args[0], args[1:]
862
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
863
+ if dy1.stride(-1) != 1:
864
+ dy1 = dy1.contiguous()
865
+ assert dy1.shape == x.shape
866
+ else:
867
+ dy1 = None
868
+ if ctx.prenorm:
869
+ dresidual = args[0]
870
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
871
+ if dresidual.stride(-1) != 1:
872
+ dresidual = dresidual.contiguous()
873
+ assert dresidual.shape == x.shape
874
+ else:
875
+ dresidual = None
876
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
877
+ dy,
878
+ x,
879
+ weight,
880
+ bias,
881
+ ctx.eps,
882
+ mean,
883
+ rstd,
884
+ dresidual,
885
+ dy1,
886
+ weight1,
887
+ bias1,
888
+ seeds,
889
+ ctx.dropout_p,
890
+ rowscale,
891
+ ctx.has_residual,
892
+ ctx.has_x1,
893
+ ctx.is_rms_norm,
894
+ x_dtype=ctx.x_dtype,
895
+ )
896
+ return (
897
+ dx.reshape(ctx.x_shape_og),
898
+ dw,
899
+ db,
900
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
901
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
902
+ dw1,
903
+ db1,
904
+ None,
905
+ None,
906
+ None,
907
+ None,
908
+ None,
909
+ None,
910
+ None,
911
+ None,
912
+ None,
913
+ )
914
+
915
+
916
+ def layer_norm_fn(
917
+ x,
918
+ weight,
919
+ bias,
920
+ residual=None,
921
+ x1=None,
922
+ weight1=None,
923
+ bias1=None,
924
+ eps=1e-6,
925
+ dropout_p=0.0,
926
+ rowscale=None,
927
+ prenorm=False,
928
+ residual_in_fp32=False,
929
+ is_rms_norm=False,
930
+ return_dropout_mask=False,
931
+ out=None,
932
+ residual_out=None,
933
+ ):
934
+ return LayerNormFn.apply(
935
+ x,
936
+ weight,
937
+ bias,
938
+ residual,
939
+ x1,
940
+ weight1,
941
+ bias1,
942
+ eps,
943
+ dropout_p,
944
+ rowscale,
945
+ prenorm,
946
+ residual_in_fp32,
947
+ is_rms_norm,
948
+ return_dropout_mask,
949
+ out,
950
+ residual_out,
951
+ )
952
+
953
+
954
+ def rms_norm_fn(
955
+ x,
956
+ weight,
957
+ bias,
958
+ residual=None,
959
+ x1=None,
960
+ weight1=None,
961
+ bias1=None,
962
+ eps=1e-6,
963
+ dropout_p=0.0,
964
+ rowscale=None,
965
+ prenorm=False,
966
+ residual_in_fp32=False,
967
+ return_dropout_mask=False,
968
+ out=None,
969
+ residual_out=None,
970
+ ):
971
+ return LayerNormFn.apply(
972
+ x,
973
+ weight,
974
+ bias,
975
+ residual,
976
+ x1,
977
+ weight1,
978
+ bias1,
979
+ eps,
980
+ dropout_p,
981
+ rowscale,
982
+ prenorm,
983
+ residual_in_fp32,
984
+ True,
985
+ return_dropout_mask,
986
+ out,
987
+ residual_out,
988
+ )