Build (aarch64)
Browse files- build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py +172 -0
- build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py +9 -0
- build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/__init__.py +172 -0
- build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_ops.py +9 -0
- build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py +172 -0
- build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py +9 -0
- build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/__init__.py +172 -0
- build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_ops.py +9 -0
- build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so +3 -0
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ._ops import ops
|
8 |
+
|
9 |
+
MIN_SGMV_RANK = 8
|
10 |
+
MIN_RANK_CUSTOM = 16
|
11 |
+
MAX_RANK_CUSTOM = 128
|
12 |
+
SGMV_BLOCK_SIZE = 16
|
13 |
+
BGMV_MAX_RANK = 128
|
14 |
+
|
15 |
+
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
16 |
+
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
17 |
+
return t.transpose(0, 1)
|
18 |
+
return t
|
19 |
+
|
20 |
+
def add_lora_sgmv_cutlass(
|
21 |
+
y: torch.Tensor,
|
22 |
+
x: torch.Tensor,
|
23 |
+
wa_ptr: torch.Tensor,
|
24 |
+
wb_ptr: torch.Tensor,
|
25 |
+
s_start: torch.Tensor,
|
26 |
+
s_end: torch.Tensor,
|
27 |
+
layer_idx: int,
|
28 |
+
lora_rank: int,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Semantics:
|
32 |
+
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
33 |
+
|
34 |
+
Args:
|
35 |
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
36 |
+
x: Shape: `[B, H1]`. Input vectors.
|
37 |
+
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
38 |
+
Weight matrix shape: `[num_layers, R, H1]`.
|
39 |
+
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
40 |
+
Weight matrix shape: `[num_layers, R, H2]`.
|
41 |
+
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
42 |
+
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
43 |
+
layer_idx: Layer index of the weight matrices.
|
44 |
+
"""
|
45 |
+
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
46 |
+
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
47 |
+
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
|
48 |
+
return
|
49 |
+
|
50 |
+
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
51 |
+
tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
52 |
+
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
53 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
54 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
55 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
56 |
+
|
57 |
+
def _add_lora_sgmv_cutlass_legacy(
|
58 |
+
y: torch.Tensor,
|
59 |
+
x: torch.Tensor,
|
60 |
+
wa_ptr: torch.Tensor,
|
61 |
+
wb_ptr: torch.Tensor,
|
62 |
+
s_start: torch.IntTensor,
|
63 |
+
s_end: torch.IntTensor,
|
64 |
+
layer_idx: int,
|
65 |
+
lora_rank: int,
|
66 |
+
):
|
67 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
68 |
+
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
69 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
70 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
71 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
72 |
+
|
73 |
+
def lora_a_sgmv_cutlass(
|
74 |
+
x: torch.Tensor,
|
75 |
+
tmp: torch.Tensor,
|
76 |
+
wa_ptr: torch.Tensor,
|
77 |
+
s_start: torch.IntTensor,
|
78 |
+
s_end: torch.IntTensor,
|
79 |
+
layer_idx: int,
|
80 |
+
lora_rank: int,
|
81 |
+
) -> torch.Tensor:
|
82 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
83 |
+
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
84 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
85 |
+
else:
|
86 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
87 |
+
return v
|
88 |
+
|
89 |
+
|
90 |
+
def lora_b_sgmv_cutlass(
|
91 |
+
y: torch.Tensor,
|
92 |
+
v: torch.Tensor,
|
93 |
+
tmp: torch.Tensor,
|
94 |
+
wb_ptr: torch.Tensor,
|
95 |
+
s_start: torch.IntTensor,
|
96 |
+
s_end: torch.IntTensor,
|
97 |
+
layer_idx: int,
|
98 |
+
):
|
99 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
100 |
+
|
101 |
+
def add_lora_a_bgmv(
|
102 |
+
v: torch.Tensor,
|
103 |
+
x: torch.Tensor,
|
104 |
+
wa_T_all: torch.Tensor,
|
105 |
+
indicies: torch.LongTensor,
|
106 |
+
layer_idx: int,
|
107 |
+
):
|
108 |
+
ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
109 |
+
|
110 |
+
|
111 |
+
def add_lora_b_bgmv(
|
112 |
+
y: torch.Tensor,
|
113 |
+
v: torch.Tensor,
|
114 |
+
wb_T_all: torch.Tensor,
|
115 |
+
indicies: torch.LongTensor,
|
116 |
+
layer_idx: int,
|
117 |
+
):
|
118 |
+
ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
122 |
+
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
123 |
+
# tensor parallelism will result in effective rank being divided by world_size,
|
124 |
+
# so we need to scale the min rank to offset that effect
|
125 |
+
min_rank = MIN_SGMV_RANK * world_size
|
126 |
+
return pad_to_min_rank(t, dim, min_rank)
|
127 |
+
|
128 |
+
def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
|
129 |
+
# if we're at or below the min rank, pad up to the min rank
|
130 |
+
# otherwise, pad to the nearest multiple of the block size
|
131 |
+
current_rank = t.size(dim)
|
132 |
+
target_rank = (
|
133 |
+
min_rank
|
134 |
+
if current_rank <= min_rank
|
135 |
+
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
136 |
+
)
|
137 |
+
if current_rank == target_rank:
|
138 |
+
return t
|
139 |
+
|
140 |
+
pad_size = target_rank - current_rank
|
141 |
+
|
142 |
+
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
143 |
+
pad = [0, 0] * t.dim()
|
144 |
+
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
145 |
+
pad = tuple(pad)
|
146 |
+
|
147 |
+
return F.pad(t, pad, mode="constant", value=0.0)
|
148 |
+
|
149 |
+
def use_cutlass_shrink(lora_rank: int) -> bool:
|
150 |
+
return lora_rank < MIN_RANK_CUSTOM
|
151 |
+
|
152 |
+
@lru_cache(maxsize=1)
|
153 |
+
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
154 |
+
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
155 |
+
|
156 |
+
@lru_cache(maxsize=32)
|
157 |
+
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
158 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(size)
|
159 |
+
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
160 |
+
|
161 |
+
def get_tmp_expand_size(size: int) -> int:
|
162 |
+
return ops.sgmv_cutlass_tmp_size(size)
|
163 |
+
|
164 |
+
|
165 |
+
def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
166 |
+
if use_cutlass_shrink(lora_rank):
|
167 |
+
tmp = get_tmp_tensor_for_size(nsegments, device)
|
168 |
+
return tmp, tmp
|
169 |
+
else:
|
170 |
+
tmp_shrink = get_tmp_tensor(device)
|
171 |
+
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
172 |
+
return tmp_shrink, tmp_expand
|
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _punica_sgmv_ad0ac7e_dirty
|
3 |
+
ops = torch.ops._punica_sgmv_ad0ac7e_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
|
build/torch26-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fb84288c2a868d46ec95e015ef56c1c661b46f9c8158dde3809569b973062af
|
3 |
+
size 14311192
|
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/__init__.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ._ops import ops
|
8 |
+
|
9 |
+
MIN_SGMV_RANK = 8
|
10 |
+
MIN_RANK_CUSTOM = 16
|
11 |
+
MAX_RANK_CUSTOM = 128
|
12 |
+
SGMV_BLOCK_SIZE = 16
|
13 |
+
BGMV_MAX_RANK = 128
|
14 |
+
|
15 |
+
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
16 |
+
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
17 |
+
return t.transpose(0, 1)
|
18 |
+
return t
|
19 |
+
|
20 |
+
def add_lora_sgmv_cutlass(
|
21 |
+
y: torch.Tensor,
|
22 |
+
x: torch.Tensor,
|
23 |
+
wa_ptr: torch.Tensor,
|
24 |
+
wb_ptr: torch.Tensor,
|
25 |
+
s_start: torch.Tensor,
|
26 |
+
s_end: torch.Tensor,
|
27 |
+
layer_idx: int,
|
28 |
+
lora_rank: int,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Semantics:
|
32 |
+
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
33 |
+
|
34 |
+
Args:
|
35 |
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
36 |
+
x: Shape: `[B, H1]`. Input vectors.
|
37 |
+
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
38 |
+
Weight matrix shape: `[num_layers, R, H1]`.
|
39 |
+
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
40 |
+
Weight matrix shape: `[num_layers, R, H2]`.
|
41 |
+
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
42 |
+
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
43 |
+
layer_idx: Layer index of the weight matrices.
|
44 |
+
"""
|
45 |
+
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
46 |
+
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
47 |
+
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
|
48 |
+
return
|
49 |
+
|
50 |
+
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
51 |
+
tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
52 |
+
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
53 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
54 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
55 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
56 |
+
|
57 |
+
def _add_lora_sgmv_cutlass_legacy(
|
58 |
+
y: torch.Tensor,
|
59 |
+
x: torch.Tensor,
|
60 |
+
wa_ptr: torch.Tensor,
|
61 |
+
wb_ptr: torch.Tensor,
|
62 |
+
s_start: torch.IntTensor,
|
63 |
+
s_end: torch.IntTensor,
|
64 |
+
layer_idx: int,
|
65 |
+
lora_rank: int,
|
66 |
+
):
|
67 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
68 |
+
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
69 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
70 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
71 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
72 |
+
|
73 |
+
def lora_a_sgmv_cutlass(
|
74 |
+
x: torch.Tensor,
|
75 |
+
tmp: torch.Tensor,
|
76 |
+
wa_ptr: torch.Tensor,
|
77 |
+
s_start: torch.IntTensor,
|
78 |
+
s_end: torch.IntTensor,
|
79 |
+
layer_idx: int,
|
80 |
+
lora_rank: int,
|
81 |
+
) -> torch.Tensor:
|
82 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
83 |
+
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
84 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
85 |
+
else:
|
86 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
87 |
+
return v
|
88 |
+
|
89 |
+
|
90 |
+
def lora_b_sgmv_cutlass(
|
91 |
+
y: torch.Tensor,
|
92 |
+
v: torch.Tensor,
|
93 |
+
tmp: torch.Tensor,
|
94 |
+
wb_ptr: torch.Tensor,
|
95 |
+
s_start: torch.IntTensor,
|
96 |
+
s_end: torch.IntTensor,
|
97 |
+
layer_idx: int,
|
98 |
+
):
|
99 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
100 |
+
|
101 |
+
def add_lora_a_bgmv(
|
102 |
+
v: torch.Tensor,
|
103 |
+
x: torch.Tensor,
|
104 |
+
wa_T_all: torch.Tensor,
|
105 |
+
indicies: torch.LongTensor,
|
106 |
+
layer_idx: int,
|
107 |
+
):
|
108 |
+
ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
109 |
+
|
110 |
+
|
111 |
+
def add_lora_b_bgmv(
|
112 |
+
y: torch.Tensor,
|
113 |
+
v: torch.Tensor,
|
114 |
+
wb_T_all: torch.Tensor,
|
115 |
+
indicies: torch.LongTensor,
|
116 |
+
layer_idx: int,
|
117 |
+
):
|
118 |
+
ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
122 |
+
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
123 |
+
# tensor parallelism will result in effective rank being divided by world_size,
|
124 |
+
# so we need to scale the min rank to offset that effect
|
125 |
+
min_rank = MIN_SGMV_RANK * world_size
|
126 |
+
return pad_to_min_rank(t, dim, min_rank)
|
127 |
+
|
128 |
+
def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
|
129 |
+
# if we're at or below the min rank, pad up to the min rank
|
130 |
+
# otherwise, pad to the nearest multiple of the block size
|
131 |
+
current_rank = t.size(dim)
|
132 |
+
target_rank = (
|
133 |
+
min_rank
|
134 |
+
if current_rank <= min_rank
|
135 |
+
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
136 |
+
)
|
137 |
+
if current_rank == target_rank:
|
138 |
+
return t
|
139 |
+
|
140 |
+
pad_size = target_rank - current_rank
|
141 |
+
|
142 |
+
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
143 |
+
pad = [0, 0] * t.dim()
|
144 |
+
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
145 |
+
pad = tuple(pad)
|
146 |
+
|
147 |
+
return F.pad(t, pad, mode="constant", value=0.0)
|
148 |
+
|
149 |
+
def use_cutlass_shrink(lora_rank: int) -> bool:
|
150 |
+
return lora_rank < MIN_RANK_CUSTOM
|
151 |
+
|
152 |
+
@lru_cache(maxsize=1)
|
153 |
+
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
154 |
+
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
155 |
+
|
156 |
+
@lru_cache(maxsize=32)
|
157 |
+
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
158 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(size)
|
159 |
+
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
160 |
+
|
161 |
+
def get_tmp_expand_size(size: int) -> int:
|
162 |
+
return ops.sgmv_cutlass_tmp_size(size)
|
163 |
+
|
164 |
+
|
165 |
+
def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
166 |
+
if use_cutlass_shrink(lora_rank):
|
167 |
+
tmp = get_tmp_tensor_for_size(nsegments, device)
|
168 |
+
return tmp, tmp
|
169 |
+
else:
|
170 |
+
tmp_shrink = get_tmp_tensor(device)
|
171 |
+
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
172 |
+
return tmp_shrink, tmp_expand
|
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _punica_sgmv_ad0ac7e_dirty
|
3 |
+
ops = torch.ops._punica_sgmv_ad0ac7e_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
|
build/torch26-cxx98-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:530a51beb6f591c58e8fe13afd427204bbf39572648fbc2befd5d5d63358b4cb
|
3 |
+
size 14307968
|
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/__init__.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ._ops import ops
|
8 |
+
|
9 |
+
MIN_SGMV_RANK = 8
|
10 |
+
MIN_RANK_CUSTOM = 16
|
11 |
+
MAX_RANK_CUSTOM = 128
|
12 |
+
SGMV_BLOCK_SIZE = 16
|
13 |
+
BGMV_MAX_RANK = 128
|
14 |
+
|
15 |
+
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
16 |
+
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
17 |
+
return t.transpose(0, 1)
|
18 |
+
return t
|
19 |
+
|
20 |
+
def add_lora_sgmv_cutlass(
|
21 |
+
y: torch.Tensor,
|
22 |
+
x: torch.Tensor,
|
23 |
+
wa_ptr: torch.Tensor,
|
24 |
+
wb_ptr: torch.Tensor,
|
25 |
+
s_start: torch.Tensor,
|
26 |
+
s_end: torch.Tensor,
|
27 |
+
layer_idx: int,
|
28 |
+
lora_rank: int,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Semantics:
|
32 |
+
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
33 |
+
|
34 |
+
Args:
|
35 |
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
36 |
+
x: Shape: `[B, H1]`. Input vectors.
|
37 |
+
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
38 |
+
Weight matrix shape: `[num_layers, R, H1]`.
|
39 |
+
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
40 |
+
Weight matrix shape: `[num_layers, R, H2]`.
|
41 |
+
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
42 |
+
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
43 |
+
layer_idx: Layer index of the weight matrices.
|
44 |
+
"""
|
45 |
+
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
46 |
+
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
47 |
+
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
|
48 |
+
return
|
49 |
+
|
50 |
+
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
51 |
+
tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
52 |
+
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
53 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
54 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
55 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
56 |
+
|
57 |
+
def _add_lora_sgmv_cutlass_legacy(
|
58 |
+
y: torch.Tensor,
|
59 |
+
x: torch.Tensor,
|
60 |
+
wa_ptr: torch.Tensor,
|
61 |
+
wb_ptr: torch.Tensor,
|
62 |
+
s_start: torch.IntTensor,
|
63 |
+
s_end: torch.IntTensor,
|
64 |
+
layer_idx: int,
|
65 |
+
lora_rank: int,
|
66 |
+
):
|
67 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
68 |
+
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
69 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
70 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
71 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
72 |
+
|
73 |
+
def lora_a_sgmv_cutlass(
|
74 |
+
x: torch.Tensor,
|
75 |
+
tmp: torch.Tensor,
|
76 |
+
wa_ptr: torch.Tensor,
|
77 |
+
s_start: torch.IntTensor,
|
78 |
+
s_end: torch.IntTensor,
|
79 |
+
layer_idx: int,
|
80 |
+
lora_rank: int,
|
81 |
+
) -> torch.Tensor:
|
82 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
83 |
+
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
84 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
85 |
+
else:
|
86 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
87 |
+
return v
|
88 |
+
|
89 |
+
|
90 |
+
def lora_b_sgmv_cutlass(
|
91 |
+
y: torch.Tensor,
|
92 |
+
v: torch.Tensor,
|
93 |
+
tmp: torch.Tensor,
|
94 |
+
wb_ptr: torch.Tensor,
|
95 |
+
s_start: torch.IntTensor,
|
96 |
+
s_end: torch.IntTensor,
|
97 |
+
layer_idx: int,
|
98 |
+
):
|
99 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
100 |
+
|
101 |
+
def add_lora_a_bgmv(
|
102 |
+
v: torch.Tensor,
|
103 |
+
x: torch.Tensor,
|
104 |
+
wa_T_all: torch.Tensor,
|
105 |
+
indicies: torch.LongTensor,
|
106 |
+
layer_idx: int,
|
107 |
+
):
|
108 |
+
ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
109 |
+
|
110 |
+
|
111 |
+
def add_lora_b_bgmv(
|
112 |
+
y: torch.Tensor,
|
113 |
+
v: torch.Tensor,
|
114 |
+
wb_T_all: torch.Tensor,
|
115 |
+
indicies: torch.LongTensor,
|
116 |
+
layer_idx: int,
|
117 |
+
):
|
118 |
+
ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
122 |
+
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
123 |
+
# tensor parallelism will result in effective rank being divided by world_size,
|
124 |
+
# so we need to scale the min rank to offset that effect
|
125 |
+
min_rank = MIN_SGMV_RANK * world_size
|
126 |
+
return pad_to_min_rank(t, dim, min_rank)
|
127 |
+
|
128 |
+
def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
|
129 |
+
# if we're at or below the min rank, pad up to the min rank
|
130 |
+
# otherwise, pad to the nearest multiple of the block size
|
131 |
+
current_rank = t.size(dim)
|
132 |
+
target_rank = (
|
133 |
+
min_rank
|
134 |
+
if current_rank <= min_rank
|
135 |
+
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
136 |
+
)
|
137 |
+
if current_rank == target_rank:
|
138 |
+
return t
|
139 |
+
|
140 |
+
pad_size = target_rank - current_rank
|
141 |
+
|
142 |
+
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
143 |
+
pad = [0, 0] * t.dim()
|
144 |
+
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
145 |
+
pad = tuple(pad)
|
146 |
+
|
147 |
+
return F.pad(t, pad, mode="constant", value=0.0)
|
148 |
+
|
149 |
+
def use_cutlass_shrink(lora_rank: int) -> bool:
|
150 |
+
return lora_rank < MIN_RANK_CUSTOM
|
151 |
+
|
152 |
+
@lru_cache(maxsize=1)
|
153 |
+
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
154 |
+
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
155 |
+
|
156 |
+
@lru_cache(maxsize=32)
|
157 |
+
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
158 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(size)
|
159 |
+
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
160 |
+
|
161 |
+
def get_tmp_expand_size(size: int) -> int:
|
162 |
+
return ops.sgmv_cutlass_tmp_size(size)
|
163 |
+
|
164 |
+
|
165 |
+
def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
166 |
+
if use_cutlass_shrink(lora_rank):
|
167 |
+
tmp = get_tmp_tensor_for_size(nsegments, device)
|
168 |
+
return tmp, tmp
|
169 |
+
else:
|
170 |
+
tmp_shrink = get_tmp_tensor(device)
|
171 |
+
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
172 |
+
return tmp_shrink, tmp_expand
|
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _punica_sgmv_ad0ac7e_dirty
|
3 |
+
ops = torch.ops._punica_sgmv_ad0ac7e_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
|
build/torch27-cxx11-cu126-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b1526b236ab1acc48ece52360332f9e7fbf261e18e4f700b404aa6dbe45240e
|
3 |
+
size 14311416
|
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/__init__.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ._ops import ops
|
8 |
+
|
9 |
+
MIN_SGMV_RANK = 8
|
10 |
+
MIN_RANK_CUSTOM = 16
|
11 |
+
MAX_RANK_CUSTOM = 128
|
12 |
+
SGMV_BLOCK_SIZE = 16
|
13 |
+
BGMV_MAX_RANK = 128
|
14 |
+
|
15 |
+
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
16 |
+
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
17 |
+
return t.transpose(0, 1)
|
18 |
+
return t
|
19 |
+
|
20 |
+
def add_lora_sgmv_cutlass(
|
21 |
+
y: torch.Tensor,
|
22 |
+
x: torch.Tensor,
|
23 |
+
wa_ptr: torch.Tensor,
|
24 |
+
wb_ptr: torch.Tensor,
|
25 |
+
s_start: torch.Tensor,
|
26 |
+
s_end: torch.Tensor,
|
27 |
+
layer_idx: int,
|
28 |
+
lora_rank: int,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Semantics:
|
32 |
+
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
33 |
+
|
34 |
+
Args:
|
35 |
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
36 |
+
x: Shape: `[B, H1]`. Input vectors.
|
37 |
+
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
38 |
+
Weight matrix shape: `[num_layers, R, H1]`.
|
39 |
+
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
40 |
+
Weight matrix shape: `[num_layers, R, H2]`.
|
41 |
+
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
42 |
+
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
43 |
+
layer_idx: Layer index of the weight matrices.
|
44 |
+
"""
|
45 |
+
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
46 |
+
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
47 |
+
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
|
48 |
+
return
|
49 |
+
|
50 |
+
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
51 |
+
tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
52 |
+
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
53 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
54 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
55 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
56 |
+
|
57 |
+
def _add_lora_sgmv_cutlass_legacy(
|
58 |
+
y: torch.Tensor,
|
59 |
+
x: torch.Tensor,
|
60 |
+
wa_ptr: torch.Tensor,
|
61 |
+
wb_ptr: torch.Tensor,
|
62 |
+
s_start: torch.IntTensor,
|
63 |
+
s_end: torch.IntTensor,
|
64 |
+
layer_idx: int,
|
65 |
+
lora_rank: int,
|
66 |
+
):
|
67 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
68 |
+
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
69 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
70 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
71 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
72 |
+
|
73 |
+
def lora_a_sgmv_cutlass(
|
74 |
+
x: torch.Tensor,
|
75 |
+
tmp: torch.Tensor,
|
76 |
+
wa_ptr: torch.Tensor,
|
77 |
+
s_start: torch.IntTensor,
|
78 |
+
s_end: torch.IntTensor,
|
79 |
+
layer_idx: int,
|
80 |
+
lora_rank: int,
|
81 |
+
) -> torch.Tensor:
|
82 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
83 |
+
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
84 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
85 |
+
else:
|
86 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
87 |
+
return v
|
88 |
+
|
89 |
+
|
90 |
+
def lora_b_sgmv_cutlass(
|
91 |
+
y: torch.Tensor,
|
92 |
+
v: torch.Tensor,
|
93 |
+
tmp: torch.Tensor,
|
94 |
+
wb_ptr: torch.Tensor,
|
95 |
+
s_start: torch.IntTensor,
|
96 |
+
s_end: torch.IntTensor,
|
97 |
+
layer_idx: int,
|
98 |
+
):
|
99 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
100 |
+
|
101 |
+
def add_lora_a_bgmv(
|
102 |
+
v: torch.Tensor,
|
103 |
+
x: torch.Tensor,
|
104 |
+
wa_T_all: torch.Tensor,
|
105 |
+
indicies: torch.LongTensor,
|
106 |
+
layer_idx: int,
|
107 |
+
):
|
108 |
+
ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
109 |
+
|
110 |
+
|
111 |
+
def add_lora_b_bgmv(
|
112 |
+
y: torch.Tensor,
|
113 |
+
v: torch.Tensor,
|
114 |
+
wb_T_all: torch.Tensor,
|
115 |
+
indicies: torch.LongTensor,
|
116 |
+
layer_idx: int,
|
117 |
+
):
|
118 |
+
ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
122 |
+
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
123 |
+
# tensor parallelism will result in effective rank being divided by world_size,
|
124 |
+
# so we need to scale the min rank to offset that effect
|
125 |
+
min_rank = MIN_SGMV_RANK * world_size
|
126 |
+
return pad_to_min_rank(t, dim, min_rank)
|
127 |
+
|
128 |
+
def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
|
129 |
+
# if we're at or below the min rank, pad up to the min rank
|
130 |
+
# otherwise, pad to the nearest multiple of the block size
|
131 |
+
current_rank = t.size(dim)
|
132 |
+
target_rank = (
|
133 |
+
min_rank
|
134 |
+
if current_rank <= min_rank
|
135 |
+
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
136 |
+
)
|
137 |
+
if current_rank == target_rank:
|
138 |
+
return t
|
139 |
+
|
140 |
+
pad_size = target_rank - current_rank
|
141 |
+
|
142 |
+
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
143 |
+
pad = [0, 0] * t.dim()
|
144 |
+
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
145 |
+
pad = tuple(pad)
|
146 |
+
|
147 |
+
return F.pad(t, pad, mode="constant", value=0.0)
|
148 |
+
|
149 |
+
def use_cutlass_shrink(lora_rank: int) -> bool:
|
150 |
+
return lora_rank < MIN_RANK_CUSTOM
|
151 |
+
|
152 |
+
@lru_cache(maxsize=1)
|
153 |
+
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
154 |
+
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
155 |
+
|
156 |
+
@lru_cache(maxsize=32)
|
157 |
+
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
158 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(size)
|
159 |
+
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
160 |
+
|
161 |
+
def get_tmp_expand_size(size: int) -> int:
|
162 |
+
return ops.sgmv_cutlass_tmp_size(size)
|
163 |
+
|
164 |
+
|
165 |
+
def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
166 |
+
if use_cutlass_shrink(lora_rank):
|
167 |
+
tmp = get_tmp_tensor_for_size(nsegments, device)
|
168 |
+
return tmp, tmp
|
169 |
+
else:
|
170 |
+
tmp_shrink = get_tmp_tensor(device)
|
171 |
+
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
172 |
+
return tmp_shrink, tmp_expand
|
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _punica_sgmv_ad0ac7e_dirty
|
3 |
+
ops = torch.ops._punica_sgmv_ad0ac7e_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_punica_sgmv_ad0ac7e_dirty::{op_name}"
|
build/torch27-cxx11-cu128-aarch64-linux/punica_sgmv/_punica_sgmv_ad0ac7e_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1681b08e9b39010e07ed293cd48fa54910a8811c01bd06acc235742660efb766
|
3 |
+
size 22831040
|