| { | |
| description = "Flake for megablocks_moe kernel"; | |
| inputs = { | |
| kernel-builder.url = "github:huggingface/kernel-builder/support-custom-python-libraries-in-dev-shell-nixland"; | |
| # Add libraries as inputs | |
| composer = { | |
| url = "github:mosaicml/composer"; | |
| flake = false; | |
| }; | |
| stk = { | |
| url = "github:stanford-futuredata/stk"; | |
| flake = false; | |
| }; | |
| # TODO: update to build with the correct torch version | |
| # grouped_gemm = { | |
| # url = "github:tgale96/grouped_gemm"; | |
| # flake = false; | |
| # }; | |
| }; | |
| outputs = { | |
| self, | |
| kernel-builder, | |
| composer, | |
| stk, | |
| # grouped_gemm, | |
| }: | |
| kernel-builder.lib.genFlakeOutputs { | |
| path = ./.; | |
| rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; | |
| # Map custom packages to their sources | |
| customPythonPackages = { | |
| composer = composer; | |
| stk = stk; | |
| # grouped_gemm = grouped_gemm; | |
| }; | |
| pythonTestDeps = [ | |
| "tqdm" | |
| "py-cpuinfo" | |
| "importlib-metadata" | |
| "torchmetrics" | |
| "composer" | |
| "stk" | |
| # "grouped_gemm" | |
| # "yahp" # may be needed for some testing plugin | |
| ]; | |
| }; | |
| } | |
