danieldk HF Staff commited on
Commit
9ae1b46
·
1 Parent(s): e3200fb

Build (aarch64)

Browse files
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from functools import lru_cache
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from ._ops import ops
8
+
9
+ MIN_SGMV_RANK = 8
10
+ MIN_RANK_CUSTOM = 16
11
+ MAX_RANK_CUSTOM = 128
12
+ SGMV_BLOCK_SIZE = 16
13
+ BGMV_MAX_RANK = 128
14
+
15
+ def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
16
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
17
+ return t.transpose(0, 1)
18
+ return t
19
+
20
+ def add_lora_sgmv_cutlass(
21
+ y: torch.Tensor,
22
+ x: torch.Tensor,
23
+ wa_ptr: torch.Tensor,
24
+ wb_ptr: torch.Tensor,
25
+ s_start: torch.Tensor,
26
+ s_end: torch.Tensor,
27
+ layer_idx: int,
28
+ lora_rank: int,
29
+ ):
30
+ """
31
+ Semantics:
32
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
33
+
34
+ Args:
35
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
36
+ x: Shape: `[B, H1]`. Input vectors.
37
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
38
+ Weight matrix shape: `[num_layers, R, H1]`.
39
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
40
+ Weight matrix shape: `[num_layers, R, H2]`.
41
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
42
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
43
+ layer_idx: Layer index of the weight matrices.
44
+ """
45
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
46
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
47
+ _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
48
+ return
49
+
50
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
51
+ tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
52
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
53
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
54
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
55
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
56
+
57
+ def _add_lora_sgmv_cutlass_legacy(
58
+ y: torch.Tensor,
59
+ x: torch.Tensor,
60
+ wa_ptr: torch.Tensor,
61
+ wb_ptr: torch.Tensor,
62
+ s_start: torch.IntTensor,
63
+ s_end: torch.IntTensor,
64
+ layer_idx: int,
65
+ lora_rank: int,
66
+ ):
67
+ tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
68
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
69
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
70
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
71
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
72
+
73
+ def lora_a_sgmv_cutlass(
74
+ x: torch.Tensor,
75
+ tmp: torch.Tensor,
76
+ wa_ptr: torch.Tensor,
77
+ s_start: torch.IntTensor,
78
+ s_end: torch.IntTensor,
79
+ layer_idx: int,
80
+ lora_rank: int,
81
+ ) -> torch.Tensor:
82
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
83
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
84
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
85
+ else:
86
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
87
+ return v
88
+
89
+
90
+ def lora_b_sgmv_cutlass(
91
+ y: torch.Tensor,
92
+ v: torch.Tensor,
93
+ tmp: torch.Tensor,
94
+ wb_ptr: torch.Tensor,
95
+ s_start: torch.IntTensor,
96
+ s_end: torch.IntTensor,
97
+ layer_idx: int,
98
+ ):
99
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
100
+
101
+ def add_lora_a_bgmv(
102
+ v: torch.Tensor,
103
+ x: torch.Tensor,
104
+ wa_T_all: torch.Tensor,
105
+ indicies: torch.LongTensor,
106
+ layer_idx: int,
107
+ ):
108
+ ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
109
+
110
+
111
+ def add_lora_b_bgmv(
112
+ y: torch.Tensor,
113
+ v: torch.Tensor,
114
+ wb_T_all: torch.Tensor,
115
+ indicies: torch.LongTensor,
116
+ layer_idx: int,
117
+ ):
118
+ ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
119
+
120
+
121
+ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
122
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
123
+ # tensor parallelism will result in effective rank being divided by world_size,
124
+ # so we need to scale the min rank to offset that effect
125
+ min_rank = MIN_SGMV_RANK * world_size
126
+ return pad_to_min_rank(t, dim, min_rank)
127
+
128
+ def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
129
+ # if we're at or below the min rank, pad up to the min rank
130
+ # otherwise, pad to the nearest multiple of the block size
131
+ current_rank = t.size(dim)
132
+ target_rank = (
133
+ min_rank
134
+ if current_rank <= min_rank
135
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
136
+ )
137
+ if current_rank == target_rank:
138
+ return t
139
+
140
+ pad_size = target_rank - current_rank
141
+
142
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
143
+ pad = [0, 0] * t.dim()
144
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
145
+ pad = tuple(pad)
146
+
147
+ return F.pad(t, pad, mode="constant", value=0.0)
148
+
149
+ def use_cutlass_shrink(lora_rank: int) -> bool:
150
+ return lora_rank < MIN_RANK_CUSTOM
151
+
152
+ @lru_cache(maxsize=1)
153
+ def get_tmp_tensor(device: torch.device) -> torch.Tensor:
154
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
155
+
156
+ @lru_cache(maxsize=32)
157
+ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
158
+ tmp_size = ops.sgmv_cutlass_tmp_size(size)
159
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
160
+
161
+ def get_tmp_expand_size(size: int) -> int:
162
+ return ops.sgmv_cutlass_tmp_size(size)
163
+
164
+
165
+ def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ if use_cutlass_shrink(lora_rank):
167
+ tmp = get_tmp_tensor_for_size(nsegments, device)
168
+ return tmp, tmp
169
+ else:
170
+ tmp_shrink = get_tmp_tensor(device)
171
+ tmp_expand = get_tmp_tensor_for_size(nsegments, device)
172
+ return tmp_shrink, tmp_expand
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _punica_sgmv_ad0ac7e_dirty
3
+ ops = torch.ops._punica_sgmv_ad0ac7e_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fb84288c2a868d46ec95e015ef56c1c661b46f9c8158dde3809569b973062af
3
+ size 14311192
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from functools import lru_cache
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from ._ops import ops
8
+
9
+ MIN_SGMV_RANK = 8
10
+ MIN_RANK_CUSTOM = 16
11
+ MAX_RANK_CUSTOM = 128
12
+ SGMV_BLOCK_SIZE = 16
13
+ BGMV_MAX_RANK = 128
14
+
15
+ def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
16
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
17
+ return t.transpose(0, 1)
18
+ return t
19
+
20
+ def add_lora_sgmv_cutlass(
21
+ y: torch.Tensor,
22
+ x: torch.Tensor,
23
+ wa_ptr: torch.Tensor,
24
+ wb_ptr: torch.Tensor,
25
+ s_start: torch.Tensor,
26
+ s_end: torch.Tensor,
27
+ layer_idx: int,
28
+ lora_rank: int,
29
+ ):
30
+ """
31
+ Semantics:
32
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
33
+
34
+ Args:
35
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
36
+ x: Shape: `[B, H1]`. Input vectors.
37
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
38
+ Weight matrix shape: `[num_layers, R, H1]`.
39
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
40
+ Weight matrix shape: `[num_layers, R, H2]`.
41
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
42
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
43
+ layer_idx: Layer index of the weight matrices.
44
+ """
45
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
46
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
47
+ _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
48
+ return
49
+
50
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
51
+ tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
52
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
53
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
54
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
55
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
56
+
57
+ def _add_lora_sgmv_cutlass_legacy(
58
+ y: torch.Tensor,
59
+ x: torch.Tensor,
60
+ wa_ptr: torch.Tensor,
61
+ wb_ptr: torch.Tensor,
62
+ s_start: torch.IntTensor,
63
+ s_end: torch.IntTensor,
64
+ layer_idx: int,
65
+ lora_rank: int,
66
+ ):
67
+ tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
68
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
69
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
70
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
71
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
72
+
73
+ def lora_a_sgmv_cutlass(
74
+ x: torch.Tensor,
75
+ tmp: torch.Tensor,
76
+ wa_ptr: torch.Tensor,
77
+ s_start: torch.IntTensor,
78
+ s_end: torch.IntTensor,
79
+ layer_idx: int,
80
+ lora_rank: int,
81
+ ) -> torch.Tensor:
82
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
83
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
84
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
85
+ else:
86
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
87
+ return v
88
+
89
+
90
+ def lora_b_sgmv_cutlass(
91
+ y: torch.Tensor,
92
+ v: torch.Tensor,
93
+ tmp: torch.Tensor,
94
+ wb_ptr: torch.Tensor,
95
+ s_start: torch.IntTensor,
96
+ s_end: torch.IntTensor,
97
+ layer_idx: int,
98
+ ):
99
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
100
+
101
+ def add_lora_a_bgmv(
102
+ v: torch.Tensor,
103
+ x: torch.Tensor,
104
+ wa_T_all: torch.Tensor,
105
+ indicies: torch.LongTensor,
106
+ layer_idx: int,
107
+ ):
108
+ ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
109
+
110
+
111
+ def add_lora_b_bgmv(
112
+ y: torch.Tensor,
113
+ v: torch.Tensor,
114
+ wb_T_all: torch.Tensor,
115
+ indicies: torch.LongTensor,
116
+ layer_idx: int,
117
+ ):
118
+ ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
119
+
120
+
121
+ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
122
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
123
+ # tensor parallelism will result in effective rank being divided by world_size,
124
+ # so we need to scale the min rank to offset that effect
125
+ min_rank = MIN_SGMV_RANK * world_size
126
+ return pad_to_min_rank(t, dim, min_rank)
127
+
128
+ def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
129
+ # if we're at or below the min rank, pad up to the min rank
130
+ # otherwise, pad to the nearest multiple of the block size
131
+ current_rank = t.size(dim)
132
+ target_rank = (
133
+ min_rank
134
+ if current_rank <= min_rank
135
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
136
+ )
137
+ if current_rank == target_rank:
138
+ return t
139
+
140
+ pad_size = target_rank - current_rank
141
+
142
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
143
+ pad = [0, 0] * t.dim()
144
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
145
+ pad = tuple(pad)
146
+
147
+ return F.pad(t, pad, mode="constant", value=0.0)
148
+
149
+ def use_cutlass_shrink(lora_rank: int) -> bool:
150
+ return lora_rank < MIN_RANK_CUSTOM
151
+
152
+ @lru_cache(maxsize=1)
153
+ def get_tmp_tensor(device: torch.device) -> torch.Tensor:
154
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
155
+
156
+ @lru_cache(maxsize=32)
157
+ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
158
+ tmp_size = ops.sgmv_cutlass_tmp_size(size)
159
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
160
+
161
+ def get_tmp_expand_size(size: int) -> int:
162
+ return ops.sgmv_cutlass_tmp_size(size)
163
+
164
+
165
+ def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ if use_cutlass_shrink(lora_rank):
167
+ tmp = get_tmp_tensor_for_size(nsegments, device)
168
+ return tmp, tmp
169
+ else:
170
+ tmp_shrink = get_tmp_tensor(device)
171
+ tmp_expand = get_tmp_tensor_for_size(nsegments, device)
172
+ return tmp_shrink, tmp_expand
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _punica_sgmv_ad0ac7e_dirty
3
+ ops = torch.ops._punica_sgmv_ad0ac7e_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:530a51beb6f591c58e8fe13afd427204bbf39572648fbc2befd5d5d63358b4cb
3
+ size 14307968
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from functools import lru_cache
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from ._ops import ops
8
+
9
+ MIN_SGMV_RANK = 8
10
+ MIN_RANK_CUSTOM = 16
11
+ MAX_RANK_CUSTOM = 128
12
+ SGMV_BLOCK_SIZE = 16
13
+ BGMV_MAX_RANK = 128
14
+
15
+ def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
16
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
17
+ return t.transpose(0, 1)
18
+ return t
19
+
20
+ def add_lora_sgmv_cutlass(
21
+ y: torch.Tensor,
22
+ x: torch.Tensor,
23
+ wa_ptr: torch.Tensor,
24
+ wb_ptr: torch.Tensor,
25
+ s_start: torch.Tensor,
26
+ s_end: torch.Tensor,
27
+ layer_idx: int,
28
+ lora_rank: int,
29
+ ):
30
+ """
31
+ Semantics:
32
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
33
+
34
+ Args:
35
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
36
+ x: Shape: `[B, H1]`. Input vectors.
37
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
38
+ Weight matrix shape: `[num_layers, R, H1]`.
39
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
40
+ Weight matrix shape: `[num_layers, R, H2]`.
41
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
42
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
43
+ layer_idx: Layer index of the weight matrices.
44
+ """
45
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
46
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
47
+ _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
48
+ return
49
+
50
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
51
+ tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
52
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
53
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
54
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
55
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
56
+
57
+ def _add_lora_sgmv_cutlass_legacy(
58
+ y: torch.Tensor,
59
+ x: torch.Tensor,
60
+ wa_ptr: torch.Tensor,
61
+ wb_ptr: torch.Tensor,
62
+ s_start: torch.IntTensor,
63
+ s_end: torch.IntTensor,
64
+ layer_idx: int,
65
+ lora_rank: int,
66
+ ):
67
+ tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
68
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
69
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
70
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
71
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
72
+
73
+ def lora_a_sgmv_cutlass(
74
+ x: torch.Tensor,
75
+ tmp: torch.Tensor,
76
+ wa_ptr: torch.Tensor,
77
+ s_start: torch.IntTensor,
78
+ s_end: torch.IntTensor,
79
+ layer_idx: int,
80
+ lora_rank: int,
81
+ ) -> torch.Tensor:
82
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
83
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
84
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
85
+ else:
86
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
87
+ return v
88
+
89
+
90
+ def lora_b_sgmv_cutlass(
91
+ y: torch.Tensor,
92
+ v: torch.Tensor,
93
+ tmp: torch.Tensor,
94
+ wb_ptr: torch.Tensor,
95
+ s_start: torch.IntTensor,
96
+ s_end: torch.IntTensor,
97
+ layer_idx: int,
98
+ ):
99
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
100
+
101
+ def add_lora_a_bgmv(
102
+ v: torch.Tensor,
103
+ x: torch.Tensor,
104
+ wa_T_all: torch.Tensor,
105
+ indicies: torch.LongTensor,
106
+ layer_idx: int,
107
+ ):
108
+ ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
109
+
110
+
111
+ def add_lora_b_bgmv(
112
+ y: torch.Tensor,
113
+ v: torch.Tensor,
114
+ wb_T_all: torch.Tensor,
115
+ indicies: torch.LongTensor,
116
+ layer_idx: int,
117
+ ):
118
+ ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
119
+
120
+
121
+ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
122
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
123
+ # tensor parallelism will result in effective rank being divided by world_size,
124
+ # so we need to scale the min rank to offset that effect
125
+ min_rank = MIN_SGMV_RANK * world_size
126
+ return pad_to_min_rank(t, dim, min_rank)
127
+
128
+ def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
129
+ # if we're at or below the min rank, pad up to the min rank
130
+ # otherwise, pad to the nearest multiple of the block size
131
+ current_rank = t.size(dim)
132
+ target_rank = (
133
+ min_rank
134
+ if current_rank <= min_rank
135
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
136
+ )
137
+ if current_rank == target_rank:
138
+ return t
139
+
140
+ pad_size = target_rank - current_rank
141
+
142
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
143
+ pad = [0, 0] * t.dim()
144
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
145
+ pad = tuple(pad)
146
+
147
+ return F.pad(t, pad, mode="constant", value=0.0)
148
+
149
+ def use_cutlass_shrink(lora_rank: int) -> bool:
150
+ return lora_rank < MIN_RANK_CUSTOM
151
+
152
+ @lru_cache(maxsize=1)
153
+ def get_tmp_tensor(device: torch.device) -> torch.Tensor:
154
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
155
+
156
+ @lru_cache(maxsize=32)
157
+ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
158
+ tmp_size = ops.sgmv_cutlass_tmp_size(size)
159
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
160
+
161
+ def get_tmp_expand_size(size: int) -> int:
162
+ return ops.sgmv_cutlass_tmp_size(size)
163
+
164
+
165
+ def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ if use_cutlass_shrink(lora_rank):
167
+ tmp = get_tmp_tensor_for_size(nsegments, device)
168
+ return tmp, tmp
169
+ else:
170
+ tmp_shrink = get_tmp_tensor(device)
171
+ tmp_expand = get_tmp_tensor_for_size(nsegments, device)
172
+ return tmp_shrink, tmp_expand
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _punica_sgmv_ad0ac7e_dirty
3
+ ops = torch.ops._punica_sgmv_ad0ac7e_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1526b236ab1acc48ece52360332f9e7fbf261e18e4f700b404aa6dbe45240e
3
+ size 14311416
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from functools import lru_cache
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from ._ops import ops
8
+
9
+ MIN_SGMV_RANK = 8
10
+ MIN_RANK_CUSTOM = 16
11
+ MAX_RANK_CUSTOM = 128
12
+ SGMV_BLOCK_SIZE = 16
13
+ BGMV_MAX_RANK = 128
14
+
15
+ def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
16
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
17
+ return t.transpose(0, 1)
18
+ return t
19
+
20
+ def add_lora_sgmv_cutlass(
21
+ y: torch.Tensor,
22
+ x: torch.Tensor,
23
+ wa_ptr: torch.Tensor,
24
+ wb_ptr: torch.Tensor,
25
+ s_start: torch.Tensor,
26
+ s_end: torch.Tensor,
27
+ layer_idx: int,
28
+ lora_rank: int,
29
+ ):
30
+ """
31
+ Semantics:
32
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
33
+
34
+ Args:
35
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
36
+ x: Shape: `[B, H1]`. Input vectors.
37
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
38
+ Weight matrix shape: `[num_layers, R, H1]`.
39
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
40
+ Weight matrix shape: `[num_layers, R, H2]`.
41
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
42
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
43
+ layer_idx: Layer index of the weight matrices.
44
+ """
45
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
46
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
47
+ _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
48
+ return
49
+
50
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
51
+ tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
52
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
53
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
54
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
55
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
56
+
57
+ def _add_lora_sgmv_cutlass_legacy(
58
+ y: torch.Tensor,
59
+ x: torch.Tensor,
60
+ wa_ptr: torch.Tensor,
61
+ wb_ptr: torch.Tensor,
62
+ s_start: torch.IntTensor,
63
+ s_end: torch.IntTensor,
64
+ layer_idx: int,
65
+ lora_rank: int,
66
+ ):
67
+ tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
68
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
69
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
70
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
71
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
72
+
73
+ def lora_a_sgmv_cutlass(
74
+ x: torch.Tensor,
75
+ tmp: torch.Tensor,
76
+ wa_ptr: torch.Tensor,
77
+ s_start: torch.IntTensor,
78
+ s_end: torch.IntTensor,
79
+ layer_idx: int,
80
+ lora_rank: int,
81
+ ) -> torch.Tensor:
82
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
83
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
84
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
85
+ else:
86
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
87
+ return v
88
+
89
+
90
+ def lora_b_sgmv_cutlass(
91
+ y: torch.Tensor,
92
+ v: torch.Tensor,
93
+ tmp: torch.Tensor,
94
+ wb_ptr: torch.Tensor,
95
+ s_start: torch.IntTensor,
96
+ s_end: torch.IntTensor,
97
+ layer_idx: int,
98
+ ):
99
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
100
+
101
+ def add_lora_a_bgmv(
102
+ v: torch.Tensor,
103
+ x: torch.Tensor,
104
+ wa_T_all: torch.Tensor,
105
+ indicies: torch.LongTensor,
106
+ layer_idx: int,
107
+ ):
108
+ ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
109
+
110
+
111
+ def add_lora_b_bgmv(
112
+ y: torch.Tensor,
113
+ v: torch.Tensor,
114
+ wb_T_all: torch.Tensor,
115
+ indicies: torch.LongTensor,
116
+ layer_idx: int,
117
+ ):
118
+ ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
119
+
120
+
121
+ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
122
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
123
+ # tensor parallelism will result in effective rank being divided by world_size,
124
+ # so we need to scale the min rank to offset that effect
125
+ min_rank = MIN_SGMV_RANK * world_size
126
+ return pad_to_min_rank(t, dim, min_rank)
127
+
128
+ def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
129
+ # if we're at or below the min rank, pad up to the min rank
130
+ # otherwise, pad to the nearest multiple of the block size
131
+ current_rank = t.size(dim)
132
+ target_rank = (
133
+ min_rank
134
+ if current_rank <= min_rank
135
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
136
+ )
137
+ if current_rank == target_rank:
138
+ return t
139
+
140
+ pad_size = target_rank - current_rank
141
+
142
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
143
+ pad = [0, 0] * t.dim()
144
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
145
+ pad = tuple(pad)
146
+
147
+ return F.pad(t, pad, mode="constant", value=0.0)
148
+
149
+ def use_cutlass_shrink(lora_rank: int) -> bool:
150
+ return lora_rank < MIN_RANK_CUSTOM
151
+
152
+ @lru_cache(maxsize=1)
153
+ def get_tmp_tensor(device: torch.device) -> torch.Tensor:
154
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
155
+
156
+ @lru_cache(maxsize=32)
157
+ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
158
+ tmp_size = ops.sgmv_cutlass_tmp_size(size)
159
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
160
+
161
+ def get_tmp_expand_size(size: int) -> int:
162
+ return ops.sgmv_cutlass_tmp_size(size)
163
+
164
+
165
+ def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ if use_cutlass_shrink(lora_rank):
167
+ tmp = get_tmp_tensor_for_size(nsegments, device)
168
+ return tmp, tmp
169
+ else:
170
+ tmp_shrink = get_tmp_tensor(device)
171
+ tmp_expand = get_tmp_tensor_for_size(nsegments, device)
172
+ return tmp_shrink, tmp_expand
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _punica_sgmv_ad0ac7e_dirty
3
+ ops = torch.ops._punica_sgmv_ad0ac7e_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1681b08e9b39010e07ed293cd48fa54910a8811c01bd06acc235742660efb766
3
+ size 22831040