Kernels
github-actions[bot] commited on
Commit
f2471cd
·
1 Parent(s): f19f8f4

Add built binary [skip-build]

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +2 -1
  2. build/torch27-cxx11-cu118-x86_64-linux/activation/{_activation_20250907180255.abi3.so → _activation_53ed492_dirty.abi3.so} +2 -2
  3. build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +3 -3
  4. build/torch27-cxx11-cu118-x86_64-linux/activation/fused_add_rms_norm_meta.py +199 -0
  5. build/torch27-cxx11-cu118-x86_64-linux/activation/parallel_style.py +50 -0
  6. build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +47 -20
  7. build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm_meta.py +164 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +2 -1
  9. build/{torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} +2 -2
  10. build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +0 -3
  11. build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +0 -3
  12. build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  13. build/torch27-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py +199 -0
  14. build/torch27-cxx11-cu126-x86_64-linux/activation/parallel_style.py +50 -0
  15. build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +47 -20
  16. build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py +164 -0
  17. build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +2 -1
  18. build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so +0 -3
  19. build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so +3 -0
  20. build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +0 -3
  21. build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +0 -3
  22. build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
  23. build/torch27-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py +199 -0
  24. build/torch27-cxx11-cu128-x86_64-linux/activation/parallel_style.py +50 -0
  25. build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +47 -20
  26. build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py +164 -0
  27. build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +2 -1
  28. build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so +0 -3
  29. build/{torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} +2 -2
  30. build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +0 -3
  31. build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +0 -3
  32. build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
  33. build/torch27-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py +199 -0
  34. build/torch27-cxx11-rocm63-x86_64-linux/activation/parallel_style.py +50 -0
  35. build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +47 -20
  36. build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py +164 -0
  37. build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py +2 -1
  38. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so +0 -3
  39. build/{torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so → torch28-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} +2 -2
  40. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +0 -3
  41. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +0 -3
  42. build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  43. build/torch28-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py +199 -0
  44. build/torch28-cxx11-cu126-x86_64-linux/activation/parallel_style.py +50 -0
  45. build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +47 -20
  46. build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py +164 -0
  47. build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py +2 -1
  48. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so +0 -3
  49. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so +3 -0
  50. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +0 -3
build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/torch27-cxx11-cu118-x86_64-linux/activation/{_activation_20250907180255.abi3.so → _activation_53ed492_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d21a85bf21aa74f1281541e658acfd4f4326d902efe3578b059eccf054443284
3
- size 8089696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80267a0391fa4cb22aa3eb04b05d8214c2bfaed968b714185bc20214596072e3
3
+ size 8618232
build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_e5e2eeb_dirty
3
- ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_e5e2eeb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_53ed492_dirty
3
+ ops = torch.ops._activation_53ed492_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_53ed492_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/activation/fused_add_rms_norm_meta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_fused_add_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.fused_add_rms_norm.default,
34
+ schema_info=RuntimeSchemaInfo(1))
35
+ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
36
+ mesh = op_schema.get_mesh_from_args()
37
+
38
+ assert len(op_schema.args_schema) == 4
39
+ (
40
+ input_strategy,
41
+ residual_strategy,
42
+ weight_strategy,
43
+ _, # eps
44
+ ) = op_schema.args_schema
45
+
46
+ assert isinstance(input_strategy, OpStrategy)
47
+ assert isinstance(residual_strategy, OpStrategy)
48
+ assert isinstance(weight_strategy, OpStrategy)
49
+
50
+ lengths = {
51
+ "input": len(input_strategy.strategies),
52
+ "residual": len(residual_strategy.strategies),
53
+ "weight": len(weight_strategy.strategies),
54
+ }
55
+ assert len(set(
56
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
57
+
58
+ last_dim = input_strategy.ndim - 1
59
+ strategy = OpStrategy([])
60
+ for input, residual, weight in zip(input_strategy.strategies,
61
+ residual_strategy.strategies,
62
+ weight_strategy.strategies):
63
+
64
+ input_src = input.output_spec
65
+ residual_src = residual.output_spec
66
+ weight_src = weight.output_spec
67
+
68
+ assert isinstance(input_src, DTensorSpec)
69
+ assert isinstance(residual_src, DTensorSpec)
70
+ assert isinstance(weight_src, DTensorSpec)
71
+
72
+ redistribute_costs = []
73
+
74
+ # Input can be sharded in any dim except the last dim.
75
+ input_tgt = DTensorSpec(
76
+ mesh=mesh,
77
+ placements=_replicate_dims_start_at(input_src.placements,
78
+ last_dim),
79
+ tensor_meta=input_src.tensor_meta,
80
+ )
81
+ redistribute_costs.append(
82
+ generate_redistribute_costs(input_strategy, input_tgt))
83
+
84
+ # Residual add must have the same sharding as input.
85
+ residual_tgt = input_tgt
86
+ redistribute_costs.append(
87
+ generate_redistribute_costs(residual_strategy, residual_tgt))
88
+
89
+ # Weight cannot be sharded, so always replicate it.
90
+ weight_tgt = DTensorSpec(
91
+ mesh=mesh,
92
+ placements=(Replicate(), ),
93
+ tensor_meta=weight_src.tensor_meta,
94
+ )
95
+ redistribute_costs.append(
96
+ generate_redistribute_costs(weight_strategy, weight_tgt))
97
+
98
+ strategy.strategies.append(
99
+ OpSpec(
100
+ output_specs=[input_tgt, input_tgt],
101
+ input_specs=[input_tgt, residual_tgt, weight_tgt],
102
+ redistribute_cost=redistribute_costs,
103
+ ))
104
+ return strategy
105
+
106
+
107
+ @register_op_strategy(ops.fused_add_rms_norm_backward.default,
108
+ schema_info=RuntimeSchemaInfo(2))
109
+ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
110
+ mesh = op_schema.get_mesh_from_args()
111
+
112
+ assert len(op_schema.args_schema) == 6
113
+ (
114
+ output_grad_strategy,
115
+ add_output_grad_strategy,
116
+ add_output_strategy,
117
+ weight_strategy,
118
+ _, # eps
119
+ need_input_grad, # need_input_grad
120
+ ) = op_schema.args_schema
121
+
122
+ assert isinstance(output_grad_strategy, OpStrategy)
123
+ assert isinstance(add_output_grad_strategy, OpStrategy)
124
+ assert isinstance(add_output_strategy, OpStrategy)
125
+ assert isinstance(weight_strategy, OpStrategy)
126
+
127
+ lengths = {
128
+ "output_grad": len(output_grad_strategy.strategies),
129
+ "add_output_grad": len(add_output_grad_strategy.strategies),
130
+ "add_output": len(add_output_strategy.strategies),
131
+ "weight": len(weight_strategy.strategies),
132
+ }
133
+ assert len(set(
134
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
135
+
136
+ zipped = zip(
137
+ output_grad_strategy.strategies,
138
+ add_output_grad_strategy.strategies,
139
+ add_output_strategy.strategies,
140
+ weight_strategy.strategies,
141
+ )
142
+
143
+ last_dim = output_grad_strategy.ndim - 1
144
+ strategy = OpStrategy([])
145
+ for output_grad, add_output_grad, add_output, weight in zipped:
146
+ output_grad_src = output_grad.output_spec
147
+ add_output_grad_src = add_output_grad.output_spec
148
+ add_output_src = add_output.output_spec
149
+ weight_src = weight.output_spec
150
+
151
+ assert isinstance(output_grad_src, DTensorSpec)
152
+ assert isinstance(add_output_grad_src, DTensorSpec)
153
+ assert isinstance(add_output_src, DTensorSpec)
154
+ assert isinstance(weight_src, DTensorSpec)
155
+
156
+ redistribute_costs = []
157
+
158
+ # output grad can be sharded in any dim except the last dim.
159
+ output_grad_tgt = DTensorSpec(
160
+ mesh=mesh,
161
+ placements=_replicate_dims_start_at(output_grad_src.placements,
162
+ last_dim),
163
+ tensor_meta=output_grad_src.tensor_meta,
164
+ )
165
+ redistribute_costs.append(
166
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
167
+
168
+ # add_output_grad must have the same sharding as output_grad.
169
+ add_output_grad_tgt = output_grad_tgt
170
+ redistribute_costs.append(
171
+ generate_redistribute_costs(add_output_grad_strategy,
172
+ add_output_grad_tgt))
173
+
174
+ # add_output must have the same sharding as output_grad.
175
+ add_output_tgt = output_grad_tgt
176
+ redistribute_costs.append(
177
+ generate_redistribute_costs(add_output_strategy, add_output_tgt))
178
+
179
+ # Weight cannot be sharded, so always replicate it.
180
+ weight_tgt = DTensorSpec(
181
+ mesh=mesh,
182
+ placements=(Replicate(), ),
183
+ tensor_meta=weight_src.tensor_meta,
184
+ )
185
+ redistribute_costs.append(
186
+ generate_redistribute_costs(weight_strategy, weight_tgt))
187
+
188
+ strategy.strategies.append(
189
+ OpSpec(
190
+ output_specs=[
191
+ output_grad_tgt if need_input_grad else None, weight_tgt
192
+ ],
193
+ input_specs=[
194
+ output_grad_tgt, add_output_grad_tgt, add_output_tgt,
195
+ weight_tgt
196
+ ],
197
+ redistribute_cost=redistribute_costs,
198
+ ))
199
+ return strategy
build/torch27-cxx11-cu118-x86_64-linux/activation/parallel_style.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard,
8
+ distribute_module, distribute_tensor)
9
+ from torch.distributed.tensor.parallel import SequenceParallel
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class ResidualSequenceParallel(SequenceParallel):
14
+ """ Consider the case where we have a residual connection across a sequence parallel layer."""
15
+
16
+ @staticmethod
17
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
18
+ input_tensor = inputs[0]
19
+ residual_tensor = inputs[1]
20
+
21
+ assert isinstance(input_tensor,
22
+ DTensor) == isinstance(residual_tensor, DTensor)
23
+ assert isinstance(input_tensor,
24
+ torch.Tensor) == isinstance(residual_tensor,
25
+ torch.Tensor)
26
+
27
+ if isinstance(input_tensor, DTensor):
28
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
29
+ if input_tensor.placements != sequence_sharding:
30
+ input_tensor = input_tensor.redistribute(
31
+ placements=sequence_sharding, async_op=True)
32
+ if residual_tensor.placements != sequence_sharding:
33
+ residual_tensor = residual_tensor.redistribute(
34
+ placements=sequence_sharding, async_op=True)
35
+ return input_tensor, residual_tensor
36
+
37
+ elif isinstance(input_tensor, torch.Tensor):
38
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
39
+ return DTensor.from_local(input_tensor,
40
+ device_mesh,
41
+ sequence_sharding,
42
+ run_check=False), DTensor.from_local(
43
+ residual_tensor,
44
+ device_mesh,
45
+ sequence_sharding,
46
+ run_check=False)
47
+ else:
48
+ raise ValueError(
49
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
50
+ )
build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
@@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
42
  # Note that forward, setup_context, and backward are @staticmethods
43
  @staticmethod
44
  def forward(input, residual, weight, eps):
45
- output = torch.empty_like(input)
46
- add_output = torch.empty_like(input)
47
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
48
- eps)
49
  return output, add_output
50
 
51
  @staticmethod
@@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
68
- grad = torch.empty_like(output_grad) if need_in or need_res else None
 
69
 
70
- weight_grad = torch.empty_like(
71
- weight) if ctx.needs_input_grad[2] else None
72
-
73
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
- weight, eps)
 
 
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
77
 
78
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
11
  # Note that forward, setup_context, and backward are @staticmethods
12
  @staticmethod
13
  def forward(input, weight, eps):
14
+ return ops.rms_norm(input, weight, eps)
 
 
15
 
16
  @staticmethod
17
  # inputs is a Tuple of all of the inputs passed to forward.
 
27
  input, weight = ctx.saved_tensors
28
  eps = ctx.eps
29
 
30
+ input_grad, weight_grad = ops.rms_norm_backward(
31
+ output_grad, input, weight, eps)
 
 
 
 
 
32
 
33
  return input_grad, weight_grad, None
34
 
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
 
65
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
66
+ output_grad,
67
+ add_output_grad,
68
+ add_output,
69
+ weight,
70
+ eps,
71
+ need_input_grad=need_in or need_res)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ @torch.library.register_fake(ops.rms_norm.default)
79
+ def rms_norm_abstract(x, weight, eps):
80
+ return torch.empty_like(x)
81
+
82
+
83
+ @torch.library.register_fake(ops.rms_norm_backward.default)
84
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
85
+ return torch.empty_like(x), torch.empty_like(weight)
86
+
87
+
88
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
89
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
90
+ return torch.empty_like(x), torch.empty_like(x)
91
+
92
+
93
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
94
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
95
+ add_output, weight, eps,
96
+ need_input_grad: bool):
97
+ return torch.empty_like(
98
+ output_grad) if need_input_grad else None, torch.empty_like(weight)
99
+
100
+
101
+ if version.parse(torch.__version__) >= version.parse("2.8"):
102
+ from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta
103
+ from .rms_norm_meta import register_rms_norm_meta
104
+ register_fused_add_rms_norm_meta()
105
+ register_rms_norm_meta()
build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
+ mesh = op_schema.get_mesh_from_args()
36
+
37
+ assert len(op_schema.args_schema) == 3
38
+ (
39
+ input_strategy,
40
+ weight_strategy,
41
+ _, # eps
42
+ ) = op_schema.args_schema
43
+
44
+ assert isinstance(input_strategy, OpStrategy)
45
+ assert isinstance(weight_strategy, OpStrategy)
46
+
47
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
48
+
49
+ last_dim = input_strategy.ndim - 1
50
+ strategy = OpStrategy([])
51
+ for input, weight in zip(input_strategy.strategies,
52
+ weight_strategy.strategies):
53
+ input_src = input.output_spec
54
+ weight_src = weight.output_spec
55
+
56
+ assert isinstance(input_src, DTensorSpec)
57
+ assert isinstance(weight_src, DTensorSpec)
58
+
59
+ redistribute_costs = []
60
+
61
+ # Input can be sharded in any dim except the last dim.
62
+ input_tgt = DTensorSpec(
63
+ mesh=mesh,
64
+ placements=_replicate_dims_start_at(input_src.placements,
65
+ last_dim),
66
+ tensor_meta=input_src.tensor_meta,
67
+ )
68
+ redistribute_costs.append(
69
+ generate_redistribute_costs(input_strategy, input_tgt))
70
+
71
+ # Weight cannot be sharded, so always replicate it.
72
+ weight_tgt = DTensorSpec(
73
+ mesh=mesh,
74
+ placements=(Replicate(), ),
75
+ tensor_meta=weight_src.tensor_meta,
76
+ )
77
+ redistribute_costs.append(
78
+ generate_redistribute_costs(weight_strategy, weight_tgt))
79
+
80
+ strategy.strategies.append(
81
+ OpSpec(
82
+ output_specs=input_tgt,
83
+ input_specs=[input_tgt, weight_tgt],
84
+ redistribute_cost=redistribute_costs,
85
+ ))
86
+ return strategy
87
+
88
+
89
+ @register_op_strategy(ops.rms_norm_backward.default,
90
+ schema_info=RuntimeSchemaInfo(1))
91
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
92
+ mesh = op_schema.get_mesh_from_args()
93
+
94
+ assert len(op_schema.args_schema) == 4
95
+ (
96
+ output_grad_strategy,
97
+ input_strategy,
98
+ weight_strategy,
99
+ _, # eps
100
+ ) = op_schema.args_schema
101
+
102
+ assert isinstance(output_grad_strategy, OpStrategy)
103
+ assert isinstance(input_strategy, OpStrategy)
104
+ assert isinstance(weight_strategy, OpStrategy)
105
+
106
+ lengths = {
107
+ "output_grad": len(output_grad_strategy.strategies),
108
+ "input": len(input_strategy.strategies),
109
+ "weight": len(weight_strategy.strategies),
110
+ }
111
+
112
+ assert len(set(
113
+ lengths.values())) == 1, f"Strategies length mismatch {lengths}"
114
+
115
+ zipped = zip(
116
+ output_grad_strategy.strategies,
117
+ input_strategy.strategies,
118
+ weight_strategy.strategies,
119
+ )
120
+
121
+ last_dim = input_strategy.ndim - 1
122
+ strategy = OpStrategy([])
123
+ for output_grad, input, weight in zipped:
124
+ output_grad_src = output_grad.output_spec
125
+ input_src = input.output_spec
126
+ weight_src = weight.output_spec
127
+
128
+ assert isinstance(output_grad_src, DTensorSpec)
129
+ assert isinstance(input_src, DTensorSpec)
130
+ assert isinstance(weight_src, DTensorSpec)
131
+
132
+ redistribute_costs = []
133
+
134
+ # Output grad can be sharded in any dim except the last dim.
135
+ output_grad_tgt = DTensorSpec(
136
+ mesh=mesh,
137
+ placements=_replicate_dims_start_at(output_grad_src.placements,
138
+ last_dim),
139
+ tensor_meta=output_grad_src.tensor_meta,
140
+ )
141
+ redistribute_costs.append(
142
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
+
144
+ # Input must have the same sharding as output grad.
145
+ input_tgt = output_grad_tgt
146
+ redistribute_costs.append(
147
+ generate_redistribute_costs(input_strategy, input_tgt))
148
+
149
+ # Weight cannot be sharded, so always replicate it.
150
+ weight_tgt = DTensorSpec(
151
+ mesh=mesh,
152
+ placements=(Replicate(), ),
153
+ tensor_meta=weight_src.tensor_meta,
154
+ )
155
+ redistribute_costs.append(
156
+ generate_redistribute_costs(weight_strategy, weight_tgt))
157
+
158
+ strategy.strategies.append(
159
+ OpSpec(
160
+ output_specs=[input_tgt, weight_tgt],
161
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
+ redistribute_cost=redistribute_costs,
163
+ ))
164
+ return strategy
build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/{torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec9ea7edc8b27f7983e20d615ab470cef6b82975afc214becfddfd05a867a839
3
- size 8600336
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef6e4eb51daac20f0d7ed9825052ecca9d8451825784c87d58fa69092c145f35
3
+ size 8793008
build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5d3511410cdc288d2fafc500223ed2e625e360f50fa341809cf892fb2c822924
3
- size 8779000
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:caffcadbb99fbaa27e8a81d5ef508f2e1a798e7626d618c3cf5b0d387d2c8686
3
- size 4618624
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_e5e2eeb_dirty
3
- ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_e5e2eeb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_53ed492_dirty
3
+ ops = torch.ops._activation_53ed492_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_53ed492_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_fused_add_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.fused_add_rms_norm.default,
34
+ schema_info=RuntimeSchemaInfo(1))
35
+ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
36
+ mesh = op_schema.get_mesh_from_args()
37
+
38
+ assert len(op_schema.args_schema) == 4
39
+ (
40
+ input_strategy,
41
+ residual_strategy,
42
+ weight_strategy,
43
+ _, # eps
44
+ ) = op_schema.args_schema
45
+
46
+ assert isinstance(input_strategy, OpStrategy)
47
+ assert isinstance(residual_strategy, OpStrategy)
48
+ assert isinstance(weight_strategy, OpStrategy)
49
+
50
+ lengths = {
51
+ "input": len(input_strategy.strategies),
52
+ "residual": len(residual_strategy.strategies),
53
+ "weight": len(weight_strategy.strategies),
54
+ }
55
+ assert len(set(
56
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
57
+
58
+ last_dim = input_strategy.ndim - 1
59
+ strategy = OpStrategy([])
60
+ for input, residual, weight in zip(input_strategy.strategies,
61
+ residual_strategy.strategies,
62
+ weight_strategy.strategies):
63
+
64
+ input_src = input.output_spec
65
+ residual_src = residual.output_spec
66
+ weight_src = weight.output_spec
67
+
68
+ assert isinstance(input_src, DTensorSpec)
69
+ assert isinstance(residual_src, DTensorSpec)
70
+ assert isinstance(weight_src, DTensorSpec)
71
+
72
+ redistribute_costs = []
73
+
74
+ # Input can be sharded in any dim except the last dim.
75
+ input_tgt = DTensorSpec(
76
+ mesh=mesh,
77
+ placements=_replicate_dims_start_at(input_src.placements,
78
+ last_dim),
79
+ tensor_meta=input_src.tensor_meta,
80
+ )
81
+ redistribute_costs.append(
82
+ generate_redistribute_costs(input_strategy, input_tgt))
83
+
84
+ # Residual add must have the same sharding as input.
85
+ residual_tgt = input_tgt
86
+ redistribute_costs.append(
87
+ generate_redistribute_costs(residual_strategy, residual_tgt))
88
+
89
+ # Weight cannot be sharded, so always replicate it.
90
+ weight_tgt = DTensorSpec(
91
+ mesh=mesh,
92
+ placements=(Replicate(), ),
93
+ tensor_meta=weight_src.tensor_meta,
94
+ )
95
+ redistribute_costs.append(
96
+ generate_redistribute_costs(weight_strategy, weight_tgt))
97
+
98
+ strategy.strategies.append(
99
+ OpSpec(
100
+ output_specs=[input_tgt, input_tgt],
101
+ input_specs=[input_tgt, residual_tgt, weight_tgt],
102
+ redistribute_cost=redistribute_costs,
103
+ ))
104
+ return strategy
105
+
106
+
107
+ @register_op_strategy(ops.fused_add_rms_norm_backward.default,
108
+ schema_info=RuntimeSchemaInfo(2))
109
+ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
110
+ mesh = op_schema.get_mesh_from_args()
111
+
112
+ assert len(op_schema.args_schema) == 6
113
+ (
114
+ output_grad_strategy,
115
+ add_output_grad_strategy,
116
+ add_output_strategy,
117
+ weight_strategy,
118
+ _, # eps
119
+ need_input_grad, # need_input_grad
120
+ ) = op_schema.args_schema
121
+
122
+ assert isinstance(output_grad_strategy, OpStrategy)
123
+ assert isinstance(add_output_grad_strategy, OpStrategy)
124
+ assert isinstance(add_output_strategy, OpStrategy)
125
+ assert isinstance(weight_strategy, OpStrategy)
126
+
127
+ lengths = {
128
+ "output_grad": len(output_grad_strategy.strategies),
129
+ "add_output_grad": len(add_output_grad_strategy.strategies),
130
+ "add_output": len(add_output_strategy.strategies),
131
+ "weight": len(weight_strategy.strategies),
132
+ }
133
+ assert len(set(
134
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
135
+
136
+ zipped = zip(
137
+ output_grad_strategy.strategies,
138
+ add_output_grad_strategy.strategies,
139
+ add_output_strategy.strategies,
140
+ weight_strategy.strategies,
141
+ )
142
+
143
+ last_dim = output_grad_strategy.ndim - 1
144
+ strategy = OpStrategy([])
145
+ for output_grad, add_output_grad, add_output, weight in zipped:
146
+ output_grad_src = output_grad.output_spec
147
+ add_output_grad_src = add_output_grad.output_spec
148
+ add_output_src = add_output.output_spec
149
+ weight_src = weight.output_spec
150
+
151
+ assert isinstance(output_grad_src, DTensorSpec)
152
+ assert isinstance(add_output_grad_src, DTensorSpec)
153
+ assert isinstance(add_output_src, DTensorSpec)
154
+ assert isinstance(weight_src, DTensorSpec)
155
+
156
+ redistribute_costs = []
157
+
158
+ # output grad can be sharded in any dim except the last dim.
159
+ output_grad_tgt = DTensorSpec(
160
+ mesh=mesh,
161
+ placements=_replicate_dims_start_at(output_grad_src.placements,
162
+ last_dim),
163
+ tensor_meta=output_grad_src.tensor_meta,
164
+ )
165
+ redistribute_costs.append(
166
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
167
+
168
+ # add_output_grad must have the same sharding as output_grad.
169
+ add_output_grad_tgt = output_grad_tgt
170
+ redistribute_costs.append(
171
+ generate_redistribute_costs(add_output_grad_strategy,
172
+ add_output_grad_tgt))
173
+
174
+ # add_output must have the same sharding as output_grad.
175
+ add_output_tgt = output_grad_tgt
176
+ redistribute_costs.append(
177
+ generate_redistribute_costs(add_output_strategy, add_output_tgt))
178
+
179
+ # Weight cannot be sharded, so always replicate it.
180
+ weight_tgt = DTensorSpec(
181
+ mesh=mesh,
182
+ placements=(Replicate(), ),
183
+ tensor_meta=weight_src.tensor_meta,
184
+ )
185
+ redistribute_costs.append(
186
+ generate_redistribute_costs(weight_strategy, weight_tgt))
187
+
188
+ strategy.strategies.append(
189
+ OpSpec(
190
+ output_specs=[
191
+ output_grad_tgt if need_input_grad else None, weight_tgt
192
+ ],
193
+ input_specs=[
194
+ output_grad_tgt, add_output_grad_tgt, add_output_tgt,
195
+ weight_tgt
196
+ ],
197
+ redistribute_cost=redistribute_costs,
198
+ ))
199
+ return strategy
build/torch27-cxx11-cu126-x86_64-linux/activation/parallel_style.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard,
8
+ distribute_module, distribute_tensor)
9
+ from torch.distributed.tensor.parallel import SequenceParallel
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class ResidualSequenceParallel(SequenceParallel):
14
+ """ Consider the case where we have a residual connection across a sequence parallel layer."""
15
+
16
+ @staticmethod
17
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
18
+ input_tensor = inputs[0]
19
+ residual_tensor = inputs[1]
20
+
21
+ assert isinstance(input_tensor,
22
+ DTensor) == isinstance(residual_tensor, DTensor)
23
+ assert isinstance(input_tensor,
24
+ torch.Tensor) == isinstance(residual_tensor,
25
+ torch.Tensor)
26
+
27
+ if isinstance(input_tensor, DTensor):
28
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
29
+ if input_tensor.placements != sequence_sharding:
30
+ input_tensor = input_tensor.redistribute(
31
+ placements=sequence_sharding, async_op=True)
32
+ if residual_tensor.placements != sequence_sharding:
33
+ residual_tensor = residual_tensor.redistribute(
34
+ placements=sequence_sharding, async_op=True)
35
+ return input_tensor, residual_tensor
36
+
37
+ elif isinstance(input_tensor, torch.Tensor):
38
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
39
+ return DTensor.from_local(input_tensor,
40
+ device_mesh,
41
+ sequence_sharding,
42
+ run_check=False), DTensor.from_local(
43
+ residual_tensor,
44
+ device_mesh,
45
+ sequence_sharding,
46
+ run_check=False)
47
+ else:
48
+ raise ValueError(
49
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
50
+ )
build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
@@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
42
  # Note that forward, setup_context, and backward are @staticmethods
43
  @staticmethod
44
  def forward(input, residual, weight, eps):
45
- output = torch.empty_like(input)
46
- add_output = torch.empty_like(input)
47
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
48
- eps)
49
  return output, add_output
50
 
51
  @staticmethod
@@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
68
- grad = torch.empty_like(output_grad) if need_in or need_res else None
 
69
 
70
- weight_grad = torch.empty_like(
71
- weight) if ctx.needs_input_grad[2] else None
72
-
73
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
- weight, eps)
 
 
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
77
 
78
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
11
  # Note that forward, setup_context, and backward are @staticmethods
12
  @staticmethod
13
  def forward(input, weight, eps):
14
+ return ops.rms_norm(input, weight, eps)
 
 
15
 
16
  @staticmethod
17
  # inputs is a Tuple of all of the inputs passed to forward.
 
27
  input, weight = ctx.saved_tensors
28
  eps = ctx.eps
29
 
30
+ input_grad, weight_grad = ops.rms_norm_backward(
31
+ output_grad, input, weight, eps)
 
 
 
 
 
32
 
33
  return input_grad, weight_grad, None
34
 
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
 
65
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
66
+ output_grad,
67
+ add_output_grad,
68
+ add_output,
69
+ weight,
70
+ eps,
71
+ need_input_grad=need_in or need_res)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ @torch.library.register_fake(ops.rms_norm.default)
79
+ def rms_norm_abstract(x, weight, eps):
80
+ return torch.empty_like(x)
81
+
82
+
83
+ @torch.library.register_fake(ops.rms_norm_backward.default)
84
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
85
+ return torch.empty_like(x), torch.empty_like(weight)
86
+
87
+
88
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
89
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
90
+ return torch.empty_like(x), torch.empty_like(x)
91
+
92
+
93
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
94
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
95
+ add_output, weight, eps,
96
+ need_input_grad: bool):
97
+ return torch.empty_like(
98
+ output_grad) if need_input_grad else None, torch.empty_like(weight)
99
+
100
+
101
+ if version.parse(torch.__version__) >= version.parse("2.8"):
102
+ from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta
103
+ from .rms_norm_meta import register_rms_norm_meta
104
+ register_fused_add_rms_norm_meta()
105
+ register_rms_norm_meta()
build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
+ mesh = op_schema.get_mesh_from_args()
36
+
37
+ assert len(op_schema.args_schema) == 3
38
+ (
39
+ input_strategy,
40
+ weight_strategy,
41
+ _, # eps
42
+ ) = op_schema.args_schema
43
+
44
+ assert isinstance(input_strategy, OpStrategy)
45
+ assert isinstance(weight_strategy, OpStrategy)
46
+
47
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
48
+
49
+ last_dim = input_strategy.ndim - 1
50
+ strategy = OpStrategy([])
51
+ for input, weight in zip(input_strategy.strategies,
52
+ weight_strategy.strategies):
53
+ input_src = input.output_spec
54
+ weight_src = weight.output_spec
55
+
56
+ assert isinstance(input_src, DTensorSpec)
57
+ assert isinstance(weight_src, DTensorSpec)
58
+
59
+ redistribute_costs = []
60
+
61
+ # Input can be sharded in any dim except the last dim.
62
+ input_tgt = DTensorSpec(
63
+ mesh=mesh,
64
+ placements=_replicate_dims_start_at(input_src.placements,
65
+ last_dim),
66
+ tensor_meta=input_src.tensor_meta,
67
+ )
68
+ redistribute_costs.append(
69
+ generate_redistribute_costs(input_strategy, input_tgt))
70
+
71
+ # Weight cannot be sharded, so always replicate it.
72
+ weight_tgt = DTensorSpec(
73
+ mesh=mesh,
74
+ placements=(Replicate(), ),
75
+ tensor_meta=weight_src.tensor_meta,
76
+ )
77
+ redistribute_costs.append(
78
+ generate_redistribute_costs(weight_strategy, weight_tgt))
79
+
80
+ strategy.strategies.append(
81
+ OpSpec(
82
+ output_specs=input_tgt,
83
+ input_specs=[input_tgt, weight_tgt],
84
+ redistribute_cost=redistribute_costs,
85
+ ))
86
+ return strategy
87
+
88
+
89
+ @register_op_strategy(ops.rms_norm_backward.default,
90
+ schema_info=RuntimeSchemaInfo(1))
91
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
92
+ mesh = op_schema.get_mesh_from_args()
93
+
94
+ assert len(op_schema.args_schema) == 4
95
+ (
96
+ output_grad_strategy,
97
+ input_strategy,
98
+ weight_strategy,
99
+ _, # eps
100
+ ) = op_schema.args_schema
101
+
102
+ assert isinstance(output_grad_strategy, OpStrategy)
103
+ assert isinstance(input_strategy, OpStrategy)
104
+ assert isinstance(weight_strategy, OpStrategy)
105
+
106
+ lengths = {
107
+ "output_grad": len(output_grad_strategy.strategies),
108
+ "input": len(input_strategy.strategies),
109
+ "weight": len(weight_strategy.strategies),
110
+ }
111
+
112
+ assert len(set(
113
+ lengths.values())) == 1, f"Strategies length mismatch {lengths}"
114
+
115
+ zipped = zip(
116
+ output_grad_strategy.strategies,
117
+ input_strategy.strategies,
118
+ weight_strategy.strategies,
119
+ )
120
+
121
+ last_dim = input_strategy.ndim - 1
122
+ strategy = OpStrategy([])
123
+ for output_grad, input, weight in zipped:
124
+ output_grad_src = output_grad.output_spec
125
+ input_src = input.output_spec
126
+ weight_src = weight.output_spec
127
+
128
+ assert isinstance(output_grad_src, DTensorSpec)
129
+ assert isinstance(input_src, DTensorSpec)
130
+ assert isinstance(weight_src, DTensorSpec)
131
+
132
+ redistribute_costs = []
133
+
134
+ # Output grad can be sharded in any dim except the last dim.
135
+ output_grad_tgt = DTensorSpec(
136
+ mesh=mesh,
137
+ placements=_replicate_dims_start_at(output_grad_src.placements,
138
+ last_dim),
139
+ tensor_meta=output_grad_src.tensor_meta,
140
+ )
141
+ redistribute_costs.append(
142
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
+
144
+ # Input must have the same sharding as output grad.
145
+ input_tgt = output_grad_tgt
146
+ redistribute_costs.append(
147
+ generate_redistribute_costs(input_strategy, input_tgt))
148
+
149
+ # Weight cannot be sharded, so always replicate it.
150
+ weight_tgt = DTensorSpec(
151
+ mesh=mesh,
152
+ placements=(Replicate(), ),
153
+ tensor_meta=weight_src.tensor_meta,
154
+ )
155
+ redistribute_costs.append(
156
+ generate_redistribute_costs(weight_strategy, weight_tgt))
157
+
158
+ strategy.strategies.append(
159
+ OpSpec(
160
+ output_specs=[input_tgt, weight_tgt],
161
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
+ redistribute_cost=redistribute_costs,
163
+ ))
164
+ return strategy
build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0bf0d2ab5ff5520704e0b0c959b61d0043d360cfd4335950e69677873a87e436
3
- size 12792112
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0699647f4c0bfc57711e8488dfa3864e7cfdf9119fb743fdaafcb2cbd2cea2c
3
+ size 13836872
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:25efc9c32e4bd6609a8326025aad861cbf79b544893755fe44519c9df7224c40
3
- size 13818872
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3b7c6ece8e8d316c4cc5fe46b1cec4422b2f61e9bb7240af71a2b4a35975d8e6
3
- size 6676528
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_e5e2eeb_dirty
3
- ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_e5e2eeb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_53ed492_dirty
3
+ ops = torch.ops._activation_53ed492_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_53ed492_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_fused_add_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.fused_add_rms_norm.default,
34
+ schema_info=RuntimeSchemaInfo(1))
35
+ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
36
+ mesh = op_schema.get_mesh_from_args()
37
+
38
+ assert len(op_schema.args_schema) == 4
39
+ (
40
+ input_strategy,
41
+ residual_strategy,
42
+ weight_strategy,
43
+ _, # eps
44
+ ) = op_schema.args_schema
45
+
46
+ assert isinstance(input_strategy, OpStrategy)
47
+ assert isinstance(residual_strategy, OpStrategy)
48
+ assert isinstance(weight_strategy, OpStrategy)
49
+
50
+ lengths = {
51
+ "input": len(input_strategy.strategies),
52
+ "residual": len(residual_strategy.strategies),
53
+ "weight": len(weight_strategy.strategies),
54
+ }
55
+ assert len(set(
56
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
57
+
58
+ last_dim = input_strategy.ndim - 1
59
+ strategy = OpStrategy([])
60
+ for input, residual, weight in zip(input_strategy.strategies,
61
+ residual_strategy.strategies,
62
+ weight_strategy.strategies):
63
+
64
+ input_src = input.output_spec
65
+ residual_src = residual.output_spec
66
+ weight_src = weight.output_spec
67
+
68
+ assert isinstance(input_src, DTensorSpec)
69
+ assert isinstance(residual_src, DTensorSpec)
70
+ assert isinstance(weight_src, DTensorSpec)
71
+
72
+ redistribute_costs = []
73
+
74
+ # Input can be sharded in any dim except the last dim.
75
+ input_tgt = DTensorSpec(
76
+ mesh=mesh,
77
+ placements=_replicate_dims_start_at(input_src.placements,
78
+ last_dim),
79
+ tensor_meta=input_src.tensor_meta,
80
+ )
81
+ redistribute_costs.append(
82
+ generate_redistribute_costs(input_strategy, input_tgt))
83
+
84
+ # Residual add must have the same sharding as input.
85
+ residual_tgt = input_tgt
86
+ redistribute_costs.append(
87
+ generate_redistribute_costs(residual_strategy, residual_tgt))
88
+
89
+ # Weight cannot be sharded, so always replicate it.
90
+ weight_tgt = DTensorSpec(
91
+ mesh=mesh,
92
+ placements=(Replicate(), ),
93
+ tensor_meta=weight_src.tensor_meta,
94
+ )
95
+ redistribute_costs.append(
96
+ generate_redistribute_costs(weight_strategy, weight_tgt))
97
+
98
+ strategy.strategies.append(
99
+ OpSpec(
100
+ output_specs=[input_tgt, input_tgt],
101
+ input_specs=[input_tgt, residual_tgt, weight_tgt],
102
+ redistribute_cost=redistribute_costs,
103
+ ))
104
+ return strategy
105
+
106
+
107
+ @register_op_strategy(ops.fused_add_rms_norm_backward.default,
108
+ schema_info=RuntimeSchemaInfo(2))
109
+ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
110
+ mesh = op_schema.get_mesh_from_args()
111
+
112
+ assert len(op_schema.args_schema) == 6
113
+ (
114
+ output_grad_strategy,
115
+ add_output_grad_strategy,
116
+ add_output_strategy,
117
+ weight_strategy,
118
+ _, # eps
119
+ need_input_grad, # need_input_grad
120
+ ) = op_schema.args_schema
121
+
122
+ assert isinstance(output_grad_strategy, OpStrategy)
123
+ assert isinstance(add_output_grad_strategy, OpStrategy)
124
+ assert isinstance(add_output_strategy, OpStrategy)
125
+ assert isinstance(weight_strategy, OpStrategy)
126
+
127
+ lengths = {
128
+ "output_grad": len(output_grad_strategy.strategies),
129
+ "add_output_grad": len(add_output_grad_strategy.strategies),
130
+ "add_output": len(add_output_strategy.strategies),
131
+ "weight": len(weight_strategy.strategies),
132
+ }
133
+ assert len(set(
134
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
135
+
136
+ zipped = zip(
137
+ output_grad_strategy.strategies,
138
+ add_output_grad_strategy.strategies,
139
+ add_output_strategy.strategies,
140
+ weight_strategy.strategies,
141
+ )
142
+
143
+ last_dim = output_grad_strategy.ndim - 1
144
+ strategy = OpStrategy([])
145
+ for output_grad, add_output_grad, add_output, weight in zipped:
146
+ output_grad_src = output_grad.output_spec
147
+ add_output_grad_src = add_output_grad.output_spec
148
+ add_output_src = add_output.output_spec
149
+ weight_src = weight.output_spec
150
+
151
+ assert isinstance(output_grad_src, DTensorSpec)
152
+ assert isinstance(add_output_grad_src, DTensorSpec)
153
+ assert isinstance(add_output_src, DTensorSpec)
154
+ assert isinstance(weight_src, DTensorSpec)
155
+
156
+ redistribute_costs = []
157
+
158
+ # output grad can be sharded in any dim except the last dim.
159
+ output_grad_tgt = DTensorSpec(
160
+ mesh=mesh,
161
+ placements=_replicate_dims_start_at(output_grad_src.placements,
162
+ last_dim),
163
+ tensor_meta=output_grad_src.tensor_meta,
164
+ )
165
+ redistribute_costs.append(
166
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
167
+
168
+ # add_output_grad must have the same sharding as output_grad.
169
+ add_output_grad_tgt = output_grad_tgt
170
+ redistribute_costs.append(
171
+ generate_redistribute_costs(add_output_grad_strategy,
172
+ add_output_grad_tgt))
173
+
174
+ # add_output must have the same sharding as output_grad.
175
+ add_output_tgt = output_grad_tgt
176
+ redistribute_costs.append(
177
+ generate_redistribute_costs(add_output_strategy, add_output_tgt))
178
+
179
+ # Weight cannot be sharded, so always replicate it.
180
+ weight_tgt = DTensorSpec(
181
+ mesh=mesh,
182
+ placements=(Replicate(), ),
183
+ tensor_meta=weight_src.tensor_meta,
184
+ )
185
+ redistribute_costs.append(
186
+ generate_redistribute_costs(weight_strategy, weight_tgt))
187
+
188
+ strategy.strategies.append(
189
+ OpSpec(
190
+ output_specs=[
191
+ output_grad_tgt if need_input_grad else None, weight_tgt
192
+ ],
193
+ input_specs=[
194
+ output_grad_tgt, add_output_grad_tgt, add_output_tgt,
195
+ weight_tgt
196
+ ],
197
+ redistribute_cost=redistribute_costs,
198
+ ))
199
+ return strategy
build/torch27-cxx11-cu128-x86_64-linux/activation/parallel_style.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard,
8
+ distribute_module, distribute_tensor)
9
+ from torch.distributed.tensor.parallel import SequenceParallel
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class ResidualSequenceParallel(SequenceParallel):
14
+ """ Consider the case where we have a residual connection across a sequence parallel layer."""
15
+
16
+ @staticmethod
17
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
18
+ input_tensor = inputs[0]
19
+ residual_tensor = inputs[1]
20
+
21
+ assert isinstance(input_tensor,
22
+ DTensor) == isinstance(residual_tensor, DTensor)
23
+ assert isinstance(input_tensor,
24
+ torch.Tensor) == isinstance(residual_tensor,
25
+ torch.Tensor)
26
+
27
+ if isinstance(input_tensor, DTensor):
28
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
29
+ if input_tensor.placements != sequence_sharding:
30
+ input_tensor = input_tensor.redistribute(
31
+ placements=sequence_sharding, async_op=True)
32
+ if residual_tensor.placements != sequence_sharding:
33
+ residual_tensor = residual_tensor.redistribute(
34
+ placements=sequence_sharding, async_op=True)
35
+ return input_tensor, residual_tensor
36
+
37
+ elif isinstance(input_tensor, torch.Tensor):
38
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
39
+ return DTensor.from_local(input_tensor,
40
+ device_mesh,
41
+ sequence_sharding,
42
+ run_check=False), DTensor.from_local(
43
+ residual_tensor,
44
+ device_mesh,
45
+ sequence_sharding,
46
+ run_check=False)
47
+ else:
48
+ raise ValueError(
49
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
50
+ )
build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
@@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
42
  # Note that forward, setup_context, and backward are @staticmethods
43
  @staticmethod
44
  def forward(input, residual, weight, eps):
45
- output = torch.empty_like(input)
46
- add_output = torch.empty_like(input)
47
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
48
- eps)
49
  return output, add_output
50
 
51
  @staticmethod
@@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
68
- grad = torch.empty_like(output_grad) if need_in or need_res else None
 
69
 
70
- weight_grad = torch.empty_like(
71
- weight) if ctx.needs_input_grad[2] else None
72
-
73
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
- weight, eps)
 
 
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
77
 
78
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
11
  # Note that forward, setup_context, and backward are @staticmethods
12
  @staticmethod
13
  def forward(input, weight, eps):
14
+ return ops.rms_norm(input, weight, eps)
 
 
15
 
16
  @staticmethod
17
  # inputs is a Tuple of all of the inputs passed to forward.
 
27
  input, weight = ctx.saved_tensors
28
  eps = ctx.eps
29
 
30
+ input_grad, weight_grad = ops.rms_norm_backward(
31
+ output_grad, input, weight, eps)
 
 
 
 
 
32
 
33
  return input_grad, weight_grad, None
34
 
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
 
65
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
66
+ output_grad,
67
+ add_output_grad,
68
+ add_output,
69
+ weight,
70
+ eps,
71
+ need_input_grad=need_in or need_res)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ @torch.library.register_fake(ops.rms_norm.default)
79
+ def rms_norm_abstract(x, weight, eps):
80
+ return torch.empty_like(x)
81
+
82
+
83
+ @torch.library.register_fake(ops.rms_norm_backward.default)
84
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
85
+ return torch.empty_like(x), torch.empty_like(weight)
86
+
87
+
88
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
89
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
90
+ return torch.empty_like(x), torch.empty_like(x)
91
+
92
+
93
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
94
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
95
+ add_output, weight, eps,
96
+ need_input_grad: bool):
97
+ return torch.empty_like(
98
+ output_grad) if need_input_grad else None, torch.empty_like(weight)
99
+
100
+
101
+ if version.parse(torch.__version__) >= version.parse("2.8"):
102
+ from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta
103
+ from .rms_norm_meta import register_rms_norm_meta
104
+ register_fused_add_rms_norm_meta()
105
+ register_rms_norm_meta()
build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
+ mesh = op_schema.get_mesh_from_args()
36
+
37
+ assert len(op_schema.args_schema) == 3
38
+ (
39
+ input_strategy,
40
+ weight_strategy,
41
+ _, # eps
42
+ ) = op_schema.args_schema
43
+
44
+ assert isinstance(input_strategy, OpStrategy)
45
+ assert isinstance(weight_strategy, OpStrategy)
46
+
47
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
48
+
49
+ last_dim = input_strategy.ndim - 1
50
+ strategy = OpStrategy([])
51
+ for input, weight in zip(input_strategy.strategies,
52
+ weight_strategy.strategies):
53
+ input_src = input.output_spec
54
+ weight_src = weight.output_spec
55
+
56
+ assert isinstance(input_src, DTensorSpec)
57
+ assert isinstance(weight_src, DTensorSpec)
58
+
59
+ redistribute_costs = []
60
+
61
+ # Input can be sharded in any dim except the last dim.
62
+ input_tgt = DTensorSpec(
63
+ mesh=mesh,
64
+ placements=_replicate_dims_start_at(input_src.placements,
65
+ last_dim),
66
+ tensor_meta=input_src.tensor_meta,
67
+ )
68
+ redistribute_costs.append(
69
+ generate_redistribute_costs(input_strategy, input_tgt))
70
+
71
+ # Weight cannot be sharded, so always replicate it.
72
+ weight_tgt = DTensorSpec(
73
+ mesh=mesh,
74
+ placements=(Replicate(), ),
75
+ tensor_meta=weight_src.tensor_meta,
76
+ )
77
+ redistribute_costs.append(
78
+ generate_redistribute_costs(weight_strategy, weight_tgt))
79
+
80
+ strategy.strategies.append(
81
+ OpSpec(
82
+ output_specs=input_tgt,
83
+ input_specs=[input_tgt, weight_tgt],
84
+ redistribute_cost=redistribute_costs,
85
+ ))
86
+ return strategy
87
+
88
+
89
+ @register_op_strategy(ops.rms_norm_backward.default,
90
+ schema_info=RuntimeSchemaInfo(1))
91
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
92
+ mesh = op_schema.get_mesh_from_args()
93
+
94
+ assert len(op_schema.args_schema) == 4
95
+ (
96
+ output_grad_strategy,
97
+ input_strategy,
98
+ weight_strategy,
99
+ _, # eps
100
+ ) = op_schema.args_schema
101
+
102
+ assert isinstance(output_grad_strategy, OpStrategy)
103
+ assert isinstance(input_strategy, OpStrategy)
104
+ assert isinstance(weight_strategy, OpStrategy)
105
+
106
+ lengths = {
107
+ "output_grad": len(output_grad_strategy.strategies),
108
+ "input": len(input_strategy.strategies),
109
+ "weight": len(weight_strategy.strategies),
110
+ }
111
+
112
+ assert len(set(
113
+ lengths.values())) == 1, f"Strategies length mismatch {lengths}"
114
+
115
+ zipped = zip(
116
+ output_grad_strategy.strategies,
117
+ input_strategy.strategies,
118
+ weight_strategy.strategies,
119
+ )
120
+
121
+ last_dim = input_strategy.ndim - 1
122
+ strategy = OpStrategy([])
123
+ for output_grad, input, weight in zipped:
124
+ output_grad_src = output_grad.output_spec
125
+ input_src = input.output_spec
126
+ weight_src = weight.output_spec
127
+
128
+ assert isinstance(output_grad_src, DTensorSpec)
129
+ assert isinstance(input_src, DTensorSpec)
130
+ assert isinstance(weight_src, DTensorSpec)
131
+
132
+ redistribute_costs = []
133
+
134
+ # Output grad can be sharded in any dim except the last dim.
135
+ output_grad_tgt = DTensorSpec(
136
+ mesh=mesh,
137
+ placements=_replicate_dims_start_at(output_grad_src.placements,
138
+ last_dim),
139
+ tensor_meta=output_grad_src.tensor_meta,
140
+ )
141
+ redistribute_costs.append(
142
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
+
144
+ # Input must have the same sharding as output grad.
145
+ input_tgt = output_grad_tgt
146
+ redistribute_costs.append(
147
+ generate_redistribute_costs(input_strategy, input_tgt))
148
+
149
+ # Weight cannot be sharded, so always replicate it.
150
+ weight_tgt = DTensorSpec(
151
+ mesh=mesh,
152
+ placements=(Replicate(), ),
153
+ tensor_meta=weight_src.tensor_meta,
154
+ )
155
+ redistribute_costs.append(
156
+ generate_redistribute_costs(weight_strategy, weight_tgt))
157
+
158
+ strategy.strategies.append(
159
+ OpSpec(
160
+ output_specs=[input_tgt, weight_tgt],
161
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
+ redistribute_cost=redistribute_costs,
163
+ ))
164
+ return strategy
build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:640322a8fac8fd9d8e9f195a3034c4ee0f81ee1acf897fd7c482a84ce47a1bec
3
- size 4160688
 
 
 
 
build/{torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd84c828d4c15e96d65d6c8f0eb7a945ee8167d92e978b2ebce03eeaf41e7fce
3
- size 4405112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d973bad96565705f9e27514a9dbfb37343d0220da4a3ae7156b1cf6a27813643
3
+ size 2773952
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c80d05690547f2842d416ebb85c9f830370373bc7e6c54ba08eec61b3690280f
3
- size 4386744
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4be173820e2a4bf4b6b8de6b63faf6544b599d9b0583f650a940adaef4a048b3
3
- size 2899184
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_e5e2eeb_dirty
3
- ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_e5e2eeb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_53ed492_dirty
3
+ ops = torch.ops._activation_53ed492_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_53ed492_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_fused_add_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.fused_add_rms_norm.default,
34
+ schema_info=RuntimeSchemaInfo(1))
35
+ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
36
+ mesh = op_schema.get_mesh_from_args()
37
+
38
+ assert len(op_schema.args_schema) == 4
39
+ (
40
+ input_strategy,
41
+ residual_strategy,
42
+ weight_strategy,
43
+ _, # eps
44
+ ) = op_schema.args_schema
45
+
46
+ assert isinstance(input_strategy, OpStrategy)
47
+ assert isinstance(residual_strategy, OpStrategy)
48
+ assert isinstance(weight_strategy, OpStrategy)
49
+
50
+ lengths = {
51
+ "input": len(input_strategy.strategies),
52
+ "residual": len(residual_strategy.strategies),
53
+ "weight": len(weight_strategy.strategies),
54
+ }
55
+ assert len(set(
56
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
57
+
58
+ last_dim = input_strategy.ndim - 1
59
+ strategy = OpStrategy([])
60
+ for input, residual, weight in zip(input_strategy.strategies,
61
+ residual_strategy.strategies,
62
+ weight_strategy.strategies):
63
+
64
+ input_src = input.output_spec
65
+ residual_src = residual.output_spec
66
+ weight_src = weight.output_spec
67
+
68
+ assert isinstance(input_src, DTensorSpec)
69
+ assert isinstance(residual_src, DTensorSpec)
70
+ assert isinstance(weight_src, DTensorSpec)
71
+
72
+ redistribute_costs = []
73
+
74
+ # Input can be sharded in any dim except the last dim.
75
+ input_tgt = DTensorSpec(
76
+ mesh=mesh,
77
+ placements=_replicate_dims_start_at(input_src.placements,
78
+ last_dim),
79
+ tensor_meta=input_src.tensor_meta,
80
+ )
81
+ redistribute_costs.append(
82
+ generate_redistribute_costs(input_strategy, input_tgt))
83
+
84
+ # Residual add must have the same sharding as input.
85
+ residual_tgt = input_tgt
86
+ redistribute_costs.append(
87
+ generate_redistribute_costs(residual_strategy, residual_tgt))
88
+
89
+ # Weight cannot be sharded, so always replicate it.
90
+ weight_tgt = DTensorSpec(
91
+ mesh=mesh,
92
+ placements=(Replicate(), ),
93
+ tensor_meta=weight_src.tensor_meta,
94
+ )
95
+ redistribute_costs.append(
96
+ generate_redistribute_costs(weight_strategy, weight_tgt))
97
+
98
+ strategy.strategies.append(
99
+ OpSpec(
100
+ output_specs=[input_tgt, input_tgt],
101
+ input_specs=[input_tgt, residual_tgt, weight_tgt],
102
+ redistribute_cost=redistribute_costs,
103
+ ))
104
+ return strategy
105
+
106
+
107
+ @register_op_strategy(ops.fused_add_rms_norm_backward.default,
108
+ schema_info=RuntimeSchemaInfo(2))
109
+ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
110
+ mesh = op_schema.get_mesh_from_args()
111
+
112
+ assert len(op_schema.args_schema) == 6
113
+ (
114
+ output_grad_strategy,
115
+ add_output_grad_strategy,
116
+ add_output_strategy,
117
+ weight_strategy,
118
+ _, # eps
119
+ need_input_grad, # need_input_grad
120
+ ) = op_schema.args_schema
121
+
122
+ assert isinstance(output_grad_strategy, OpStrategy)
123
+ assert isinstance(add_output_grad_strategy, OpStrategy)
124
+ assert isinstance(add_output_strategy, OpStrategy)
125
+ assert isinstance(weight_strategy, OpStrategy)
126
+
127
+ lengths = {
128
+ "output_grad": len(output_grad_strategy.strategies),
129
+ "add_output_grad": len(add_output_grad_strategy.strategies),
130
+ "add_output": len(add_output_strategy.strategies),
131
+ "weight": len(weight_strategy.strategies),
132
+ }
133
+ assert len(set(
134
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
135
+
136
+ zipped = zip(
137
+ output_grad_strategy.strategies,
138
+ add_output_grad_strategy.strategies,
139
+ add_output_strategy.strategies,
140
+ weight_strategy.strategies,
141
+ )
142
+
143
+ last_dim = output_grad_strategy.ndim - 1
144
+ strategy = OpStrategy([])
145
+ for output_grad, add_output_grad, add_output, weight in zipped:
146
+ output_grad_src = output_grad.output_spec
147
+ add_output_grad_src = add_output_grad.output_spec
148
+ add_output_src = add_output.output_spec
149
+ weight_src = weight.output_spec
150
+
151
+ assert isinstance(output_grad_src, DTensorSpec)
152
+ assert isinstance(add_output_grad_src, DTensorSpec)
153
+ assert isinstance(add_output_src, DTensorSpec)
154
+ assert isinstance(weight_src, DTensorSpec)
155
+
156
+ redistribute_costs = []
157
+
158
+ # output grad can be sharded in any dim except the last dim.
159
+ output_grad_tgt = DTensorSpec(
160
+ mesh=mesh,
161
+ placements=_replicate_dims_start_at(output_grad_src.placements,
162
+ last_dim),
163
+ tensor_meta=output_grad_src.tensor_meta,
164
+ )
165
+ redistribute_costs.append(
166
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
167
+
168
+ # add_output_grad must have the same sharding as output_grad.
169
+ add_output_grad_tgt = output_grad_tgt
170
+ redistribute_costs.append(
171
+ generate_redistribute_costs(add_output_grad_strategy,
172
+ add_output_grad_tgt))
173
+
174
+ # add_output must have the same sharding as output_grad.
175
+ add_output_tgt = output_grad_tgt
176
+ redistribute_costs.append(
177
+ generate_redistribute_costs(add_output_strategy, add_output_tgt))
178
+
179
+ # Weight cannot be sharded, so always replicate it.
180
+ weight_tgt = DTensorSpec(
181
+ mesh=mesh,
182
+ placements=(Replicate(), ),
183
+ tensor_meta=weight_src.tensor_meta,
184
+ )
185
+ redistribute_costs.append(
186
+ generate_redistribute_costs(weight_strategy, weight_tgt))
187
+
188
+ strategy.strategies.append(
189
+ OpSpec(
190
+ output_specs=[
191
+ output_grad_tgt if need_input_grad else None, weight_tgt
192
+ ],
193
+ input_specs=[
194
+ output_grad_tgt, add_output_grad_tgt, add_output_tgt,
195
+ weight_tgt
196
+ ],
197
+ redistribute_cost=redistribute_costs,
198
+ ))
199
+ return strategy
build/torch27-cxx11-rocm63-x86_64-linux/activation/parallel_style.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard,
8
+ distribute_module, distribute_tensor)
9
+ from torch.distributed.tensor.parallel import SequenceParallel
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class ResidualSequenceParallel(SequenceParallel):
14
+ """ Consider the case where we have a residual connection across a sequence parallel layer."""
15
+
16
+ @staticmethod
17
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
18
+ input_tensor = inputs[0]
19
+ residual_tensor = inputs[1]
20
+
21
+ assert isinstance(input_tensor,
22
+ DTensor) == isinstance(residual_tensor, DTensor)
23
+ assert isinstance(input_tensor,
24
+ torch.Tensor) == isinstance(residual_tensor,
25
+ torch.Tensor)
26
+
27
+ if isinstance(input_tensor, DTensor):
28
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
29
+ if input_tensor.placements != sequence_sharding:
30
+ input_tensor = input_tensor.redistribute(
31
+ placements=sequence_sharding, async_op=True)
32
+ if residual_tensor.placements != sequence_sharding:
33
+ residual_tensor = residual_tensor.redistribute(
34
+ placements=sequence_sharding, async_op=True)
35
+ return input_tensor, residual_tensor
36
+
37
+ elif isinstance(input_tensor, torch.Tensor):
38
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
39
+ return DTensor.from_local(input_tensor,
40
+ device_mesh,
41
+ sequence_sharding,
42
+ run_check=False), DTensor.from_local(
43
+ residual_tensor,
44
+ device_mesh,
45
+ sequence_sharding,
46
+ run_check=False)
47
+ else:
48
+ raise ValueError(
49
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
50
+ )
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
@@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
42
  # Note that forward, setup_context, and backward are @staticmethods
43
  @staticmethod
44
  def forward(input, residual, weight, eps):
45
- output = torch.empty_like(input)
46
- add_output = torch.empty_like(input)
47
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
48
- eps)
49
  return output, add_output
50
 
51
  @staticmethod
@@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
68
- grad = torch.empty_like(output_grad) if need_in or need_res else None
 
69
 
70
- weight_grad = torch.empty_like(
71
- weight) if ctx.needs_input_grad[2] else None
72
-
73
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
- weight, eps)
 
 
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
77
 
78
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
11
  # Note that forward, setup_context, and backward are @staticmethods
12
  @staticmethod
13
  def forward(input, weight, eps):
14
+ return ops.rms_norm(input, weight, eps)
 
 
15
 
16
  @staticmethod
17
  # inputs is a Tuple of all of the inputs passed to forward.
 
27
  input, weight = ctx.saved_tensors
28
  eps = ctx.eps
29
 
30
+ input_grad, weight_grad = ops.rms_norm_backward(
31
+ output_grad, input, weight, eps)
 
 
 
 
 
32
 
33
  return input_grad, weight_grad, None
34
 
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
 
65
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
66
+ output_grad,
67
+ add_output_grad,
68
+ add_output,
69
+ weight,
70
+ eps,
71
+ need_input_grad=need_in or need_res)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ @torch.library.register_fake(ops.rms_norm.default)
79
+ def rms_norm_abstract(x, weight, eps):
80
+ return torch.empty_like(x)
81
+
82
+
83
+ @torch.library.register_fake(ops.rms_norm_backward.default)
84
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
85
+ return torch.empty_like(x), torch.empty_like(weight)
86
+
87
+
88
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
89
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
90
+ return torch.empty_like(x), torch.empty_like(x)
91
+
92
+
93
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
94
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
95
+ add_output, weight, eps,
96
+ need_input_grad: bool):
97
+ return torch.empty_like(
98
+ output_grad) if need_input_grad else None, torch.empty_like(weight)
99
+
100
+
101
+ if version.parse(torch.__version__) >= version.parse("2.8"):
102
+ from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta
103
+ from .rms_norm_meta import register_rms_norm_meta
104
+ register_fused_add_rms_norm_meta()
105
+ register_rms_norm_meta()
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
+ mesh = op_schema.get_mesh_from_args()
36
+
37
+ assert len(op_schema.args_schema) == 3
38
+ (
39
+ input_strategy,
40
+ weight_strategy,
41
+ _, # eps
42
+ ) = op_schema.args_schema
43
+
44
+ assert isinstance(input_strategy, OpStrategy)
45
+ assert isinstance(weight_strategy, OpStrategy)
46
+
47
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
48
+
49
+ last_dim = input_strategy.ndim - 1
50
+ strategy = OpStrategy([])
51
+ for input, weight in zip(input_strategy.strategies,
52
+ weight_strategy.strategies):
53
+ input_src = input.output_spec
54
+ weight_src = weight.output_spec
55
+
56
+ assert isinstance(input_src, DTensorSpec)
57
+ assert isinstance(weight_src, DTensorSpec)
58
+
59
+ redistribute_costs = []
60
+
61
+ # Input can be sharded in any dim except the last dim.
62
+ input_tgt = DTensorSpec(
63
+ mesh=mesh,
64
+ placements=_replicate_dims_start_at(input_src.placements,
65
+ last_dim),
66
+ tensor_meta=input_src.tensor_meta,
67
+ )
68
+ redistribute_costs.append(
69
+ generate_redistribute_costs(input_strategy, input_tgt))
70
+
71
+ # Weight cannot be sharded, so always replicate it.
72
+ weight_tgt = DTensorSpec(
73
+ mesh=mesh,
74
+ placements=(Replicate(), ),
75
+ tensor_meta=weight_src.tensor_meta,
76
+ )
77
+ redistribute_costs.append(
78
+ generate_redistribute_costs(weight_strategy, weight_tgt))
79
+
80
+ strategy.strategies.append(
81
+ OpSpec(
82
+ output_specs=input_tgt,
83
+ input_specs=[input_tgt, weight_tgt],
84
+ redistribute_cost=redistribute_costs,
85
+ ))
86
+ return strategy
87
+
88
+
89
+ @register_op_strategy(ops.rms_norm_backward.default,
90
+ schema_info=RuntimeSchemaInfo(1))
91
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
92
+ mesh = op_schema.get_mesh_from_args()
93
+
94
+ assert len(op_schema.args_schema) == 4
95
+ (
96
+ output_grad_strategy,
97
+ input_strategy,
98
+ weight_strategy,
99
+ _, # eps
100
+ ) = op_schema.args_schema
101
+
102
+ assert isinstance(output_grad_strategy, OpStrategy)
103
+ assert isinstance(input_strategy, OpStrategy)
104
+ assert isinstance(weight_strategy, OpStrategy)
105
+
106
+ lengths = {
107
+ "output_grad": len(output_grad_strategy.strategies),
108
+ "input": len(input_strategy.strategies),
109
+ "weight": len(weight_strategy.strategies),
110
+ }
111
+
112
+ assert len(set(
113
+ lengths.values())) == 1, f"Strategies length mismatch {lengths}"
114
+
115
+ zipped = zip(
116
+ output_grad_strategy.strategies,
117
+ input_strategy.strategies,
118
+ weight_strategy.strategies,
119
+ )
120
+
121
+ last_dim = input_strategy.ndim - 1
122
+ strategy = OpStrategy([])
123
+ for output_grad, input, weight in zipped:
124
+ output_grad_src = output_grad.output_spec
125
+ input_src = input.output_spec
126
+ weight_src = weight.output_spec
127
+
128
+ assert isinstance(output_grad_src, DTensorSpec)
129
+ assert isinstance(input_src, DTensorSpec)
130
+ assert isinstance(weight_src, DTensorSpec)
131
+
132
+ redistribute_costs = []
133
+
134
+ # Output grad can be sharded in any dim except the last dim.
135
+ output_grad_tgt = DTensorSpec(
136
+ mesh=mesh,
137
+ placements=_replicate_dims_start_at(output_grad_src.placements,
138
+ last_dim),
139
+ tensor_meta=output_grad_src.tensor_meta,
140
+ )
141
+ redistribute_costs.append(
142
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
+
144
+ # Input must have the same sharding as output grad.
145
+ input_tgt = output_grad_tgt
146
+ redistribute_costs.append(
147
+ generate_redistribute_costs(input_strategy, input_tgt))
148
+
149
+ # Weight cannot be sharded, so always replicate it.
150
+ weight_tgt = DTensorSpec(
151
+ mesh=mesh,
152
+ placements=(Replicate(), ),
153
+ tensor_meta=weight_src.tensor_meta,
154
+ )
155
+ redistribute_costs.append(
156
+ generate_redistribute_costs(weight_strategy, weight_tgt))
157
+
158
+ strategy.strategies.append(
159
+ OpSpec(
160
+ output_specs=[input_tgt, weight_tgt],
161
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
+ redistribute_cost=redistribute_costs,
163
+ ))
164
+ return strategy
build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1768d8d5072ac06d937cb5332988c6b3bfaa191f72d1369a22d2c577e9a3bca2
3
- size 8215280
 
 
 
 
build/{torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so → torch28-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:74d4955271509451b946495da75f69a0f978e7258b8303fe3c077e585c0d3e6a
3
- size 8272456
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c301db3d37625ebf0cecf016948ec18fbeddb497acca8c870d2d8eff0a1d1203
3
+ size 8735952
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:440f5c17a7ddaf73c506bbc84fd1405e2e188b8ceaf4977910608be6b91e89bf
3
- size 8730200
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb222449350310f90f7271f34fcf9052c9eec28021fee0348130a8f239a97bf4
3
- size 4571976
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_e5e2eeb_dirty
3
- ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_e5e2eeb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_53ed492_dirty
3
+ ops = torch.ops._activation_53ed492_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_53ed492_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_fused_add_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.fused_add_rms_norm.default,
34
+ schema_info=RuntimeSchemaInfo(1))
35
+ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
36
+ mesh = op_schema.get_mesh_from_args()
37
+
38
+ assert len(op_schema.args_schema) == 4
39
+ (
40
+ input_strategy,
41
+ residual_strategy,
42
+ weight_strategy,
43
+ _, # eps
44
+ ) = op_schema.args_schema
45
+
46
+ assert isinstance(input_strategy, OpStrategy)
47
+ assert isinstance(residual_strategy, OpStrategy)
48
+ assert isinstance(weight_strategy, OpStrategy)
49
+
50
+ lengths = {
51
+ "input": len(input_strategy.strategies),
52
+ "residual": len(residual_strategy.strategies),
53
+ "weight": len(weight_strategy.strategies),
54
+ }
55
+ assert len(set(
56
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
57
+
58
+ last_dim = input_strategy.ndim - 1
59
+ strategy = OpStrategy([])
60
+ for input, residual, weight in zip(input_strategy.strategies,
61
+ residual_strategy.strategies,
62
+ weight_strategy.strategies):
63
+
64
+ input_src = input.output_spec
65
+ residual_src = residual.output_spec
66
+ weight_src = weight.output_spec
67
+
68
+ assert isinstance(input_src, DTensorSpec)
69
+ assert isinstance(residual_src, DTensorSpec)
70
+ assert isinstance(weight_src, DTensorSpec)
71
+
72
+ redistribute_costs = []
73
+
74
+ # Input can be sharded in any dim except the last dim.
75
+ input_tgt = DTensorSpec(
76
+ mesh=mesh,
77
+ placements=_replicate_dims_start_at(input_src.placements,
78
+ last_dim),
79
+ tensor_meta=input_src.tensor_meta,
80
+ )
81
+ redistribute_costs.append(
82
+ generate_redistribute_costs(input_strategy, input_tgt))
83
+
84
+ # Residual add must have the same sharding as input.
85
+ residual_tgt = input_tgt
86
+ redistribute_costs.append(
87
+ generate_redistribute_costs(residual_strategy, residual_tgt))
88
+
89
+ # Weight cannot be sharded, so always replicate it.
90
+ weight_tgt = DTensorSpec(
91
+ mesh=mesh,
92
+ placements=(Replicate(), ),
93
+ tensor_meta=weight_src.tensor_meta,
94
+ )
95
+ redistribute_costs.append(
96
+ generate_redistribute_costs(weight_strategy, weight_tgt))
97
+
98
+ strategy.strategies.append(
99
+ OpSpec(
100
+ output_specs=[input_tgt, input_tgt],
101
+ input_specs=[input_tgt, residual_tgt, weight_tgt],
102
+ redistribute_cost=redistribute_costs,
103
+ ))
104
+ return strategy
105
+
106
+
107
+ @register_op_strategy(ops.fused_add_rms_norm_backward.default,
108
+ schema_info=RuntimeSchemaInfo(2))
109
+ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
110
+ mesh = op_schema.get_mesh_from_args()
111
+
112
+ assert len(op_schema.args_schema) == 6
113
+ (
114
+ output_grad_strategy,
115
+ add_output_grad_strategy,
116
+ add_output_strategy,
117
+ weight_strategy,
118
+ _, # eps
119
+ need_input_grad, # need_input_grad
120
+ ) = op_schema.args_schema
121
+
122
+ assert isinstance(output_grad_strategy, OpStrategy)
123
+ assert isinstance(add_output_grad_strategy, OpStrategy)
124
+ assert isinstance(add_output_strategy, OpStrategy)
125
+ assert isinstance(weight_strategy, OpStrategy)
126
+
127
+ lengths = {
128
+ "output_grad": len(output_grad_strategy.strategies),
129
+ "add_output_grad": len(add_output_grad_strategy.strategies),
130
+ "add_output": len(add_output_strategy.strategies),
131
+ "weight": len(weight_strategy.strategies),
132
+ }
133
+ assert len(set(
134
+ lengths.values())) == 1, f"Strategy length mismatch: {lengths}"
135
+
136
+ zipped = zip(
137
+ output_grad_strategy.strategies,
138
+ add_output_grad_strategy.strategies,
139
+ add_output_strategy.strategies,
140
+ weight_strategy.strategies,
141
+ )
142
+
143
+ last_dim = output_grad_strategy.ndim - 1
144
+ strategy = OpStrategy([])
145
+ for output_grad, add_output_grad, add_output, weight in zipped:
146
+ output_grad_src = output_grad.output_spec
147
+ add_output_grad_src = add_output_grad.output_spec
148
+ add_output_src = add_output.output_spec
149
+ weight_src = weight.output_spec
150
+
151
+ assert isinstance(output_grad_src, DTensorSpec)
152
+ assert isinstance(add_output_grad_src, DTensorSpec)
153
+ assert isinstance(add_output_src, DTensorSpec)
154
+ assert isinstance(weight_src, DTensorSpec)
155
+
156
+ redistribute_costs = []
157
+
158
+ # output grad can be sharded in any dim except the last dim.
159
+ output_grad_tgt = DTensorSpec(
160
+ mesh=mesh,
161
+ placements=_replicate_dims_start_at(output_grad_src.placements,
162
+ last_dim),
163
+ tensor_meta=output_grad_src.tensor_meta,
164
+ )
165
+ redistribute_costs.append(
166
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
167
+
168
+ # add_output_grad must have the same sharding as output_grad.
169
+ add_output_grad_tgt = output_grad_tgt
170
+ redistribute_costs.append(
171
+ generate_redistribute_costs(add_output_grad_strategy,
172
+ add_output_grad_tgt))
173
+
174
+ # add_output must have the same sharding as output_grad.
175
+ add_output_tgt = output_grad_tgt
176
+ redistribute_costs.append(
177
+ generate_redistribute_costs(add_output_strategy, add_output_tgt))
178
+
179
+ # Weight cannot be sharded, so always replicate it.
180
+ weight_tgt = DTensorSpec(
181
+ mesh=mesh,
182
+ placements=(Replicate(), ),
183
+ tensor_meta=weight_src.tensor_meta,
184
+ )
185
+ redistribute_costs.append(
186
+ generate_redistribute_costs(weight_strategy, weight_tgt))
187
+
188
+ strategy.strategies.append(
189
+ OpSpec(
190
+ output_specs=[
191
+ output_grad_tgt if need_input_grad else None, weight_tgt
192
+ ],
193
+ input_specs=[
194
+ output_grad_tgt, add_output_grad_tgt, add_output_tgt,
195
+ weight_tgt
196
+ ],
197
+ redistribute_cost=redistribute_costs,
198
+ ))
199
+ return strategy
build/torch28-cxx11-cu126-x86_64-linux/activation/parallel_style.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard,
8
+ distribute_module, distribute_tensor)
9
+ from torch.distributed.tensor.parallel import SequenceParallel
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class ResidualSequenceParallel(SequenceParallel):
14
+ """ Consider the case where we have a residual connection across a sequence parallel layer."""
15
+
16
+ @staticmethod
17
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
18
+ input_tensor = inputs[0]
19
+ residual_tensor = inputs[1]
20
+
21
+ assert isinstance(input_tensor,
22
+ DTensor) == isinstance(residual_tensor, DTensor)
23
+ assert isinstance(input_tensor,
24
+ torch.Tensor) == isinstance(residual_tensor,
25
+ torch.Tensor)
26
+
27
+ if isinstance(input_tensor, DTensor):
28
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
29
+ if input_tensor.placements != sequence_sharding:
30
+ input_tensor = input_tensor.redistribute(
31
+ placements=sequence_sharding, async_op=True)
32
+ if residual_tensor.placements != sequence_sharding:
33
+ residual_tensor = residual_tensor.redistribute(
34
+ placements=sequence_sharding, async_op=True)
35
+ return input_tensor, residual_tensor
36
+
37
+ elif isinstance(input_tensor, torch.Tensor):
38
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
39
+ return DTensor.from_local(input_tensor,
40
+ device_mesh,
41
+ sequence_sharding,
42
+ run_check=False), DTensor.from_local(
43
+ residual_tensor,
44
+ device_mesh,
45
+ sequence_sharding,
46
+ run_check=False)
47
+ else:
48
+ raise ValueError(
49
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
50
+ )
build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
@@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
42
  # Note that forward, setup_context, and backward are @staticmethods
43
  @staticmethod
44
  def forward(input, residual, weight, eps):
45
- output = torch.empty_like(input)
46
- add_output = torch.empty_like(input)
47
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
48
- eps)
49
  return output, add_output
50
 
51
  @staticmethod
@@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
68
- grad = torch.empty_like(output_grad) if need_in or need_res else None
 
69
 
70
- weight_grad = torch.empty_like(
71
- weight) if ctx.needs_input_grad[2] else None
72
-
73
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
- weight, eps)
 
 
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
77
 
78
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
11
  # Note that forward, setup_context, and backward are @staticmethods
12
  @staticmethod
13
  def forward(input, weight, eps):
14
+ return ops.rms_norm(input, weight, eps)
 
 
15
 
16
  @staticmethod
17
  # inputs is a Tuple of all of the inputs passed to forward.
 
27
  input, weight = ctx.saved_tensors
28
  eps = ctx.eps
29
 
30
+ input_grad, weight_grad = ops.rms_norm_backward(
31
+ output_grad, input, weight, eps)
 
 
 
 
 
32
 
33
  return input_grad, weight_grad, None
34
 
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
 
65
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
66
+ output_grad,
67
+ add_output_grad,
68
+ add_output,
69
+ weight,
70
+ eps,
71
+ need_input_grad=need_in or need_res)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ @torch.library.register_fake(ops.rms_norm.default)
79
+ def rms_norm_abstract(x, weight, eps):
80
+ return torch.empty_like(x)
81
+
82
+
83
+ @torch.library.register_fake(ops.rms_norm_backward.default)
84
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
85
+ return torch.empty_like(x), torch.empty_like(weight)
86
+
87
+
88
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
89
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
90
+ return torch.empty_like(x), torch.empty_like(x)
91
+
92
+
93
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
94
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
95
+ add_output, weight, eps,
96
+ need_input_grad: bool):
97
+ return torch.empty_like(
98
+ output_grad) if need_input_grad else None, torch.empty_like(weight)
99
+
100
+
101
+ if version.parse(torch.__version__) >= version.parse("2.8"):
102
+ from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta
103
+ from .rms_norm_meta import register_rms_norm_meta
104
+ register_fused_add_rms_norm_meta()
105
+ register_rms_norm_meta()
build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
+ RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
+ register_op_strategy)
9
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
10
+ Shard)
11
+
12
+ from ._ops import ops
13
+
14
+
15
+ def register_rms_norm_meta():
16
+ """Dummy function to register the meta functions.
17
+ Registration happens at import time by the decorators below.
18
+ """
19
+ pass
20
+
21
+
22
+ def _replicate_dims_start_at(placements: Sequence[Placement],
23
+ start_dim: int = 0) -> tuple[Placement, ...]:
24
+ new_placements: list[Placement] = []
25
+ for p in placements:
26
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
+ new_placements.append(Replicate()) # make it replicate
28
+ else:
29
+ new_placements.append(p) # keep the placement
30
+ return tuple(new_placements)
31
+
32
+
33
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
+ mesh = op_schema.get_mesh_from_args()
36
+
37
+ assert len(op_schema.args_schema) == 3
38
+ (
39
+ input_strategy,
40
+ weight_strategy,
41
+ _, # eps
42
+ ) = op_schema.args_schema
43
+
44
+ assert isinstance(input_strategy, OpStrategy)
45
+ assert isinstance(weight_strategy, OpStrategy)
46
+
47
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
48
+
49
+ last_dim = input_strategy.ndim - 1
50
+ strategy = OpStrategy([])
51
+ for input, weight in zip(input_strategy.strategies,
52
+ weight_strategy.strategies):
53
+ input_src = input.output_spec
54
+ weight_src = weight.output_spec
55
+
56
+ assert isinstance(input_src, DTensorSpec)
57
+ assert isinstance(weight_src, DTensorSpec)
58
+
59
+ redistribute_costs = []
60
+
61
+ # Input can be sharded in any dim except the last dim.
62
+ input_tgt = DTensorSpec(
63
+ mesh=mesh,
64
+ placements=_replicate_dims_start_at(input_src.placements,
65
+ last_dim),
66
+ tensor_meta=input_src.tensor_meta,
67
+ )
68
+ redistribute_costs.append(
69
+ generate_redistribute_costs(input_strategy, input_tgt))
70
+
71
+ # Weight cannot be sharded, so always replicate it.
72
+ weight_tgt = DTensorSpec(
73
+ mesh=mesh,
74
+ placements=(Replicate(), ),
75
+ tensor_meta=weight_src.tensor_meta,
76
+ )
77
+ redistribute_costs.append(
78
+ generate_redistribute_costs(weight_strategy, weight_tgt))
79
+
80
+ strategy.strategies.append(
81
+ OpSpec(
82
+ output_specs=input_tgt,
83
+ input_specs=[input_tgt, weight_tgt],
84
+ redistribute_cost=redistribute_costs,
85
+ ))
86
+ return strategy
87
+
88
+
89
+ @register_op_strategy(ops.rms_norm_backward.default,
90
+ schema_info=RuntimeSchemaInfo(1))
91
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
92
+ mesh = op_schema.get_mesh_from_args()
93
+
94
+ assert len(op_schema.args_schema) == 4
95
+ (
96
+ output_grad_strategy,
97
+ input_strategy,
98
+ weight_strategy,
99
+ _, # eps
100
+ ) = op_schema.args_schema
101
+
102
+ assert isinstance(output_grad_strategy, OpStrategy)
103
+ assert isinstance(input_strategy, OpStrategy)
104
+ assert isinstance(weight_strategy, OpStrategy)
105
+
106
+ lengths = {
107
+ "output_grad": len(output_grad_strategy.strategies),
108
+ "input": len(input_strategy.strategies),
109
+ "weight": len(weight_strategy.strategies),
110
+ }
111
+
112
+ assert len(set(
113
+ lengths.values())) == 1, f"Strategies length mismatch {lengths}"
114
+
115
+ zipped = zip(
116
+ output_grad_strategy.strategies,
117
+ input_strategy.strategies,
118
+ weight_strategy.strategies,
119
+ )
120
+
121
+ last_dim = input_strategy.ndim - 1
122
+ strategy = OpStrategy([])
123
+ for output_grad, input, weight in zipped:
124
+ output_grad_src = output_grad.output_spec
125
+ input_src = input.output_spec
126
+ weight_src = weight.output_spec
127
+
128
+ assert isinstance(output_grad_src, DTensorSpec)
129
+ assert isinstance(input_src, DTensorSpec)
130
+ assert isinstance(weight_src, DTensorSpec)
131
+
132
+ redistribute_costs = []
133
+
134
+ # Output grad can be sharded in any dim except the last dim.
135
+ output_grad_tgt = DTensorSpec(
136
+ mesh=mesh,
137
+ placements=_replicate_dims_start_at(output_grad_src.placements,
138
+ last_dim),
139
+ tensor_meta=output_grad_src.tensor_meta,
140
+ )
141
+ redistribute_costs.append(
142
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
+
144
+ # Input must have the same sharding as output grad.
145
+ input_tgt = output_grad_tgt
146
+ redistribute_costs.append(
147
+ generate_redistribute_costs(input_strategy, input_tgt))
148
+
149
+ # Weight cannot be sharded, so always replicate it.
150
+ weight_tgt = DTensorSpec(
151
+ mesh=mesh,
152
+ placements=(Replicate(), ),
153
+ tensor_meta=weight_src.tensor_meta,
154
+ )
155
+ redistribute_costs.append(
156
+ generate_redistribute_costs(weight_strategy, weight_tgt))
157
+
158
+ strategy.strategies.append(
159
+ OpSpec(
160
+ output_specs=[input_tgt, weight_tgt],
161
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
+ redistribute_cost=redistribute_costs,
163
+ ))
164
+ return strategy
build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
 
3
- from . import layers
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -48,5 +48,6 @@ __all__ = [
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
51
  "ops",
52
  ]
 
1
  import torch
2
 
3
+ from . import layers, parallel_style
4
  from ._ops import ops
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
51
+ "parallel_style",
52
  "ops",
53
  ]
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37a572bd877980ab8c0331ca5682191cb5a2b1f05bc69ea493a9e24f7728ba3f
3
- size 12730840
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f7879c74d91f2412bbf5524cd107dea64edeeeabf1dd496eeefa627d2e7143c
3
+ size 13775752
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1dfb6d468f9cef0239d4ea47f0a247fa721befc5b8db86e1cddfc25f1814b67a
3
- size 13770064