How to use this model

tl_methods = [
    'PropagateNan', 'TRITON_MAX_TENSOR_NUMEL', 'abs', 'advance', 'arange',
    'argmax', 'argmin', 'associative_scan', 'atomic_add', 'atomic_and',
    'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg',
    'atomic_xor', 'bfloat16', 'block_type', 'broadcast', 'broadcast_to',
    'cast', 'cat', 'cdiv', 'ceil', 'clamp', 'const', 'const_pointer_type',
    'constexpr', 'cos', 'cumprod', 'cumsum', 'debug_barrier', 'device_assert',
    'device_print', 'div_rn', 'dot', 'dtype', 'erf', 'exp', 'exp2',
    'expand_dims', 'fdiv', 'flip', 'float16', 'float32', 'float64',
    'float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16',
    'floor', 'fma', 'full', 'function_type', 'histogram',
    'inline_asm_elementwise', 'int1', 'int16', 'int32', 'int64', 'int8',
    'interleave', 'join', 'load', 'log', 'log2', 'make_block_ptr', 'max',
    'max_constancy', 'max_contiguous', 'maximum', 'min', 'minimum',
    'multiple_of', 'num_programs', 'pair_uniform_to_normal', 'permute',
    'philox', 'pi32_t', 'pointer_type', 'program_id', 'rand', 'rand4x',
    'randint', 'randint4x', 'randn', 'randn4x', 'range', 'ravel', 'reduce',
    'reshape', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'sort', 'split', 'sqrt',
    'sqrt_rn', 'static_assert', 'static_print', 'static_range', 'store',
    'str_to_ty', 'sum', 'swizzle2d', 'tensor', 'trans', 'uint16', 'uint32',
    'uint64', 'uint8', 'uint_to_uniform_float', 'umulhi', 'view', 'void',
    'where', 'xor_sum', 'zeros', 'zeros_like'
]


def get_user_prompt(name, pytorch_impl):
    prompt = f"""Convert this PyTorch module implementation into an equivalent Triton kernel:

<torch_code>
{pytorch_impl}
</torch_code>

The Triton kernel should:
1. Import torch, triton, and triton.language as tl and other necessary modules
2. Use @triton.jit decorator on the kernel implementation (not the entrypoint function)
3. Have proper grid and block sizes
4. Use a mask in the load/store operations
5. Use typed constants (tl.constexpr)
6. Handle tensor dimensions correctly
7. Return output matching PyTorch's implementation
8. Do not include any test code in your response, only the Triton kernel implementation and entrypoint function

The triton.language (tl) module supports the following methods: {", ".join(tl_methods)}

The entrypoint function must be named: {name}_triton
The Triton kernel implementation (called by the entrypoint) must be named: {name}_kernel

No computation logic should be done within the entrypoint function. All computation logic should be done within the Triton kernel implementation.

The final generated code in the response must start with <triton_code> and end with </triton_code> tags."""

    return prompt


SYSTEM_PROMPT = """You are a helpful assistant that converts PyTorch code into Triton kernels."""

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": get_user_prompt(name, code)},
]

...

Example PyTorch code (from Kernelbench):

import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a LeakyReLU activation.
    """
    def __init__(self, negative_slope: float = 0.01):
        """
        Initializes the LeakyReLU module.

        Args:
            negative_slope (float, optional): The negative slope of the activation function. Defaults to 0.01.
        """
        super(Model, self).__init__()
        self.negative_slope = negative_slope

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies LeakyReLU activation to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of any shape.

        Returns:
            torch.Tensor: Output tensor with LeakyReLU applied, same shape as input.
        """
        return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope)

batch_size = 16
dim = 16384

def get_inputs():
    x = torch.randn(batch_size, dim)
    return [x]

def get_init_inputs():
    return []  # No special initialization inputs needed
Downloads last month
20
Safetensors
Model size
32.8B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for predibase/Predibase-T2T-32B-GRPO

Quantizations
1 model

Collection including predibase/Predibase-T2T-32B-GRPO