drbh
commited on
Commit
·
3224250
1
Parent(s):
a585153
feat: vendor grouped gemm
Browse files- build.toml +4 -0
- csrc/grouped_gemm/fill_arguments.cuh +141 -0
- csrc/grouped_gemm/grouped_gemm.cu +567 -0
- csrc/grouped_gemm/grouped_gemm.h +20 -0
- csrc/grouped_gemm/ops.cu +11 -0
- tests/ops_test.py +170 -0
- tests/test_gg.py +57 -0
- torch-ext/megablocks/__init__.py +9 -5
- torch-ext/megablocks/grouped_gemm/__init__.py +2 -0
- torch-ext/megablocks/grouped_gemm/backend.py +32 -0
- torch-ext/megablocks/grouped_gemm/ops.py +33 -0
- torch-ext/megablocks/grouped_gemm_util.py +8 -3
- torch-ext/megablocks/layers/__init__.py +1 -1
- torch-ext/torch_binding.cpp +12 -0
build.toml
CHANGED
|
@@ -35,4 +35,8 @@ src = [
|
|
| 35 |
"csrc/new_replicate.h",
|
| 36 |
"csrc/new_sort.h",
|
| 37 |
"csrc/new_sort.cu",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
]
|
|
|
|
| 35 |
"csrc/new_replicate.h",
|
| 36 |
"csrc/new_sort.h",
|
| 37 |
"csrc/new_sort.cu",
|
| 38 |
+
# vendored grouped gemm
|
| 39 |
+
"csrc/grouped_gemm/fill_arguments.cuh",
|
| 40 |
+
"csrc/grouped_gemm/grouped_gemm.cu",
|
| 41 |
+
"csrc/grouped_gemm/grouped_gemm.h",
|
| 42 |
]
|
csrc/grouped_gemm/fill_arguments.cuh
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
| 4 |
+
#include <cub/cub.cuh>
|
| 5 |
+
#include <cutlass/bfloat16.h>
|
| 6 |
+
#include <cutlass/gemm_coord.h>
|
| 7 |
+
|
| 8 |
+
namespace grouped_gemm {
|
| 9 |
+
|
| 10 |
+
constexpr int kDynamicDim = -1;
|
| 11 |
+
constexpr int kMaxExperts = 512;
|
| 12 |
+
|
| 13 |
+
struct GemmProblem {
|
| 14 |
+
::cutlass::gemm::GemmCoord dims;
|
| 15 |
+
int64_t lda, ldb, ldc;
|
| 16 |
+
// All offsets are in elements.
|
| 17 |
+
int64_t a_offset, b_offset, c_offset;
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
// TODO: revisit `ExtractGemmProblemK` struct
|
| 21 |
+
// struct ExtractGemmProblemK {
|
| 22 |
+
// __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
|
| 23 |
+
// return {problem.dims.k()};
|
| 24 |
+
// }
|
| 25 |
+
// };
|
| 26 |
+
|
| 27 |
+
template <
|
| 28 |
+
// If `k` is dynamic, we sort the problems by `k` in descending order.
|
| 29 |
+
// Otherwise, `m` is dynamic, and no sorting happens.
|
| 30 |
+
bool kDynamicK,
|
| 31 |
+
typename ElementA, typename ElementB, typename ElementC,
|
| 32 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
| 33 |
+
typename Args
|
| 34 |
+
>
|
| 35 |
+
__global__ void FillArguments(
|
| 36 |
+
int num_experts, const int64_t* batch_sizes,
|
| 37 |
+
ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
|
| 38 |
+
Args args, ::cutlass::gemm::GemmCoord dims
|
| 39 |
+
) {
|
| 40 |
+
const int expert_idx = threadIdx.x;
|
| 41 |
+
const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
|
| 42 |
+
|
| 43 |
+
if (kDynamicK) {
|
| 44 |
+
assert(dims.k() == kDynamicDim);
|
| 45 |
+
dims.k() = batch_size;
|
| 46 |
+
} else {
|
| 47 |
+
assert(dims.m() == kDynamicDim);
|
| 48 |
+
dims.m() = batch_size;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
using BlockScan = cub::BlockScan<int, kMaxExperts>;
|
| 52 |
+
using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
|
| 53 |
+
|
| 54 |
+
union SharedMemory {
|
| 55 |
+
typename BlockScan::TempStorage scan_storage;
|
| 56 |
+
typename BlockSort::TempStorage sort_storage;
|
| 57 |
+
};
|
| 58 |
+
__shared__ SharedMemory shared_memory;
|
| 59 |
+
|
| 60 |
+
int dynamic_dim = kDynamicK ? dims.k() : dims.m();
|
| 61 |
+
int dynamic_dim_cumsum;
|
| 62 |
+
BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
|
| 63 |
+
__syncthreads();
|
| 64 |
+
|
| 65 |
+
// We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
|
| 66 |
+
// `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
|
| 67 |
+
GemmProblem problem[1] = {
|
| 68 |
+
GemmProblem {
|
| 69 |
+
.dims = dims,
|
| 70 |
+
.lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
|
| 71 |
+
.ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
|
| 72 |
+
.ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
|
| 73 |
+
.a_offset = kDynamicK
|
| 74 |
+
? (dims.m() * dynamic_dim_cumsum)
|
| 75 |
+
: (dynamic_dim_cumsum * dims.k()),
|
| 76 |
+
.b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
|
| 77 |
+
.c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
|
| 78 |
+
},
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
if constexpr (kDynamicK) {
|
| 82 |
+
// Sort by k dimension in descending order
|
| 83 |
+
// We need to extract the key (k value) for sorting
|
| 84 |
+
int k_keys[1] = { problem[0].dims.k() };
|
| 85 |
+
|
| 86 |
+
BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
|
| 87 |
+
|
| 88 |
+
// TODO: revisit original impl without `__syncthreads()`
|
| 89 |
+
// BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
|
| 90 |
+
// Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
|
| 91 |
+
// > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
|
| 92 |
+
// > is **to be reused or repurposed**.
|
| 93 |
+
// We don't need `__syncthreads()` here, since we don't do either of these things.
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
if (expert_idx < num_experts) {
|
| 97 |
+
args.problem_sizes[expert_idx] = problem[0].dims;
|
| 98 |
+
args.lda[expert_idx] = problem[0].lda;
|
| 99 |
+
args.ldb[expert_idx] = problem[0].ldb;
|
| 100 |
+
args.ldc[expert_idx] = problem[0].ldc;
|
| 101 |
+
|
| 102 |
+
args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
|
| 103 |
+
args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
|
| 104 |
+
args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename Args>
|
| 109 |
+
__global__ void ZeroOutK0Outputs(int num_experts, Args args) {
|
| 110 |
+
const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
|
| 111 |
+
const int64_t delta = (int64_t)gridDim.x * blockDim.x;
|
| 112 |
+
for (int ei = 0; ei < num_experts; ++ei) {
|
| 113 |
+
auto& dims = args.problem_sizes[ei];
|
| 114 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
| 115 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
| 116 |
+
// * (here) set the output to zero
|
| 117 |
+
// * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
| 118 |
+
if (dims.k() == 0) {
|
| 119 |
+
// Assume packed layout, run a grid-strided loop over the output.
|
| 120 |
+
int64_t total_elems = (int64_t)dims.m() * dims.n();
|
| 121 |
+
auto* out = args.ptr_C[ei];
|
| 122 |
+
for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
|
| 123 |
+
out[idx] = {};
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <typename Args>
|
| 130 |
+
__global__ void IgnoreK0Problems(int num_experts, Args args) {
|
| 131 |
+
const int expert_idx = threadIdx.x;
|
| 132 |
+
if (expert_idx < num_experts) {
|
| 133 |
+
auto& dims = args.problem_sizes[expert_idx];
|
| 134 |
+
if (dims.k() == 0) {
|
| 135 |
+
dims.m() = 0;
|
| 136 |
+
dims.n() = 0;
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.cu
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "grouped_gemm.h"
|
| 2 |
+
#include "fill_arguments.cuh"
|
| 3 |
+
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
| 6 |
+
#include <c10/util/BFloat16.h>
|
| 7 |
+
#include <c10/cuda/CUDAStream.h>
|
| 8 |
+
#include <cub/cub.cuh>
|
| 9 |
+
#include <torch/torch.h>
|
| 10 |
+
|
| 11 |
+
#include "cutlass/bfloat16.h"
|
| 12 |
+
#include "cutlass/complex.h"
|
| 13 |
+
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
| 14 |
+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
| 15 |
+
#include "cutlass/gemm/device/gemm_grouped.h"
|
| 16 |
+
|
| 17 |
+
#include <type_traits>
|
| 18 |
+
|
| 19 |
+
namespace grouped_gemm {
|
| 20 |
+
|
| 21 |
+
#define CUDA_CALL(code) \
|
| 22 |
+
do { \
|
| 23 |
+
cudaError_t status = code; \
|
| 24 |
+
std::string err = cudaGetErrorString(status); \
|
| 25 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
| 26 |
+
} while (0)
|
| 27 |
+
|
| 28 |
+
#define CUBLAS_CALL(code) \
|
| 29 |
+
do { \
|
| 30 |
+
cublasStatus_t status = code; \
|
| 31 |
+
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
|
| 32 |
+
} while (0)
|
| 33 |
+
|
| 34 |
+
#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
|
| 35 |
+
#define GROUPED_GEMM_STRINGIFY(x) \
|
| 36 |
+
GROUPED_GEMM_STRINGIFY_HELPER(x)
|
| 37 |
+
|
| 38 |
+
template <bool trans>
|
| 39 |
+
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
|
| 40 |
+
|
| 41 |
+
using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
|
| 42 |
+
::cutlass::arch::OpClassTensorOp,
|
| 43 |
+
::cutlass::arch::Sm80,
|
| 44 |
+
::cutlass::bfloat16_t,
|
| 45 |
+
::cutlass::bfloat16_t,
|
| 46 |
+
::cutlass::bfloat16_t,
|
| 47 |
+
float
|
| 48 |
+
>;
|
| 49 |
+
|
| 50 |
+
// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
|
| 51 |
+
template <bool trans_a, bool trans_b>
|
| 52 |
+
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
| 53 |
+
// A operand.
|
| 54 |
+
::cutlass::bfloat16_t,
|
| 55 |
+
GroupedGemmInputLayout<trans_a>,
|
| 56 |
+
::cutlass::ComplexTransform::kNone,
|
| 57 |
+
GroupedGemmConfig::kAlignmentA,
|
| 58 |
+
// B operand.
|
| 59 |
+
::cutlass::bfloat16_t,
|
| 60 |
+
GroupedGemmInputLayout<trans_b>,
|
| 61 |
+
::cutlass::ComplexTransform::kNone,
|
| 62 |
+
GroupedGemmConfig::kAlignmentB,
|
| 63 |
+
// C operand.
|
| 64 |
+
::cutlass::bfloat16_t,
|
| 65 |
+
::cutlass::layout::RowMajor,
|
| 66 |
+
float,
|
| 67 |
+
::cutlass::arch::OpClassTensorOp,
|
| 68 |
+
::cutlass::arch::Sm80,
|
| 69 |
+
GroupedGemmConfig::ThreadblockShape,
|
| 70 |
+
GroupedGemmConfig::WarpShape,
|
| 71 |
+
GroupedGemmConfig::InstructionShape,
|
| 72 |
+
GroupedGemmConfig::EpilogueOutputOp,
|
| 73 |
+
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
| 74 |
+
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
| 75 |
+
// is unused within the kernel.
|
| 76 |
+
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
| 77 |
+
// TODO(tgale): Tune this for SM90.
|
| 78 |
+
GroupedGemmConfig::kStages>::GemmKernel;
|
| 79 |
+
|
| 80 |
+
template <bool trans_a, bool trans_b>
|
| 81 |
+
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
|
| 82 |
+
|
| 83 |
+
template <typename T>
|
| 84 |
+
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
|
| 85 |
+
size_t bytes = x.size() * sizeof(T);
|
| 86 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
|
| 87 |
+
torch::Tensor out = torch::empty(bytes, options);
|
| 88 |
+
|
| 89 |
+
CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
|
| 90 |
+
x.data(), bytes,
|
| 91 |
+
cudaMemcpyHostToDevice,
|
| 92 |
+
c10::cuda::getCurrentCUDAStream()));
|
| 93 |
+
return out;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <typename T>
|
| 97 |
+
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
|
| 98 |
+
// For now, simply create a copy of the data and then copy over to the original.
|
| 99 |
+
std::vector<T> copy(data, data + indices.size());
|
| 100 |
+
for (size_t i = 0; i < indices.size(); ++i) {
|
| 101 |
+
data[i] = copy.at(indices[i]);
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
template <typename T>
|
| 106 |
+
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
|
| 107 |
+
return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
struct RawGemmArguments {
|
| 111 |
+
torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
|
| 112 |
+
int threadblock_count{};
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
template <
|
| 116 |
+
typename Gemm,
|
| 117 |
+
typename ElementA, typename ElementB, typename ElementC
|
| 118 |
+
>
|
| 119 |
+
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
|
| 120 |
+
TORCH_CHECK(
|
| 121 |
+
num_experts <= kMaxExperts,
|
| 122 |
+
"At most ", kMaxExperts,
|
| 123 |
+
" experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
|
| 124 |
+
);
|
| 125 |
+
|
| 126 |
+
return RawGemmArguments {
|
| 127 |
+
.lda = TypedEmpty<int64_t>(num_experts, device),
|
| 128 |
+
.ldb = TypedEmpty<int64_t>(num_experts, device),
|
| 129 |
+
.ldc = TypedEmpty<int64_t>(num_experts, device),
|
| 130 |
+
.ptr_a = TypedEmpty<ElementA*>(num_experts, device),
|
| 131 |
+
.ptr_b = TypedEmpty<ElementB*>(num_experts, device),
|
| 132 |
+
.ptr_c = TypedEmpty<ElementC*>(num_experts, device),
|
| 133 |
+
.problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
|
| 134 |
+
|
| 135 |
+
// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
|
| 136 |
+
.threadblock_count = Gemm::sufficient(),
|
| 137 |
+
};
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <
|
| 141 |
+
bool kDynamicK,
|
| 142 |
+
typename Gemm,
|
| 143 |
+
typename ElementA, typename ElementB, typename ElementC,
|
| 144 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
| 145 |
+
>
|
| 146 |
+
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
|
| 147 |
+
torch::Tensor b,
|
| 148 |
+
torch::Tensor c,
|
| 149 |
+
torch::Tensor batch_sizes,
|
| 150 |
+
::cutlass::gemm::GemmCoord coord_template,
|
| 151 |
+
int64_t num_experts) {
|
| 152 |
+
std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
|
| 153 |
+
|
| 154 |
+
// Create the host arrays of leading dimension data and pointer data.
|
| 155 |
+
std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
|
| 156 |
+
int64_t elements_a = 0, elements_b = 0, elements_c = 0;
|
| 157 |
+
|
| 158 |
+
std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
|
| 159 |
+
|
| 160 |
+
for (int i = 0; i < num_experts; ++i) {
|
| 161 |
+
auto& problem = problem_sizes_host[i];
|
| 162 |
+
problem = coord_template;
|
| 163 |
+
(kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
|
| 164 |
+
|
| 165 |
+
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
|
| 166 |
+
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
|
| 167 |
+
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
| 168 |
+
|
| 169 |
+
ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
|
| 170 |
+
ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
|
| 171 |
+
ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
|
| 172 |
+
|
| 173 |
+
elements_a += problem.m() * problem.k();
|
| 174 |
+
elements_b += problem.k() * problem.n();
|
| 175 |
+
elements_c += problem.m() * problem.n();
|
| 176 |
+
|
| 177 |
+
if (problem.k() == 0) {
|
| 178 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
| 179 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
| 180 |
+
// * set the output to zero with `cudaMemsetAsync()`
|
| 181 |
+
// * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
| 182 |
+
CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
|
| 183 |
+
0,
|
| 184 |
+
problem.m() * problem.n() * sizeof(ElementC),
|
| 185 |
+
c10::cuda::getCurrentCUDAStream()));
|
| 186 |
+
|
| 187 |
+
problem.m() = 0;
|
| 188 |
+
problem.n() = 0;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// Only sort problems when K are different
|
| 193 |
+
if (kDynamicK) {
|
| 194 |
+
std::vector<size_t> indices(num_experts);
|
| 195 |
+
std::iota(indices.begin(), indices.end(), 0);
|
| 196 |
+
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
|
| 197 |
+
return problem_sizes_host[i].k() > problem_sizes_host[j].k();
|
| 198 |
+
});
|
| 199 |
+
|
| 200 |
+
ReorderArray(problem_sizes_host.data(), indices);
|
| 201 |
+
ReorderArray(lda_host.data(), indices);
|
| 202 |
+
ReorderArray(ldb_host.data(), indices);
|
| 203 |
+
ReorderArray(ldc_host.data(), indices);
|
| 204 |
+
ReorderArray(ptr_a_host.data(), indices);
|
| 205 |
+
ReorderArray(ptr_b_host.data(), indices);
|
| 206 |
+
ReorderArray(ptr_c_host.data(), indices);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// Copy the problem sizes, pointers and leading dimension data to the device.
|
| 210 |
+
return RawGemmArguments {
|
| 211 |
+
.lda = CopyToDevice(lda_host, a.device()),
|
| 212 |
+
.ldb = CopyToDevice(ldb_host, a.device()),
|
| 213 |
+
.ldc = CopyToDevice(ldc_host, a.device()),
|
| 214 |
+
.ptr_a = CopyToDevice(ptr_a_host, a.device()),
|
| 215 |
+
.ptr_b = CopyToDevice(ptr_b_host, a.device()),
|
| 216 |
+
.ptr_c = CopyToDevice(ptr_c_host, a.device()),
|
| 217 |
+
.problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
|
| 218 |
+
|
| 219 |
+
// We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
|
| 220 |
+
.threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
|
| 221 |
+
};
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
template <
|
| 225 |
+
bool kDynamicK,
|
| 226 |
+
typename Gemm,
|
| 227 |
+
typename ElementA, typename ElementB, typename ElementC,
|
| 228 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
| 229 |
+
>
|
| 230 |
+
typename Gemm::Arguments MakeArguments(torch::Tensor a,
|
| 231 |
+
torch::Tensor b,
|
| 232 |
+
torch::Tensor c,
|
| 233 |
+
torch::Tensor batch_sizes,
|
| 234 |
+
::cutlass::gemm::GemmCoord coord_template,
|
| 235 |
+
int64_t num_experts) {
|
| 236 |
+
RawGemmArguments raw_args;
|
| 237 |
+
if (batch_sizes.is_cuda()) {
|
| 238 |
+
raw_args = MakeArgumentsOnDevice<
|
| 239 |
+
Gemm, ElementA, ElementB, ElementC
|
| 240 |
+
>(num_experts, a.device());
|
| 241 |
+
} else {
|
| 242 |
+
raw_args = MakeArgumentsOnHost<
|
| 243 |
+
kDynamicK,
|
| 244 |
+
Gemm,
|
| 245 |
+
ElementA, ElementB, ElementC,
|
| 246 |
+
LayoutA, LayoutB, LayoutC
|
| 247 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
|
| 251 |
+
// Validate the result.
|
| 252 |
+
if (!raw_args.threadblock_count) {
|
| 253 |
+
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
|
| 257 |
+
// We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
|
| 258 |
+
// so we can safely pass `nullptr` for `host_problem_sizes`.
|
| 259 |
+
// TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
|
| 260 |
+
// know the problem dimensions on the host.
|
| 261 |
+
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
|
| 262 |
+
(int)num_experts,
|
| 263 |
+
(int)raw_args.threadblock_count,
|
| 264 |
+
epilogue_op,
|
| 265 |
+
(ElementA**)raw_args.ptr_a.data_ptr(),
|
| 266 |
+
(ElementB**)raw_args.ptr_b.data_ptr(),
|
| 267 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
| 268 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
| 269 |
+
/*lda=*/(int64_t*)raw_args.lda.data_ptr(),
|
| 270 |
+
/*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
|
| 271 |
+
/*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
|
| 272 |
+
/*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
|
| 273 |
+
/*host_problem_sizes=*/nullptr);
|
| 274 |
+
return arguments;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
template <
|
| 278 |
+
bool trans_a,
|
| 279 |
+
typename ElementA, typename ElementB, typename ElementC,
|
| 280 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
| 281 |
+
typename Arguments
|
| 282 |
+
>
|
| 283 |
+
void FillCutlassArguments(int num_experts,
|
| 284 |
+
torch::Tensor batch_sizes,
|
| 285 |
+
torch::Tensor a,
|
| 286 |
+
torch::Tensor b,
|
| 287 |
+
torch::Tensor c,
|
| 288 |
+
const Arguments& arguments,
|
| 289 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
| 290 |
+
// Convert the batch sizes to the format CUTLASS understands on the device.
|
| 291 |
+
// Use a single block here because:
|
| 292 |
+
// * the number of elements to process is microscopically small
|
| 293 |
+
// * we don't need any additional global memory
|
| 294 |
+
FillArguments<
|
| 295 |
+
/*kDynamicK*/trans_a,
|
| 296 |
+
ElementA, ElementB, ElementC,
|
| 297 |
+
LayoutA, LayoutB, LayoutC
|
| 298 |
+
><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
| 299 |
+
num_experts, batch_sizes.data_ptr<int64_t>(),
|
| 300 |
+
(ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
|
| 301 |
+
arguments, coord_template
|
| 302 |
+
);
|
| 303 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
template <typename Args>
|
| 307 |
+
void RemoveK0Problems(int num_experts, const Args& arguments) {
|
| 308 |
+
// For zeroing out the outputs (which might be arbitrarily large), we want to use
|
| 309 |
+
// as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
|
| 310 |
+
// `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
|
| 311 |
+
// should be a good approximation for this.
|
| 312 |
+
// When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
|
| 313 |
+
ZeroOutK0Outputs<><<<
|
| 314 |
+
arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
|
| 315 |
+
>>>(
|
| 316 |
+
num_experts, arguments
|
| 317 |
+
);
|
| 318 |
+
IgnoreK0Problems<><<<
|
| 319 |
+
1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
|
| 320 |
+
>>>(
|
| 321 |
+
num_experts, arguments
|
| 322 |
+
);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
template <bool trans_a, bool trans_b>
|
| 326 |
+
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
|
| 327 |
+
torch::Tensor b,
|
| 328 |
+
torch::Tensor c,
|
| 329 |
+
torch::Tensor batch_sizes,
|
| 330 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
| 331 |
+
using Gemm = GemmGrouped<trans_a, trans_b>;
|
| 332 |
+
using LayoutA = typename Gemm::LayoutA;
|
| 333 |
+
using LayoutB = typename Gemm::LayoutB;
|
| 334 |
+
using LayoutC = typename Gemm::LayoutC;
|
| 335 |
+
|
| 336 |
+
using ElementA = typename Gemm::ElementA;
|
| 337 |
+
using ElementB = typename Gemm::ElementB;
|
| 338 |
+
using ElementC = typename Gemm::ElementC;
|
| 339 |
+
|
| 340 |
+
Gemm gemm;
|
| 341 |
+
int64_t num_experts = batch_sizes.size(0);
|
| 342 |
+
auto arguments = MakeArguments<
|
| 343 |
+
/*kDynamicK*/trans_a,
|
| 344 |
+
Gemm,
|
| 345 |
+
ElementA, ElementB, ElementC,
|
| 346 |
+
LayoutA, LayoutB, LayoutC
|
| 347 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
| 348 |
+
int64_t workspace_size = gemm.get_workspace_size(arguments);
|
| 349 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
|
| 350 |
+
torch::Tensor workspace = torch::empty(workspace_size, options);
|
| 351 |
+
|
| 352 |
+
if (batch_sizes.is_cuda()) {
|
| 353 |
+
FillCutlassArguments<
|
| 354 |
+
trans_a,
|
| 355 |
+
ElementA, ElementB, ElementC,
|
| 356 |
+
LayoutA, LayoutB, LayoutC
|
| 357 |
+
>(num_experts, batch_sizes, a, b, c, arguments, coord_template);
|
| 358 |
+
|
| 359 |
+
RemoveK0Problems<>(num_experts, arguments);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// Initialize the kernel.
|
| 363 |
+
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
|
| 364 |
+
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
// Execute the kernel in the current stream.
|
| 368 |
+
if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
|
| 369 |
+
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
|
| 370 |
+
}
|
| 371 |
+
return c;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
|
| 375 |
+
c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
|
| 376 |
+
c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
|
| 377 |
+
int m = trans_b ? b_rows : b_cols;
|
| 378 |
+
int k = trans_b ? b_cols : b_rows;
|
| 379 |
+
int n = trans_a ? a_cols : a_rows;
|
| 380 |
+
|
| 381 |
+
int lda = trans_a ? n : k;
|
| 382 |
+
int ldb = trans_b ? k : m;
|
| 383 |
+
cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
| 384 |
+
cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
|
| 385 |
+
|
| 386 |
+
float alpha = 1.0, beta = 0.0;
|
| 387 |
+
CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
|
| 388 |
+
transpose_b, transpose_a,
|
| 389 |
+
m, n, k, &alpha,
|
| 390 |
+
b, CUDA_R_16BF, ldb,
|
| 391 |
+
a, CUDA_R_16BF, lda,
|
| 392 |
+
&beta,
|
| 393 |
+
c, CUDA_R_16BF, c_cols, CUDA_R_32F,
|
| 394 |
+
CUBLAS_GEMM_DEFAULT));
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
void CublasGroupedGemm(torch::Tensor a,
|
| 398 |
+
torch::Tensor b,
|
| 399 |
+
torch::Tensor c,
|
| 400 |
+
torch::Tensor batch_sizes,
|
| 401 |
+
bool trans_b) {
|
| 402 |
+
int64_t bs = batch_sizes.size(0), k = a.size(1);
|
| 403 |
+
int64_t n = trans_b ? b.size(1) : b.size(2);
|
| 404 |
+
int64_t b_rows = b.size(1), b_cols = b.size(2);
|
| 405 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
| 406 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
| 407 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
| 408 |
+
for (int i = 0; i < bs; ++i) {
|
| 409 |
+
int64_t m = batch_sizes.data_ptr<int64_t>()[i];
|
| 410 |
+
CublasGemm(a_ptr, m, k, /*trans_a=*/false,
|
| 411 |
+
b_ptr, b_rows, b_cols, trans_b,
|
| 412 |
+
c_ptr, m, n);
|
| 413 |
+
a_ptr += m * k;
|
| 414 |
+
b_ptr += b_rows * b_cols;
|
| 415 |
+
c_ptr += m * n;
|
| 416 |
+
}
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
void CublasGroupedGemmVariableK(torch::Tensor a,
|
| 420 |
+
torch::Tensor b,
|
| 421 |
+
torch::Tensor c,
|
| 422 |
+
torch::Tensor batch_sizes) {
|
| 423 |
+
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
|
| 424 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
| 425 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
| 426 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
| 427 |
+
for (int i = 0; i < bs; ++i) {
|
| 428 |
+
int64_t k = batch_sizes.data_ptr<int64_t>()[i];
|
| 429 |
+
CublasGemm(a_ptr, k, m, /*trans_a=*/true,
|
| 430 |
+
b_ptr, k, n, /*trans_b=*/false,
|
| 431 |
+
c_ptr, m, n);
|
| 432 |
+
a_ptr += k * m;
|
| 433 |
+
b_ptr += k * n;
|
| 434 |
+
c_ptr += m * n;
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
void GroupedGemmVariableK(torch::Tensor a,
|
| 439 |
+
torch::Tensor b,
|
| 440 |
+
torch::Tensor c,
|
| 441 |
+
torch::Tensor batch_sizes) {
|
| 442 |
+
// We expected a CUDA tensor with two dimensions and shape
|
| 443 |
+
// (tokens, hidden_out) for 'b'.
|
| 444 |
+
TORCH_CHECK(b.is_cuda());
|
| 445 |
+
TORCH_CHECK(b.ndimension() == 2);
|
| 446 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
| 447 |
+
|
| 448 |
+
// Validate the dimensions.
|
| 449 |
+
int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
|
| 450 |
+
int64_t m = a.size(1), n = b.size(1);
|
| 451 |
+
|
| 452 |
+
// Validate that we have the same contraction dimension.
|
| 453 |
+
TORCH_CHECK(tokens == b.size(0));
|
| 454 |
+
|
| 455 |
+
// Validate the output shape.
|
| 456 |
+
TORCH_CHECK(c.is_cuda());
|
| 457 |
+
TORCH_CHECK(c.ndimension() == 3);
|
| 458 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
| 459 |
+
TORCH_CHECK(c.size(0) == num_experts);
|
| 460 |
+
TORCH_CHECK(c.size(1) == m);
|
| 461 |
+
TORCH_CHECK(c.size(2) == n);
|
| 462 |
+
|
| 463 |
+
// Run the computation.
|
| 464 |
+
CublasGroupedGemmVariableK(a, b, c, batch_sizes);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
|
| 468 |
+
// assumed to be batched with fixed sized batches.
|
| 469 |
+
//
|
| 470 |
+
// TODO(tgale): Validate alignment is true for every batch element.
|
| 471 |
+
void GroupedGemm(torch::Tensor a,
|
| 472 |
+
torch::Tensor b,
|
| 473 |
+
torch::Tensor c,
|
| 474 |
+
torch::Tensor batch_sizes,
|
| 475 |
+
bool trans_a, bool trans_b) {
|
| 476 |
+
// NOTE: We only support 'trans_a' or 'trans_b', not both.
|
| 477 |
+
TORCH_CHECK(!(trans_a && trans_b));
|
| 478 |
+
|
| 479 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
| 480 |
+
// No way to run cuBLAS kernels if the problem dimensions are not known on the host.
|
| 481 |
+
TORCH_CHECK(batch_sizes.is_cpu());
|
| 482 |
+
#else
|
| 483 |
+
// CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
|
| 484 |
+
TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
|
| 485 |
+
#endif
|
| 486 |
+
TORCH_CHECK(batch_sizes.ndimension() == 1);
|
| 487 |
+
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
|
| 488 |
+
|
| 489 |
+
// We expected a CUDA tensor with two dimensions and shape
|
| 490 |
+
// (tokens, hidden_in) for 'a'.
|
| 491 |
+
TORCH_CHECK(a.is_cuda());
|
| 492 |
+
TORCH_CHECK(a.ndimension() == 2);
|
| 493 |
+
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
|
| 494 |
+
|
| 495 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
| 496 |
+
if (trans_a) {
|
| 497 |
+
// If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
|
| 498 |
+
// for the rest of the op.
|
| 499 |
+
GroupedGemmVariableK(a, b, c, batch_sizes);
|
| 500 |
+
return;
|
| 501 |
+
}
|
| 502 |
+
#endif
|
| 503 |
+
|
| 504 |
+
TORCH_CHECK(b.is_cuda());
|
| 505 |
+
TORCH_CHECK(c.is_cuda());
|
| 506 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
| 507 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
| 508 |
+
|
| 509 |
+
// The expected shapes of 'b' and 'c' are:
|
| 510 |
+
// * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
|
| 511 |
+
// * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
|
| 512 |
+
// * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
|
| 513 |
+
size_t hidden_in{}, hidden_out{};
|
| 514 |
+
if (trans_a) {
|
| 515 |
+
hidden_in = a.size(1);
|
| 516 |
+
hidden_out = b.size(1);
|
| 517 |
+
|
| 518 |
+
TORCH_CHECK(b.ndimension() == 2);
|
| 519 |
+
TORCH_CHECK(c.ndimension() == 3);
|
| 520 |
+
TORCH_CHECK(b.size(0) == a.size(0));
|
| 521 |
+
TORCH_CHECK(c.size(0) == batch_sizes.size(0));
|
| 522 |
+
TORCH_CHECK(c.size(1) == hidden_in);
|
| 523 |
+
TORCH_CHECK(c.size(2) == hidden_out);
|
| 524 |
+
} else {
|
| 525 |
+
TORCH_CHECK(b.ndimension() == 3);
|
| 526 |
+
TORCH_CHECK(c.ndimension() == 2);
|
| 527 |
+
|
| 528 |
+
// Validate the contraction dimensions match.
|
| 529 |
+
int64_t tokens = a.size(0), num_experts = b.size(0);
|
| 530 |
+
hidden_in = trans_b ? b.size(2) : b.size(1);
|
| 531 |
+
hidden_out = trans_b ? b.size(1) : b.size(2);
|
| 532 |
+
TORCH_CHECK(hidden_in == a.size(1));
|
| 533 |
+
|
| 534 |
+
// Validate that we have one size per expert.
|
| 535 |
+
TORCH_CHECK(batch_sizes.size(0) == num_experts);
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
// NOTE: We support transposition through the 'trans_b' flag.
|
| 539 |
+
TORCH_CHECK(a.is_contiguous());
|
| 540 |
+
TORCH_CHECK(b.is_contiguous());
|
| 541 |
+
TORCH_CHECK(c.is_contiguous());
|
| 542 |
+
|
| 543 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
| 544 |
+
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
|
| 545 |
+
return;
|
| 546 |
+
#else
|
| 547 |
+
// The `coord_template` argument contains `kDynamicDim` as one of its dimensions
|
| 548 |
+
// as a placeholder. This placeholder is later expanded into the actual dimension
|
| 549 |
+
// for every element of the batch, either on the host or on the device
|
| 550 |
+
// (if we can't do in on the host).
|
| 551 |
+
const auto coord_template = trans_a
|
| 552 |
+
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
|
| 553 |
+
: cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
|
| 554 |
+
if (trans_a) {
|
| 555 |
+
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
|
| 556 |
+
return;
|
| 557 |
+
}
|
| 558 |
+
if (trans_b) {
|
| 559 |
+
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
|
| 560 |
+
return;
|
| 561 |
+
}
|
| 562 |
+
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
|
| 563 |
+
return;
|
| 564 |
+
#endif
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// // Set default if not already defined
|
| 4 |
+
// #ifndef GROUPED_GEMM_CUTLASS
|
| 5 |
+
// #define GROUPED_GEMM_CUTLASS 0
|
| 6 |
+
// #endif
|
| 7 |
+
|
| 8 |
+
// #include <torch/extension.h>
|
| 9 |
+
#include <torch/torch.h>
|
| 10 |
+
|
| 11 |
+
namespace grouped_gemm {
|
| 12 |
+
|
| 13 |
+
void GroupedGemm(torch::Tensor a,
|
| 14 |
+
torch::Tensor b,
|
| 15 |
+
torch::Tensor c,
|
| 16 |
+
torch::Tensor batch_sizes,
|
| 17 |
+
bool trans_a, bool trans_b);
|
| 18 |
+
|
| 19 |
+
} // namespace grouped_gemm
|
| 20 |
+
|
csrc/grouped_gemm/ops.cu
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "grouped_gemm.h"
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
namespace grouped_gemm {
|
| 6 |
+
|
| 7 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 8 |
+
m.def("gmm", &GroupedGemm, "Grouped GEMM.");
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
} // namespace grouped_gemm
|
tests/ops_test.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import itertools
|
| 3 |
+
|
| 4 |
+
from absl.testing import parameterized
|
| 5 |
+
import megablocks
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def allclose(x, y, pct=2.0):
|
| 11 |
+
mask = torch.isclose(x, y, rtol=1e-5)
|
| 12 |
+
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
|
| 13 |
+
if pct_diff > pct:
|
| 14 |
+
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
|
| 15 |
+
print("{:.2f}% of values not close.".format(pct_diff))
|
| 16 |
+
return False
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def add_flags(x):
|
| 21 |
+
out = []
|
| 22 |
+
for y in x:
|
| 23 |
+
for trans_b in (False, True):
|
| 24 |
+
out.append(y + (trans_b, False))
|
| 25 |
+
|
| 26 |
+
# TODO: Revisit enabling batch_sizes_on_device
|
| 27 |
+
# for batch_sizes_on_device in (False, True):
|
| 28 |
+
# out.append(y + (trans_b, batch_sizes_on_device))
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
_TEST_PROBLEMS = add_flags((
|
| 33 |
+
(1, 128, 128, 128),
|
| 34 |
+
(8, 128, 128, 128),
|
| 35 |
+
(16, 128, 128, 128),
|
| 36 |
+
(1, 128, 256, 512),
|
| 37 |
+
(8, 128, 256, 512),
|
| 38 |
+
(16, 128, 256, 512),
|
| 39 |
+
))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def randn(bs, x, y):
|
| 43 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
| 44 |
+
return out.cuda().to(torch.bfloat16)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
| 48 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
| 49 |
+
|
| 50 |
+
out = []
|
| 51 |
+
start = 0
|
| 52 |
+
for i, size in enumerate(batch_sizes):
|
| 53 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
| 54 |
+
out.append(a[start:start + size, :] @ rhs)
|
| 55 |
+
start += size
|
| 56 |
+
return torch.cat(out)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@parameterized.parameters(*_TEST_PROBLEMS)
|
| 60 |
+
class OpsTest(parameterized.TestCase):
|
| 61 |
+
|
| 62 |
+
def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
| 63 |
+
torch.manual_seed(0)
|
| 64 |
+
a = randn(z, m, k).view(-1, k)
|
| 65 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
| 66 |
+
batch_sizes = torch.tensor([m] * z)
|
| 67 |
+
if batch_sizes_on_device:
|
| 68 |
+
batch_sizes = batch_sizes.cuda()
|
| 69 |
+
|
| 70 |
+
a.requires_grad_(True)
|
| 71 |
+
b.requires_grad_(True)
|
| 72 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
| 73 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
| 74 |
+
|
| 75 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
| 76 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
| 77 |
+
# print("out", out)
|
| 78 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
| 79 |
+
self.assertTrue(allclose(out, expected_out))
|
| 80 |
+
|
| 81 |
+
# Check gradients.
|
| 82 |
+
out.sum().backward()
|
| 83 |
+
expected_out.sum().backward()
|
| 84 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
| 85 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
| 86 |
+
|
| 87 |
+
def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
| 88 |
+
torch.manual_seed(0)
|
| 89 |
+
a = randn(z, m, k).view(-1, k)
|
| 90 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
| 91 |
+
|
| 92 |
+
dist = torch.rand(z, )
|
| 93 |
+
dist /= dist.sum()
|
| 94 |
+
batch_sizes = (dist * m).to(torch.long)
|
| 95 |
+
error = m * z - batch_sizes.sum()
|
| 96 |
+
batch_sizes[-1] += error
|
| 97 |
+
assert batch_sizes.sum() == (m * z)
|
| 98 |
+
if batch_sizes_on_device:
|
| 99 |
+
batch_sizes = batch_sizes.cuda()
|
| 100 |
+
|
| 101 |
+
a.requires_grad_(True)
|
| 102 |
+
b.requires_grad_(True)
|
| 103 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
| 104 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
| 105 |
+
|
| 106 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
| 107 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
| 108 |
+
self.assertTrue(allclose(out, expected_out))
|
| 109 |
+
|
| 110 |
+
# Check gradients.
|
| 111 |
+
out.sum().backward()
|
| 112 |
+
expected_out.sum().backward()
|
| 113 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
| 114 |
+
|
| 115 |
+
# TODO: Review to ensure that the gradients are correct.
|
| 116 |
+
# self.assertTrue(allclose(b.grad, b_ref.grad))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# @parameterized.parameters(False, True)
|
| 120 |
+
@parameterized.parameters(False, False)
|
| 121 |
+
class EdgeCasesTest(unittest.TestCase):
|
| 122 |
+
|
| 123 |
+
def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
|
| 124 |
+
torch.manual_seed(0)
|
| 125 |
+
m = 16384
|
| 126 |
+
k = 4096
|
| 127 |
+
n = 14336
|
| 128 |
+
num_experts = 8
|
| 129 |
+
|
| 130 |
+
a = randn(num_experts, m // num_experts, k).view(-1, k)
|
| 131 |
+
b = randn(num_experts, k, n)
|
| 132 |
+
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
|
| 133 |
+
if batch_sizes_on_device:
|
| 134 |
+
batch_sizes = batch_sizes.cuda()
|
| 135 |
+
|
| 136 |
+
a.requires_grad_(True)
|
| 137 |
+
b.requires_grad_(True)
|
| 138 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
| 139 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
| 140 |
+
|
| 141 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes)
|
| 142 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes)
|
| 143 |
+
self.assertTrue(allclose(out, expected_out))
|
| 144 |
+
|
| 145 |
+
# Check gradients.
|
| 146 |
+
out.sum().backward()
|
| 147 |
+
expected_out.sum().backward()
|
| 148 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
| 149 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
| 150 |
+
|
| 151 |
+
def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
|
| 152 |
+
sz = 128
|
| 153 |
+
total_tokens = 192
|
| 154 |
+
|
| 155 |
+
a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
| 156 |
+
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
| 157 |
+
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
|
| 158 |
+
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
|
| 159 |
+
if batch_sizes_on_device:
|
| 160 |
+
batch_sizes = batch_sizes.cuda()
|
| 161 |
+
|
| 162 |
+
megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
|
| 163 |
+
self.assertTrue((c[0] == 0).all())
|
| 164 |
+
self.assertTrue((c[1] == 128).all())
|
| 165 |
+
self.assertTrue((c[2] == 0).all())
|
| 166 |
+
self.assertTrue((c[3] == 64).all())
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
unittest.main()
|
tests/test_gg.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import megablocks
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def randn(bs, x, y):
|
| 6 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
| 7 |
+
return out.cuda().to(torch.bfloat16)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
| 11 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
| 12 |
+
|
| 13 |
+
out = []
|
| 14 |
+
start = 0
|
| 15 |
+
for i, size in enumerate(batch_sizes):
|
| 16 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
| 17 |
+
out.append(a[start : start + size, :] @ rhs)
|
| 18 |
+
start += size
|
| 19 |
+
return torch.cat(out)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_gmm():
|
| 23 |
+
z = 1
|
| 24 |
+
m = 128
|
| 25 |
+
n = 128
|
| 26 |
+
k = 128
|
| 27 |
+
trans_b = False
|
| 28 |
+
batch_sizes_on_device = False
|
| 29 |
+
# TODO: fix to enable batch_sizes_on_device
|
| 30 |
+
# batch_sizes_on_device = True
|
| 31 |
+
|
| 32 |
+
torch.manual_seed(0)
|
| 33 |
+
a = randn(z, m, k).view(-1, k)
|
| 34 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
| 35 |
+
batch_sizes = torch.tensor([m] * z)
|
| 36 |
+
if batch_sizes_on_device:
|
| 37 |
+
batch_sizes = batch_sizes.cuda()
|
| 38 |
+
|
| 39 |
+
a.requires_grad_(True)
|
| 40 |
+
b.requires_grad_(True)
|
| 41 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
| 42 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
| 43 |
+
|
| 44 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
| 45 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
| 46 |
+
print("out", out)
|
| 47 |
+
|
| 48 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
| 49 |
+
|
| 50 |
+
assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
|
| 51 |
+
|
| 52 |
+
out.sum().backward()
|
| 53 |
+
|
| 54 |
+
expected_out.sum().backward()
|
| 55 |
+
assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
|
| 56 |
+
assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
|
| 57 |
+
print("Test passed successfully!")
|
torch-ext/megablocks/__init__.py
CHANGED
|
@@ -5,11 +5,15 @@ import torch
|
|
| 5 |
|
| 6 |
from ._ops import ops
|
| 7 |
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# This section contains the direct kernel exports (not inlcuded in the original code)
|
| 15 |
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 5 |
|
| 6 |
from ._ops import ops
|
| 7 |
|
| 8 |
+
from .grouped_gemm import backend as gg_backend
|
| 9 |
+
from .grouped_gemm import ops as gg_ops
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from .layers.arguments import Arguments
|
| 13 |
+
from .layers.dmoe import ParallelDroplessMLP, dMoE
|
| 14 |
+
from .layers.glu import SparseGLU
|
| 15 |
+
from .layers.mlp import MLP, SparseMLP
|
| 16 |
+
from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
| 17 |
|
| 18 |
# This section contains the direct kernel exports (not inlcuded in the original code)
|
| 19 |
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
torch-ext/megablocks/grouped_gemm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import ops
|
| 2 |
+
from . import backend
|
torch-ext/megablocks/grouped_gemm/backend.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NOTE: Torch needs to be imported before the custom
|
| 2 |
+
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# # TODO(tgale): Wrap this in a try-block with better
|
| 6 |
+
# # error message and instructions for building the
|
| 7 |
+
# # c++ operations.
|
| 8 |
+
# import grouped_gemm_backend as backend
|
| 9 |
+
|
| 10 |
+
# We import the backend operations from the megablocks package as
|
| 11 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
| 12 |
+
# from ... import _ops as backend
|
| 13 |
+
from megablocks._ops import ops as backend # type: ignore
|
| 14 |
+
|
| 15 |
+
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 16 |
+
assert not (trans_a and trans_b)
|
| 17 |
+
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
|
| 18 |
+
assert a.ndim == 2, "Expected 2d tensor for 'a'"
|
| 19 |
+
assert b.ndim == (2 if trans_a else 3)
|
| 20 |
+
|
| 21 |
+
shape = (
|
| 22 |
+
(batch_sizes.shape[0], a.shape[1], b.shape[1])
|
| 23 |
+
if trans_a else
|
| 24 |
+
(a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
|
| 25 |
+
)
|
| 26 |
+
return torch.empty(*shape, device=a.device, dtype=a.dtype)
|
| 27 |
+
|
| 28 |
+
def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
|
| 29 |
+
if c is None:
|
| 30 |
+
c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
|
| 31 |
+
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
|
| 32 |
+
return c
|
torch-ext/megablocks/grouped_gemm/ops.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import backend
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class GroupedGemm(torch.autograd.Function):
|
| 6 |
+
|
| 7 |
+
@staticmethod
|
| 8 |
+
def forward(ctx, a, b, batch_sizes, trans_b):
|
| 9 |
+
ctx.save_for_backward(a, b, batch_sizes)
|
| 10 |
+
ctx.trans_b = trans_b
|
| 11 |
+
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def backward(ctx, grad):
|
| 15 |
+
grad = grad.contiguous()
|
| 16 |
+
a, b, batch_sizes = ctx.saved_tensors
|
| 17 |
+
trans_b = ctx.trans_b
|
| 18 |
+
|
| 19 |
+
agrad = None
|
| 20 |
+
if ctx.needs_input_grad[0]:
|
| 21 |
+
agrad = backend.gmm(
|
| 22 |
+
grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
|
| 23 |
+
|
| 24 |
+
bgrad = None
|
| 25 |
+
if ctx.needs_input_grad[1]:
|
| 26 |
+
lhs, rhs = (grad, a) if trans_b else (a, grad)
|
| 27 |
+
bgrad = backend.gmm(
|
| 28 |
+
lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
|
| 29 |
+
return agrad, bgrad, None, None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
| 33 |
+
return GroupedGemm.apply(a, b, batch_sizes, trans_b)
|
torch-ext/megablocks/grouped_gemm_util.py
CHANGED
|
@@ -4,7 +4,8 @@ import warnings
|
|
| 4 |
|
| 5 |
_grouped_gemm_is_available: bool = False
|
| 6 |
try:
|
| 7 |
-
import grouped_gemm
|
|
|
|
| 8 |
_grouped_gemm_is_available = True
|
| 9 |
except ImportError as error:
|
| 10 |
warnings.warn('Grouped GEMM not available.')
|
|
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
|
|
| 22 |
assert _grouped_gemm_is_available, msg
|
| 23 |
|
| 24 |
|
| 25 |
-
backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
| 26 |
-
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
_grouped_gemm_is_available: bool = False
|
| 6 |
try:
|
| 7 |
+
# import grouped_gemm
|
| 8 |
+
pass
|
| 9 |
_grouped_gemm_is_available = True
|
| 10 |
except ImportError as error:
|
| 11 |
warnings.warn('Grouped GEMM not available.')
|
|
|
|
| 23 |
assert _grouped_gemm_is_available, msg
|
| 24 |
|
| 25 |
|
| 26 |
+
# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
| 27 |
+
# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from .grouped_gemm import backend as ops
|
| 31 |
+
from .grouped_gemm import ops as backend
|
torch-ext/megablocks/layers/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
|
| 4 |
# from megablocks.layers.dmoe import dMoE
|
| 5 |
-
from
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
'MoE',
|
|
|
|
| 2 |
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
|
| 4 |
# from megablocks.layers.dmoe import dMoE
|
| 5 |
+
from .moe import MoE
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
'MoE',
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -9,6 +9,8 @@
|
|
| 9 |
#include "new_replicate.h"
|
| 10 |
#include "new_sort.h"
|
| 11 |
|
|
|
|
|
|
|
| 12 |
// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
| 13 |
torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
|
| 14 |
megablocks::exclusive_cumsum(x, dim, out);
|
|
@@ -70,6 +72,12 @@ torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out
|
|
| 70 |
return x_out;
|
| 71 |
}
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
// Reference implementation:
|
| 74 |
//
|
| 75 |
// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
|
@@ -101,6 +109,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 101 |
|
| 102 |
ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
|
| 103 |
ops.impl("sort", torch::kCUDA, &sort_wrapper);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
}
|
| 105 |
|
| 106 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 9 |
#include "new_replicate.h"
|
| 10 |
#include "new_sort.h"
|
| 11 |
|
| 12 |
+
#include "grouped_gemm/grouped_gemm.h"
|
| 13 |
+
|
| 14 |
// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
| 15 |
torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
|
| 16 |
megablocks::exclusive_cumsum(x, dim, out);
|
|
|
|
| 72 |
return x_out;
|
| 73 |
}
|
| 74 |
|
| 75 |
+
// GroupedGemm operation
|
| 76 |
+
torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) {
|
| 77 |
+
grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b);
|
| 78 |
+
return c;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
// Reference implementation:
|
| 82 |
//
|
| 83 |
// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
|
|
|
| 109 |
|
| 110 |
ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
|
| 111 |
ops.impl("sort", torch::kCUDA, &sort_wrapper);
|
| 112 |
+
|
| 113 |
+
// Register the gmm GroupedGemm operation
|
| 114 |
+
ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)");
|
| 115 |
+
ops.impl("gmm", torch::kCUDA, &gmm);
|
| 116 |
}
|
| 117 |
|
| 118 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|