Felladrin commited on
Commit
f0bd677
·
verified ·
1 Parent(s): 920ec42

Upload kernelllm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. kernelllm.py +331 -0
kernelllm.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KernelLLM
3
+
4
+ This script provides a simple interface for the KernelLLM model.
5
+ It allows users to input PyTorch models and let KernelLLM attempt to implement the corresponding Triton kernels.
6
+
7
+ The KernelLLM class provides two types of methods:
8
+ 1. Methods that instruct the model with a suitable prompt to generate Triton kernels.
9
+ 2. "raw" methods that allow the user to interact with the model directly, without any additional prompt wrapping.
10
+
11
+ For best results, use the `generate_triton` method to instruct the model the way it was trained.
12
+
13
+ Example usage:
14
+ To run the script from the command line:
15
+ CUDA_VISIBLE_DEVICES=0 python kernelllm.py
16
+
17
+ To use the class in an interactive Python session:
18
+ $ ipython
19
+ from kernelllm import KernelLLM
20
+ model = KernelLLM()
21
+ model.generate_triton("<your torch module here>", max_new_tokens=128)
22
+
23
+ # or stream output directly
24
+ model.stream_raw("<your text prompt>", max_new_tokens=128)
25
+
26
+
27
+ Full example:
28
+ ```
29
+ #Generate Triton-optimized code for a PyTorch model:
30
+ from kernelllm import KernelLLM
31
+
32
+ model = KernelLLM()
33
+ pytorch_code = '''
34
+ import torch
35
+ import torch.nn as nn
36
+
37
+ class Model(nn.Module):
38
+ def __init__(self):
39
+ super(Model, self).__init__()
40
+
41
+ def forward(self, x):
42
+ return x * 2
43
+
44
+ def get_inputs():
45
+ return [torch.randn(1, 128).cuda()]
46
+
47
+ def get_init_inputs():
48
+ return []
49
+ '''
50
+ optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
51
+ print(optimized_code)
52
+ ```
53
+ """
54
+
55
+ import sys
56
+
57
+ import torch
58
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
59
+
60
+ HF_MODEL = "facebook/KernelLLM"
61
+
62
+ REPL_INSTRUCTIONS = """
63
+ You can paste or write your nn.Module code below (and finish with Ctrl+D).
64
+ The model will try to optimize it with Triton kernels.
65
+
66
+ Make sure that you provide a `get_inputs()` and `get_init_inputs()` function such that your model can be run like this
67
+ args, kwargs = get_inputs()
68
+ model = Model(*args, **kwargs)
69
+ out = model(get_inputs())
70
+
71
+ >>>"""
72
+
73
+ DEFAULT_MODEL_CODE = """
74
+ import torch
75
+ import torch.nn as nn
76
+
77
+ class Model(nn.Module):
78
+ \"\"\"
79
+ A model that computes Hinge Loss for binary classification tasks.
80
+
81
+ Parameters:
82
+ None
83
+ \"\"\"
84
+ def __init__(self):
85
+ super(Model, self).__init__()
86
+
87
+ def forward(self, predictions, targets):
88
+ return torch.mean(torch.clamp(1 - predictions * targets, min=0))
89
+
90
+ batch_size = 128
91
+ input_shape = (1,)
92
+ dim = 1
93
+
94
+ def get_inputs():
95
+ return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
96
+
97
+ def get_init_inputs():
98
+ return []
99
+ """
100
+
101
+ PROMPT_TEMPLATE = """
102
+ <|begin_of_text|>You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.
103
+
104
+ You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.
105
+
106
+
107
+ Here's an example to show you the syntax of inline embedding custom operators from the Triton DSL in torch: The example given architecture is:
108
+ ```
109
+ import torch
110
+ import torch.nn as nn
111
+ import torch.nn.functional as F
112
+
113
+
114
+ class Model(nn.Module):
115
+ def __init__(self) -> None:
116
+ super().__init__()
117
+
118
+ def forward(self, a, b):
119
+ return a + b
120
+
121
+
122
+ def get_inputs():
123
+ # randomly generate input tensors based on the model architecture
124
+ a = torch.randn(1, 128).cuda()
125
+ b = torch.randn(1, 128).cuda()
126
+ return [a, b]
127
+
128
+
129
+ def get_init_inputs():
130
+ # randomly generate tensors required for initialization based on the model architecture
131
+ return []
132
+
133
+ ```
134
+ The example new arch with custom Triton kernels looks like this:
135
+ ```
136
+ import torch
137
+ import torch.nn as nn
138
+ import torch.nn.functional as F
139
+ import triton
140
+ import triton.language as tl
141
+
142
+
143
+ @triton.jit
144
+ def add_kernel(
145
+ x_ptr, # Pointer to first input
146
+ y_ptr, # Pointer to second input
147
+ out_ptr, # Pointer to output
148
+ n_elements, # Total number of elements in input/output
149
+ BLOCK_SIZE: tl.constexpr,
150
+ ):
151
+ # Each program handles a contiguous block of data of size BLOCK_SIZE
152
+ block_start = tl.program_id(0) * BLOCK_SIZE
153
+ # Create a range of offsets [0..BLOCK_SIZE-1]
154
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
155
+ # Mask to ensure we don't go out of bounds
156
+ mask = offsets < n_elements
157
+ # Load input values
158
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
159
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
160
+ # Perform the elementwise addition
161
+ out = x + y
162
+ # Store the result
163
+ tl.store(out_ptr + offsets, out, mask=mask)
164
+
165
+
166
+ def triton_add(x: torch.Tensor, y: torch.Tensor):
167
+ \"\"\"
168
+ This function wraps the Triton kernel call. It:
169
+ 1. Ensures the inputs are contiguous on GPU.
170
+ 2. Calculates the grid (blocks) needed.
171
+ 3. Launches the Triton kernel.
172
+ \"\"\"
173
+ assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
174
+ x = x.contiguous()
175
+ y = y.contiguous()
176
+
177
+ # Prepare output tensor
178
+ out = torch.empty_like(x)
179
+
180
+ # Number of elements in the tensor
181
+ n_elements = x.numel()
182
+ BLOCK_SIZE = 128 # Tunable parameter for block size
183
+
184
+ # Determine the number of blocks needed
185
+ grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
186
+
187
+ # Launch the Triton kernel
188
+ add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
189
+ return out
190
+
191
+
192
+ class ModelNew(nn.Module):
193
+ def __init__(self) -> None:
194
+ super().__init__()
195
+
196
+ def forward(self, a, b):
197
+ # Instead of "return a + b", call our Triton-based addition
198
+ return triton_add(a, b)
199
+
200
+ ```
201
+
202
+ You are given the following architecture:
203
+ ```
204
+ {}
205
+ ```
206
+
207
+ Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!
208
+ """
209
+
210
+
211
+ class KernelLLM:
212
+ """
213
+ A simple wrapper around the KernelLLM model for generating Triton kernels that allows easy
214
+ instruction of the model and a streamed repl interface to interact with the model.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ model_name: str = HF_MODEL,
220
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
221
+ ):
222
+ self.model_name = model_name
223
+ self.device = device
224
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
225
+ self.model = AutoModelForCausalLM.from_pretrained(
226
+ self.model_name, torch_dtype=torch.float16
227
+ )
228
+ self.model.to(self.device)
229
+
230
+ def generate_raw(
231
+ self, prompt: str, temperature: float = 0.6, max_new_tokens: int = 2048
232
+ ) -> str:
233
+ """
234
+ Generate text from the model using the given prompt verbatim.
235
+
236
+ Args:
237
+ prompt (str): The prompt to generate text from.
238
+ temperature (float): The temperature to use for sampling.
239
+ max_new_tokens (int): The maximum length of the generated text.
240
+ Returns:
241
+ str: The generated text.
242
+ """
243
+ inputs = self.tokenizer([prompt], return_tensors="pt")
244
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
245
+ outputs = self.model.generate(
246
+ **inputs,
247
+ max_new_tokens=max_new_tokens,
248
+ temperature=temperature,
249
+ top_k=0,
250
+ top_p=0.95,
251
+ do_sample=True,
252
+ eos_token_id=self.tokenizer.eos_token_id,
253
+ )
254
+ text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
255
+ return text[len(prompt) :].strip()
256
+
257
+ def stream_raw(self, prompt: str, max_new_tokens: int = 2048):
258
+ """
259
+ Stream and print text from the model using the given prompt verbatim.
260
+
261
+ Args:
262
+ prompt (str): The prompt to generate text from.
263
+ max_new_tokens (int): The maximum length of the generated text.
264
+ """
265
+ inputs = self.tokenizer([prompt], return_tensors="pt")
266
+ inputs = {k: v.cuda() for k, v in inputs.items()}
267
+ streamer = TextStreamer(
268
+ self.tokenizer, skip_prompt=True, skip_special_tokens=True
269
+ )
270
+ self.model.generate(
271
+ **inputs,
272
+ streamer=streamer,
273
+ max_new_tokens=max_new_tokens,
274
+ do_sample=True,
275
+ top_k=0,
276
+ top_p=0.95,
277
+ temperature=0.6,
278
+ eos_token_id=self.tokenizer.eos_token_id,
279
+ )
280
+
281
+ def generate_triton(
282
+ self, code: str, temperature: float = 0.6, max_new_tokens: int = 2048
283
+ ) -> str:
284
+ """
285
+ Generate Triton for the given torch module.
286
+
287
+ The input code should be a python module that contains a torch Model(nn.Module) class and
288
+ `get_inputs()` and `get_init_inputs()` functions such that your model can be run like this
289
+ ```
290
+ args, kwargs = get_inputs()
291
+ model = Model(*args, **kwargs)
292
+ out = model(get_inputs())
293
+ ```
294
+
295
+ Args:
296
+ code (str): The torch code to generate Triton for.
297
+ temperature (float): The temperature to use for sampling.
298
+ max_new_tokens (int): The maximum length of the generated text.
299
+ Returns:
300
+ str: The generated Triton module.
301
+ """
302
+ prompt = PROMPT_TEMPLATE.format(code)
303
+ return self.generate_raw(prompt, temperature, max_new_tokens)
304
+
305
+ def run_repl(self):
306
+ """
307
+ Run a REPL for the model. The user can input code and the model will try to optimize it with Triton kernels.
308
+ """
309
+ while True:
310
+ try:
311
+ print(REPL_INSTRUCTIONS)
312
+ code = sys.stdin.read().strip()
313
+ if code.lower() == "exit":
314
+ return
315
+ except EOFError:
316
+ pass
317
+
318
+ if not code:
319
+ print(f"Using default prompt:\n{DEFAULT_MODEL_CODE}\n")
320
+ code = DEFAULT_MODEL_CODE
321
+ prompt = PROMPT_TEMPLATE.format(DEFAULT_MODEL_CODE)
322
+
323
+ try:
324
+ self.stream_raw(prompt)
325
+ except KeyboardInterrupt:
326
+ print("Aborting...")
327
+
328
+
329
+ if __name__ == "__main__":
330
+ kernel_llm = KernelLLM()
331
+ kernel_llm.run_repl()