gagan3012 commited on
Commit
a2a20b8
·
verified ·
1 Parent(s): e6010fe

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +3 -1
  2. readme_example.py +202 -0
README.md CHANGED
@@ -4,4 +4,6 @@ tags:
4
  - kernel
5
  ---
6
 
7
- # batch_invariant_kernel
 
 
 
4
  - kernel
5
  ---
6
 
7
+ # batch_invariant_kernel
8
+
9
+ To try out the example of the code
readme_example.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "numpy",
6
+ # "kernels",
7
+ # ]
8
+ # ///
9
+
10
+ import torch
11
+ from kernels import get_kernel
12
+
13
+ # Load batch_invariant_kernel via kernels library
14
+ batch_invariant_kernel = get_kernel("gagan3012/batch_invariant_kernel")
15
+
16
+ # Set device and seed for reproducibility
17
+ device = "cuda"
18
+ torch.manual_seed(42)
19
+ torch.cuda.manual_seed(42)
20
+
21
+ print("🚀 Testing batch_invariant_kernel from Hugging Face Hub")
22
+ print(f"✅ CUDA is available. Using device: {torch.cuda.get_device_name()}")
23
+
24
+ # Test 1: Matrix Multiplication
25
+ print("\n" + "=" * 60)
26
+ print("🧪 Test 1: Persistent Matrix Multiplication")
27
+ print("=" * 60)
28
+
29
+ # Parameters for matrix multiplication
30
+ M, K, N = 512, 256, 1024
31
+ a = torch.randn(M, K, device=device, dtype=torch.float32)
32
+ b = torch.randn(K, N, device=device, dtype=torch.float32)
33
+ bias = torch.randn(N, device=device, dtype=torch.float32)
34
+
35
+ print(f"Matrix A shape: {a.shape}")
36
+ print(f"Matrix B shape: {b.shape}")
37
+ print(f"Bias shape: {bias.shape}")
38
+
39
+ # Run matrix multiplication without bias
40
+ start_event = torch.cuda.Event(enable_timing=True)
41
+ end_event = torch.cuda.Event(enable_timing=True)
42
+
43
+ start_event.record()
44
+ output_no_bias = batch_invariant_kernel.matmul_persistent(a, b)
45
+ end_event.record()
46
+ torch.cuda.synchronize()
47
+ time_no_bias = start_event.elapsed_time(end_event)
48
+
49
+ print(f"\nMatrix multiplication (no bias) completed!")
50
+ print(f"Output shape: {output_no_bias.shape}")
51
+ print(f"Execution time: {time_no_bias:.3f} ms")
52
+
53
+ # Run matrix multiplication with bias
54
+ start_event.record()
55
+ output_with_bias = batch_invariant_kernel.matmul_persistent(a, b, bias)
56
+ end_event.record()
57
+ torch.cuda.synchronize()
58
+ time_with_bias = start_event.elapsed_time(end_event)
59
+
60
+ print(f"\nMatrix multiplication (with bias) completed!")
61
+ print(f"Output shape: {output_with_bias.shape}")
62
+ print(f"Execution time: {time_with_bias:.3f} ms")
63
+
64
+ # Verify correctness
65
+ expected_no_bias = torch.mm(a, b)
66
+ expected_with_bias = torch.mm(a, b) + bias
67
+
68
+ max_diff_no_bias = torch.max(torch.abs(output_no_bias - expected_no_bias)).item()
69
+ max_diff_with_bias = torch.max(torch.abs(output_with_bias - expected_with_bias)).item()
70
+
71
+ print(f"Max difference (no bias): {max_diff_no_bias:.6f}")
72
+ print(f"Max difference (with bias): {max_diff_with_bias:.6f}")
73
+
74
+ # Test 2: Log Softmax
75
+ print("\n" + "=" * 60)
76
+ print("🧪 Test 2: Log Softmax")
77
+ print("=" * 60)
78
+
79
+ # Parameters for log softmax (typical attention dimensions)
80
+ batch_size = 4
81
+ seq_len = 512
82
+ vocab_size = 32000
83
+
84
+ logits = torch.randn(
85
+ batch_size, seq_len, vocab_size, device=device, dtype=torch.float32
86
+ )
87
+ print(f"Input logits shape: {logits.shape}")
88
+
89
+ # Run log softmax
90
+ start_event.record()
91
+ log_probs = batch_invariant_kernel.log_softmax(logits, dim=-1)
92
+ end_event.record()
93
+ torch.cuda.synchronize()
94
+ time_log_softmax = start_event.elapsed_time(end_event)
95
+
96
+ print(f"\nLog softmax completed!")
97
+ print(f"Output shape: {log_probs.shape}")
98
+ print(f"Execution time: {time_log_softmax:.3f} ms")
99
+
100
+ # Verify correctness
101
+ expected_log_probs = torch.log_softmax(logits, dim=-1)
102
+ max_diff_log_softmax = torch.max(torch.abs(log_probs - expected_log_probs)).item()
103
+ print(f"Max difference vs PyTorch: {max_diff_log_softmax:.6f}")
104
+
105
+ # Test 3: Mean Reduction
106
+ print("\n" + "=" * 60)
107
+ print("🧪 Test 3: Mean Dimension Reduction")
108
+ print("=" * 60)
109
+
110
+ # Parameters for mean reduction (typical layer norm dimensions)
111
+ batch_size = 8
112
+ seq_len = 256
113
+ hidden_size = 768
114
+
115
+ hidden_states = torch.randn(
116
+ batch_size, seq_len, hidden_size, device=device, dtype=torch.float32
117
+ )
118
+ print(f"Input hidden states shape: {hidden_states.shape}")
119
+
120
+ # Test reduction along different dimensions
121
+ for dim in [0, 1, 2]:
122
+ start_event.record()
123
+ mean_output = batch_invariant_kernel.mean_dim(hidden_states, dim=dim, keepdim=False)
124
+ end_event.record()
125
+ torch.cuda.synchronize()
126
+ time_mean = start_event.elapsed_time(end_event)
127
+
128
+ expected_mean = torch.mean(hidden_states, dim=dim, keepdim=False)
129
+ max_diff_mean = torch.max(torch.abs(mean_output - expected_mean)).item()
130
+
131
+ print(f"\nMean reduction along dim {dim}:")
132
+ print(f" Output shape: {mean_output.shape}")
133
+ print(f" Execution time: {time_mean:.3f} ms")
134
+ print(f" Max difference vs PyTorch: {max_diff_mean:.6f}")
135
+
136
+ # Test 4: End-to-End Attention-like Computation
137
+ print("\n" + "=" * 60)
138
+ print("🧪 Test 4: End-to-End Attention-like Computation")
139
+ print("=" * 60)
140
+
141
+ # Simulate a simple attention computation using our kernels
142
+ batch_size = 4
143
+ seq_len = 128
144
+ hidden_size = 512
145
+ num_heads = 8
146
+ head_dim = hidden_size // num_heads
147
+
148
+ # Input embeddings
149
+ x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=torch.float32)
150
+
151
+ # Weight matrices for Q, K, V projections
152
+ w_q = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
153
+ w_k = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
154
+ w_v = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
155
+ w_o = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
156
+
157
+ print(f"Input shape: {x.shape}")
158
+ print("Computing Q, K, V projections using batch_invariant matmul...")
159
+
160
+ # Reshape for batch matrix multiplication
161
+ x_flat = x.view(-1, hidden_size) # (batch_size * seq_len, hidden_size)
162
+
163
+ start_event.record()
164
+
165
+ # Compute Q, K, V using our custom matmul
166
+ q_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_q)
167
+ k_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_k)
168
+ v_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_v)
169
+
170
+ # Reshape to multi-head format
171
+ q = q_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
172
+ k = k_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
173
+ v = v_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
174
+
175
+ # Compute attention scores
176
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim**0.5)
177
+
178
+ # Apply softmax using our custom log_softmax (convert to softmax)
179
+ log_attn_weights = batch_invariant_kernel.log_softmax(scores, dim=-1)
180
+ attn_weights = torch.exp(log_attn_weights)
181
+
182
+ # Apply attention to values
183
+ attn_output = torch.matmul(attn_weights, v)
184
+
185
+ # Reshape and apply output projection
186
+ attn_output = (
187
+ attn_output.transpose(1, 2).contiguous().view(batch_size * seq_len, hidden_size)
188
+ )
189
+ final_output = batch_invariant_kernel.matmul_persistent(attn_output, w_o)
190
+ final_output = final_output.view(batch_size, seq_len, hidden_size)
191
+
192
+ end_event.record()
193
+ torch.cuda.synchronize()
194
+ total_time = start_event.elapsed_time(end_event)
195
+
196
+ print(f"\nEnd-to-end attention computation completed!")
197
+ print(f"Final output shape: {final_output.shape}")
198
+ print(f"Total execution time: {total_time:.3f} ms")
199
+ print(
200
+ f"Output tensor stats - Mean: {final_output.mean().item():.4f}, Std: {final_output.std().item():.4f}"
201
+ )
202
+