tgaddair commited on
Commit
3cc04dc
·
verified ·
1 Parent(s): bb6fda5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +113 -3
README.md CHANGED
@@ -1,3 +1,113 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # How to use this model
6
+
7
+ ```python
8
+ tl_methods = [
9
+ 'PropagateNan', 'TRITON_MAX_TENSOR_NUMEL', 'abs', 'advance', 'arange',
10
+ 'argmax', 'argmin', 'associative_scan', 'atomic_add', 'atomic_and',
11
+ 'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg',
12
+ 'atomic_xor', 'bfloat16', 'block_type', 'broadcast', 'broadcast_to',
13
+ 'cast', 'cat', 'cdiv', 'ceil', 'clamp', 'const', 'const_pointer_type',
14
+ 'constexpr', 'cos', 'cumprod', 'cumsum', 'debug_barrier', 'device_assert',
15
+ 'device_print', 'div_rn', 'dot', 'dtype', 'erf', 'exp', 'exp2',
16
+ 'expand_dims', 'fdiv', 'flip', 'float16', 'float32', 'float64',
17
+ 'float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16',
18
+ 'floor', 'fma', 'full', 'function_type', 'histogram',
19
+ 'inline_asm_elementwise', 'int1', 'int16', 'int32', 'int64', 'int8',
20
+ 'interleave', 'join', 'load', 'log', 'log2', 'make_block_ptr', 'max',
21
+ 'max_constancy', 'max_contiguous', 'maximum', 'min', 'minimum',
22
+ 'multiple_of', 'num_programs', 'pair_uniform_to_normal', 'permute',
23
+ 'philox', 'pi32_t', 'pointer_type', 'program_id', 'rand', 'rand4x',
24
+ 'randint', 'randint4x', 'randn', 'randn4x', 'range', 'ravel', 'reduce',
25
+ 'reshape', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'sort', 'split', 'sqrt',
26
+ 'sqrt_rn', 'static_assert', 'static_print', 'static_range', 'store',
27
+ 'str_to_ty', 'sum', 'swizzle2d', 'tensor', 'trans', 'uint16', 'uint32',
28
+ 'uint64', 'uint8', 'uint_to_uniform_float', 'umulhi', 'view', 'void',
29
+ 'where', 'xor_sum', 'zeros', 'zeros_like'
30
+ ]
31
+
32
+
33
+ def get_user_prompt(name, pytorch_impl):
34
+ prompt = f"""Convert this PyTorch module implementation into an equivalent Triton kernel:
35
+
36
+ <torch_code>
37
+ {pytorch_impl}
38
+ </torch_code>
39
+
40
+ The Triton kernel should:
41
+ 1. Import torch, triton, and triton.language as tl and other necessary modules
42
+ 2. Use @triton.jit decorator on the kernel implementation (not the entrypoint function)
43
+ 3. Have proper grid and block sizes
44
+ 4. Use a mask in the load/store operations
45
+ 5. Use typed constants (tl.constexpr)
46
+ 6. Handle tensor dimensions correctly
47
+ 7. Return output matching PyTorch's implementation
48
+ 8. Do not include any test code in your response, only the Triton kernel implementation and entrypoint function
49
+
50
+ The triton.language (tl) module supports the following methods: {", ".join(tl_methods)}
51
+
52
+ The entrypoint function must be named: {name}_triton
53
+ The Triton kernel implementation (called by the entrypoint) must be named: {name}_kernel
54
+
55
+ No computation logic should be done within the entrypoint function. All computation logic should be done within the Triton kernel implementation.
56
+
57
+ The final generated code in the response must start with <triton_code> and end with </triton_code> tags."""
58
+
59
+ return prompt
60
+
61
+
62
+ SYSTEM_PROMPT = """You are a helpful assistant that converts PyTorch code into Triton kernels."""
63
+
64
+ messages = [
65
+ {"role": "system", "content": SYSTEM_PROMPT},
66
+ {"role": "user", "content": get_user_prompt(name, code)},
67
+ ]
68
+
69
+ ...
70
+ ```
71
+
72
+ Example PyTorch code (from Kernelbench):
73
+
74
+ ```python
75
+ import torch
76
+ import torch.nn as nn
77
+
78
+ class Model(nn.Module):
79
+ """
80
+ Simple model that performs a LeakyReLU activation.
81
+ """
82
+ def __init__(self, negative_slope: float = 0.01):
83
+ """
84
+ Initializes the LeakyReLU module.
85
+
86
+ Args:
87
+ negative_slope (float, optional): The negative slope of the activation function. Defaults to 0.01.
88
+ """
89
+ super(Model, self).__init__()
90
+ self.negative_slope = negative_slope
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Applies LeakyReLU activation to the input tensor.
95
+
96
+ Args:
97
+ x (torch.Tensor): Input tensor of any shape.
98
+
99
+ Returns:
100
+ torch.Tensor: Output tensor with LeakyReLU applied, same shape as input.
101
+ """
102
+ return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope)
103
+
104
+ batch_size = 16
105
+ dim = 16384
106
+
107
+ def get_inputs():
108
+ x = torch.randn(batch_size, dim)
109
+ return [x]
110
+
111
+ def get_init_inputs():
112
+ return [] # No special initialization inputs needed
113
+ ```