Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,113 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
# How to use this model
|
6 |
+
|
7 |
+
```python
|
8 |
+
tl_methods = [
|
9 |
+
'PropagateNan', 'TRITON_MAX_TENSOR_NUMEL', 'abs', 'advance', 'arange',
|
10 |
+
'argmax', 'argmin', 'associative_scan', 'atomic_add', 'atomic_and',
|
11 |
+
'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg',
|
12 |
+
'atomic_xor', 'bfloat16', 'block_type', 'broadcast', 'broadcast_to',
|
13 |
+
'cast', 'cat', 'cdiv', 'ceil', 'clamp', 'const', 'const_pointer_type',
|
14 |
+
'constexpr', 'cos', 'cumprod', 'cumsum', 'debug_barrier', 'device_assert',
|
15 |
+
'device_print', 'div_rn', 'dot', 'dtype', 'erf', 'exp', 'exp2',
|
16 |
+
'expand_dims', 'fdiv', 'flip', 'float16', 'float32', 'float64',
|
17 |
+
'float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16',
|
18 |
+
'floor', 'fma', 'full', 'function_type', 'histogram',
|
19 |
+
'inline_asm_elementwise', 'int1', 'int16', 'int32', 'int64', 'int8',
|
20 |
+
'interleave', 'join', 'load', 'log', 'log2', 'make_block_ptr', 'max',
|
21 |
+
'max_constancy', 'max_contiguous', 'maximum', 'min', 'minimum',
|
22 |
+
'multiple_of', 'num_programs', 'pair_uniform_to_normal', 'permute',
|
23 |
+
'philox', 'pi32_t', 'pointer_type', 'program_id', 'rand', 'rand4x',
|
24 |
+
'randint', 'randint4x', 'randn', 'randn4x', 'range', 'ravel', 'reduce',
|
25 |
+
'reshape', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'sort', 'split', 'sqrt',
|
26 |
+
'sqrt_rn', 'static_assert', 'static_print', 'static_range', 'store',
|
27 |
+
'str_to_ty', 'sum', 'swizzle2d', 'tensor', 'trans', 'uint16', 'uint32',
|
28 |
+
'uint64', 'uint8', 'uint_to_uniform_float', 'umulhi', 'view', 'void',
|
29 |
+
'where', 'xor_sum', 'zeros', 'zeros_like'
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def get_user_prompt(name, pytorch_impl):
|
34 |
+
prompt = f"""Convert this PyTorch module implementation into an equivalent Triton kernel:
|
35 |
+
|
36 |
+
<torch_code>
|
37 |
+
{pytorch_impl}
|
38 |
+
</torch_code>
|
39 |
+
|
40 |
+
The Triton kernel should:
|
41 |
+
1. Import torch, triton, and triton.language as tl and other necessary modules
|
42 |
+
2. Use @triton.jit decorator on the kernel implementation (not the entrypoint function)
|
43 |
+
3. Have proper grid and block sizes
|
44 |
+
4. Use a mask in the load/store operations
|
45 |
+
5. Use typed constants (tl.constexpr)
|
46 |
+
6. Handle tensor dimensions correctly
|
47 |
+
7. Return output matching PyTorch's implementation
|
48 |
+
8. Do not include any test code in your response, only the Triton kernel implementation and entrypoint function
|
49 |
+
|
50 |
+
The triton.language (tl) module supports the following methods: {", ".join(tl_methods)}
|
51 |
+
|
52 |
+
The entrypoint function must be named: {name}_triton
|
53 |
+
The Triton kernel implementation (called by the entrypoint) must be named: {name}_kernel
|
54 |
+
|
55 |
+
No computation logic should be done within the entrypoint function. All computation logic should be done within the Triton kernel implementation.
|
56 |
+
|
57 |
+
The final generated code in the response must start with <triton_code> and end with </triton_code> tags."""
|
58 |
+
|
59 |
+
return prompt
|
60 |
+
|
61 |
+
|
62 |
+
SYSTEM_PROMPT = """You are a helpful assistant that converts PyTorch code into Triton kernels."""
|
63 |
+
|
64 |
+
messages = [
|
65 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
66 |
+
{"role": "user", "content": get_user_prompt(name, code)},
|
67 |
+
]
|
68 |
+
|
69 |
+
...
|
70 |
+
```
|
71 |
+
|
72 |
+
Example PyTorch code (from Kernelbench):
|
73 |
+
|
74 |
+
```python
|
75 |
+
import torch
|
76 |
+
import torch.nn as nn
|
77 |
+
|
78 |
+
class Model(nn.Module):
|
79 |
+
"""
|
80 |
+
Simple model that performs a LeakyReLU activation.
|
81 |
+
"""
|
82 |
+
def __init__(self, negative_slope: float = 0.01):
|
83 |
+
"""
|
84 |
+
Initializes the LeakyReLU module.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
negative_slope (float, optional): The negative slope of the activation function. Defaults to 0.01.
|
88 |
+
"""
|
89 |
+
super(Model, self).__init__()
|
90 |
+
self.negative_slope = negative_slope
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Applies LeakyReLU activation to the input tensor.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
x (torch.Tensor): Input tensor of any shape.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: Output tensor with LeakyReLU applied, same shape as input.
|
101 |
+
"""
|
102 |
+
return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope)
|
103 |
+
|
104 |
+
batch_size = 16
|
105 |
+
dim = 16384
|
106 |
+
|
107 |
+
def get_inputs():
|
108 |
+
x = torch.randn(batch_size, dim)
|
109 |
+
return [x]
|
110 |
+
|
111 |
+
def get_init_inputs():
|
112 |
+
return [] # No special initialization inputs needed
|
113 |
+
```
|