run.py script to prompt KernelLLM

#3
Files changed (1) hide show
  1. run.py +198 -0
run.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer
5
+ import transformers
6
+
7
+ HF_MODEL = "facebook/KernelLLM"
8
+
9
+ REPL_INSTRUCTIONS = """
10
+ You can paste or write your nn.Module code below (and finish with Ctrl+D).
11
+ The model will try to optimize it with Triton kernels.
12
+
13
+ Make sure that you provide a `get_inputs()` and `get_init_inputs()` function such that your model can be run like this
14
+ args, kwargs = get_inputs()
15
+ model = ModelNew(*args, **kwargs)
16
+ out = model(get_inputs())
17
+
18
+ >>>
19
+ """
20
+
21
+ DEFAULT_MODEL_CODE = """
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ class Model(nn.Module):
26
+ \"\"\"
27
+ A model that computes Hinge Loss for binary classification tasks.
28
+
29
+ Parameters:
30
+ None
31
+ \"\"\"
32
+ def __init__(self):
33
+ super(Model, self).__init__()
34
+
35
+ def forward(self, predictions, targets):
36
+ return torch.mean(torch.clamp(1 - predictions * targets, min=0))
37
+
38
+ batch_size = 128
39
+ input_shape = (1,)
40
+ dim = 1
41
+
42
+ def get_inputs():
43
+ return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
44
+
45
+ def get_init_inputs():
46
+ return []
47
+ """
48
+
49
+ PROMPT_TEMPLATE = """
50
+ <|begin_of_text|>You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.
51
+
52
+ You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.
53
+
54
+
55
+ Here's an example to show you the syntax of inline embedding custom operators from the Triton DSL in torch: The example given architecture is:
56
+ ```
57
+ import torch
58
+ import torch.nn as nn
59
+ import torch.nn.functional as F
60
+
61
+
62
+ class Model(nn.Module):
63
+ def __init__(self) -> None:
64
+ super().__init__()
65
+
66
+ def forward(self, a, b):
67
+ return a + b
68
+
69
+
70
+ def get_inputs():
71
+ # randomly generate input tensors based on the model architecture
72
+ a = torch.randn(1, 128).cuda()
73
+ b = torch.randn(1, 128).cuda()
74
+ return [a, b]
75
+
76
+
77
+ def get_init_inputs():
78
+ # randomly generate tensors required for initialization based on the model architecture
79
+ return []
80
+
81
+ ```
82
+ The example new arch with custom Triton kernels looks like this:
83
+ ```
84
+ import torch
85
+ import torch.nn as nn
86
+ import torch.nn.functional as F
87
+ import triton
88
+ import triton.language as tl
89
+
90
+
91
+ @triton.jit
92
+ def add_kernel(
93
+ x_ptr, # Pointer to first input
94
+ y_ptr, # Pointer to second input
95
+ out_ptr, # Pointer to output
96
+ n_elements, # Total number of elements in input/output
97
+ BLOCK_SIZE: tl.constexpr,
98
+ ):
99
+ # Each program handles a contiguous block of data of size BLOCK_SIZE
100
+ block_start = tl.program_id(0) * BLOCK_SIZE
101
+ # Create a range of offsets [0..BLOCK_SIZE-1]
102
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
103
+ # Mask to ensure we don't go out of bounds
104
+ mask = offsets < n_elements
105
+ # Load input values
106
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
107
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
108
+ # Perform the elementwise addition
109
+ out = x + y
110
+ # Store the result
111
+ tl.store(out_ptr + offsets, out, mask=mask)
112
+
113
+
114
+ def triton_add(x: torch.Tensor, y: torch.Tensor):
115
+ \"\"\"
116
+ This function wraps the Triton kernel call. It:
117
+ 1. Ensures the inputs are contiguous on GPU.
118
+ 2. Calculates the grid (blocks) needed.
119
+ 3. Launches the Triton kernel.
120
+ \"\"\"
121
+ assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
122
+ x = x.contiguous()
123
+ y = y.contiguous()
124
+
125
+ # Prepare output tensor
126
+ out = torch.empty_like(x)
127
+
128
+ # Number of elements in the tensor
129
+ n_elements = x.numel()
130
+ BLOCK_SIZE = 128 # Tunable parameter for block size
131
+
132
+ # Determine the number of blocks needed
133
+ grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
134
+
135
+ # Launch the Triton kernel
136
+ add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
137
+ return out
138
+
139
+
140
+ class ModelNew(nn.Module):
141
+ def __init__(self) -> None:
142
+ super().__init__()
143
+
144
+ def forward(self, a, b):
145
+ # Instead of "return a + b", call our Triton-based addition
146
+ return triton_add(a, b)
147
+
148
+ ```
149
+
150
+ You are given the following architecture:
151
+ ```
152
+ {}
153
+ ```
154
+
155
+ Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!
156
+ """
157
+
158
+ def main():
159
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL)
160
+ pipeline = transformers.pipeline(
161
+ "text-generation",
162
+ model=HF_MODEL,
163
+ torch_dtype=torch.float16,
164
+ device_map="auto",
165
+ )
166
+
167
+
168
+ while True:
169
+ try:
170
+ print(REPL_INSTRUCTIONS)
171
+ prompt = sys.stdin.read().strip()
172
+ if prompt.lower() == 'exit':
173
+ exit()
174
+ except EOFError:
175
+ pass
176
+
177
+ if not prompt:
178
+ print(f"Using default prompt:\n{DEFAULT_MODEL_CODE}")
179
+ prompt = PROMPT_TEMPLATE.format(DEFAULT_MODEL_CODE)
180
+
181
+ response = pipeline(
182
+ prompt,
183
+ do_sample=True,
184
+ top_k=0,
185
+ temperature=0.6,
186
+ top_p=0.95,
187
+ num_return_sequences=1,
188
+ eos_token_id=tokenizer.eos_token_id,
189
+ max_length=2048,
190
+ truncation=True,
191
+ )[0]
192
+
193
+ print("Response:", response['generated_text'])
194
+
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()