kernel
danieldk HF Staff commited on
Commit
1a6ab32
·
1 Parent(s): 8915dbd

Enable Torch 2.8 build

Browse files
Files changed (3) hide show
  1. build.toml +59 -43
  2. flake.lock +79 -27
  3. flake.nix +1 -1
build.toml CHANGED
@@ -1,61 +1,77 @@
1
  [general]
2
  name = "moe"
 
3
 
4
  [torch]
 
 
 
 
 
5
  src = [
6
- "core/scalar_type.hpp",
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
  ]
10
- include = ["."]
11
- pyext = ["py", "json"]
12
 
13
- [kernel.fp8]
14
- src = [
15
- "cuda_compat.h",
16
- "dispatch_utils.h",
17
- "fp8/amd/hip_float8.h",
18
- "fp8/amd/hip_float8_impl.h",
19
- "fp8/common.cu",
20
- "fp8/common.cuh",
21
- "fp8/vectorization.cuh",
 
 
22
  ]
23
- include = ["."]
24
  depends = ["torch"]
25
-
26
-
27
- [kernel.moe]
28
  src = [
29
- "cuda_compat.h",
30
- "dispatch_utils.h",
31
- "moe/moe_align_sum_kernels.cu",
32
- "moe/moe_wna16.cu",
33
- "moe/moe_wna16_utils.h",
34
- "moe/topk_softmax_kernels.cu",
 
 
 
 
35
  ]
36
- depends = ["torch"]
37
 
38
- [kernel.moe-marlin]
39
- cuda-capabilities = ["8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0"]
 
40
  src = [
41
- "core/exception.hpp",
42
- "core/scalar_type.hpp",
43
- "marlin-moe/marlin_moe_ops.cu",
44
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu",
45
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu",
46
- "marlin-moe/marlin_kernels/marlin_moe_kernel.h",
47
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h",
48
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h",
49
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
50
- "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
51
  ]
52
- include = ["."]
53
- depends = ["torch"]
54
 
55
- [kernel.activation]
 
 
 
56
  src = [
57
- "activation/activation_kernels.cu",
58
- "activation/cuda_compat.h",
59
- "activation/dispatch_utils.h",
 
 
 
 
60
  ]
 
 
 
61
  depends = ["torch"]
 
 
 
 
 
 
 
 
 
1
  [general]
2
  name = "moe"
3
+ universal = false
4
 
5
  [torch]
6
+ include = ["."]
7
+ pyext = [
8
+ "py",
9
+ "json",
10
+ ]
11
  src = [
12
+ "core/scalar_type.hpp",
13
+ "torch-ext/torch_binding.cpp",
14
+ "torch-ext/torch_binding.h",
15
  ]
 
 
16
 
17
+ [kernel.moe-marlin]
18
+ backend = "cuda"
19
+ cuda-capabilities = [
20
+ "8.0",
21
+ "8.6",
22
+ "8.7",
23
+ "8.9",
24
+ "9.0",
25
+ "10.0",
26
+ "10.1",
27
+ "12.0",
28
  ]
 
29
  depends = ["torch"]
30
+ include = ["."]
 
 
31
  src = [
32
+ "core/exception.hpp",
33
+ "core/scalar_type.hpp",
34
+ "marlin-moe/marlin_moe_ops.cu",
35
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu",
36
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu",
37
+ "marlin-moe/marlin_kernels/marlin_moe_kernel.h",
38
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h",
39
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h",
40
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
41
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
42
  ]
 
43
 
44
+ [kernel.activation]
45
+ backend = "cuda"
46
+ depends = ["torch"]
47
  src = [
48
+ "activation/activation_kernels.cu",
49
+ "activation/cuda_compat.h",
50
+ "activation/dispatch_utils.h",
 
 
 
 
 
 
 
51
  ]
 
 
52
 
53
+ [kernel.fp8]
54
+ backend = "cuda"
55
+ depends = ["torch"]
56
+ include = ["."]
57
  src = [
58
+ "cuda_compat.h",
59
+ "dispatch_utils.h",
60
+ "fp8/amd/hip_float8.h",
61
+ "fp8/amd/hip_float8_impl.h",
62
+ "fp8/common.cu",
63
+ "fp8/common.cuh",
64
+ "fp8/vectorization.cuh",
65
  ]
66
+
67
+ [kernel.moe]
68
+ backend = "cuda"
69
  depends = ["torch"]
70
+ src = [
71
+ "cuda_compat.h",
72
+ "dispatch_utils.h",
73
+ "moe/moe_align_sum_kernels.cu",
74
+ "moe/moe_wna16.cu",
75
+ "moe/moe_wna16_utils.h",
76
+ "moe/topk_softmax_kernels.cu",
77
+ ]
flake.lock CHANGED
@@ -1,6 +1,21 @@
1
  {
2
  "nodes": {
3
  "flake-compat": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "locked": {
5
  "lastModified": 1733328505,
6
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
@@ -33,61 +48,83 @@
33
  "type": "github"
34
  }
35
  },
36
- "kernel-builder": {
37
  "inputs": {
38
- "flake-compat": "flake-compat",
39
- "flake-utils": "flake-utils",
40
- "nixpkgs": "nixpkgs",
41
- "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1744976941,
45
- "narHash": "sha256-+csrhVaT6Mj2j1FM7P2BDITvf1Xwj2AKdMm0IKZK340=",
46
- "owner": "huggingface",
47
- "repo": "kernel-builder",
48
- "rev": "0a278c2e9aaf6003a4ec6fe35c7158624762de5a",
49
  "type": "github"
50
  },
51
  "original": {
52
- "owner": "huggingface",
53
- "repo": "kernel-builder",
54
  "type": "github"
55
  }
56
  },
57
- "nixpkgs": {
 
 
 
 
 
58
  "locked": {
59
- "lastModified": 1743559129,
60
- "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
- "owner": "nixos",
62
- "repo": "nixpkgs",
63
- "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
  "type": "github"
65
  },
66
  "original": {
67
- "owner": "nixos",
68
- "ref": "nixos-unstable-small",
69
- "repo": "nixpkgs",
70
  "type": "github"
71
  }
72
  },
73
- "rocm-nix": {
74
  "inputs": {
 
 
 
75
  "nixpkgs": [
76
  "kernel-builder",
 
77
  "nixpkgs"
78
  ]
79
  },
80
  "locked": {
81
- "lastModified": 1743085847,
82
- "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
83
  "owner": "huggingface",
84
- "repo": "rocm-nix",
85
- "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
86
  "type": "github"
87
  },
88
  "original": {
89
  "owner": "huggingface",
90
- "repo": "rocm-nix",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  "type": "github"
92
  }
93
  },
@@ -110,6 +147,21 @@
110
  "repo": "default",
111
  "type": "github"
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
  },
115
  "root": "root",
 
1
  {
2
  "nodes": {
3
  "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
  "locked": {
20
  "lastModified": 1733328505,
21
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
 
48
  "type": "github"
49
  }
50
  },
51
+ "flake-utils_2": {
52
  "inputs": {
53
+ "systems": "systems_2"
 
 
 
54
  },
55
  "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
  "type": "github"
62
  },
63
  "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
  "type": "github"
67
  }
68
  },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
  "locked": {
76
+ "lastModified": 1753354560,
77
+ "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
  "type": "github"
82
  },
83
  "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
 
86
  "type": "github"
87
  }
88
  },
89
+ "kernel-builder": {
90
  "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
  "nixpkgs": [
95
  "kernel-builder",
96
+ "hf-nix",
97
  "nixpkgs"
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1753354632,
102
+ "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
  "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "ref": "torch-2.8",
111
+ "repo": "kernel-builder",
112
+ "type": "github"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1752785354,
118
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
+ "owner": "nixos",
120
+ "repo": "nixpkgs",
121
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "nixos",
126
+ "repo": "nixpkgs",
127
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
  "type": "github"
129
  }
130
  },
 
147
  "repo": "default",
148
  "type": "github"
149
  }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
  }
166
  },
167
  "root": "root",
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for activation kernels";
3
 
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
 
2
  description = "Flake for activation kernels";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
  };
7
 
8
  outputs =