How does this actually work?
#5
by
alpindale
- opened
Are there any limitations to what sort of pytorch code it can write triton kernels for? For example:
$ cat infer.py
from kernelllm import KernelLLM
# Initialize the model
model = KernelLLM()
# Define your PyTorch module
pytorch_code = '''
import torch
def _apply_min_p(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits
'''
# Generate optimized Triton code
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
print(optimized_code)
Output:
$ python infer.py
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:06<00:00, 1.56s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!
import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
@triton
.jit
def triton_poi_fused__softmax_0(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr
):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = xindex // 4
tmp0 = tl.load(in_ptr0 + x2, xmask)
tmp1 = tl.load(in_ptr0 + 4 * x1, xmask, eviction_policy='evict_last')
tmp2 = tl.load(in_ptr0 + (1 + 4 * x1), xmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr0 + (2 + 4 * x1), xmask, eviction_policy='evict_last')
tmp6 = tl.load(in_ptr0 + (3 + 4 * x1), xmask, eviction_policy='evict_last')
tmp3 = triton_helpers.maximum(tmp1, tmp2)
tmp5 = triton_helpers.maximum(tmp3, tmp4)
tmp7 = triton_helpers.maximum(tmp5, tmp6)
tmp8 = tmp0 - tmp7
tmp9 = tl_math.exp(tmp8)
tl.store(out_ptr0 + x2, tmp9, xmask)
@triton
.jit
def triton_poi_fused__softmax_max_mul_1(in_ptr0, in_ptr1, out
This does not seem to be what I asked for.