Upload folder using huggingface_hub
Browse files- README.md +3 -1
- readme_example.py +202 -0
README.md
CHANGED
@@ -4,4 +4,6 @@ tags:
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
7 |
-
# batch_invariant_kernel
|
|
|
|
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
7 |
+
# batch_invariant_kernel
|
8 |
+
|
9 |
+
To try out the example of the code
|
readme_example.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "torch",
|
5 |
+
# "numpy",
|
6 |
+
# "kernels",
|
7 |
+
# ]
|
8 |
+
# ///
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from kernels import get_kernel
|
12 |
+
|
13 |
+
# Load batch_invariant_kernel via kernels library
|
14 |
+
batch_invariant_kernel = get_kernel("gagan3012/batch_invariant_kernel")
|
15 |
+
|
16 |
+
# Set device and seed for reproducibility
|
17 |
+
device = "cuda"
|
18 |
+
torch.manual_seed(42)
|
19 |
+
torch.cuda.manual_seed(42)
|
20 |
+
|
21 |
+
print("🚀 Testing batch_invariant_kernel from Hugging Face Hub")
|
22 |
+
print(f"✅ CUDA is available. Using device: {torch.cuda.get_device_name()}")
|
23 |
+
|
24 |
+
# Test 1: Matrix Multiplication
|
25 |
+
print("\n" + "=" * 60)
|
26 |
+
print("🧪 Test 1: Persistent Matrix Multiplication")
|
27 |
+
print("=" * 60)
|
28 |
+
|
29 |
+
# Parameters for matrix multiplication
|
30 |
+
M, K, N = 512, 256, 1024
|
31 |
+
a = torch.randn(M, K, device=device, dtype=torch.float32)
|
32 |
+
b = torch.randn(K, N, device=device, dtype=torch.float32)
|
33 |
+
bias = torch.randn(N, device=device, dtype=torch.float32)
|
34 |
+
|
35 |
+
print(f"Matrix A shape: {a.shape}")
|
36 |
+
print(f"Matrix B shape: {b.shape}")
|
37 |
+
print(f"Bias shape: {bias.shape}")
|
38 |
+
|
39 |
+
# Run matrix multiplication without bias
|
40 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
41 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
42 |
+
|
43 |
+
start_event.record()
|
44 |
+
output_no_bias = batch_invariant_kernel.matmul_persistent(a, b)
|
45 |
+
end_event.record()
|
46 |
+
torch.cuda.synchronize()
|
47 |
+
time_no_bias = start_event.elapsed_time(end_event)
|
48 |
+
|
49 |
+
print(f"\nMatrix multiplication (no bias) completed!")
|
50 |
+
print(f"Output shape: {output_no_bias.shape}")
|
51 |
+
print(f"Execution time: {time_no_bias:.3f} ms")
|
52 |
+
|
53 |
+
# Run matrix multiplication with bias
|
54 |
+
start_event.record()
|
55 |
+
output_with_bias = batch_invariant_kernel.matmul_persistent(a, b, bias)
|
56 |
+
end_event.record()
|
57 |
+
torch.cuda.synchronize()
|
58 |
+
time_with_bias = start_event.elapsed_time(end_event)
|
59 |
+
|
60 |
+
print(f"\nMatrix multiplication (with bias) completed!")
|
61 |
+
print(f"Output shape: {output_with_bias.shape}")
|
62 |
+
print(f"Execution time: {time_with_bias:.3f} ms")
|
63 |
+
|
64 |
+
# Verify correctness
|
65 |
+
expected_no_bias = torch.mm(a, b)
|
66 |
+
expected_with_bias = torch.mm(a, b) + bias
|
67 |
+
|
68 |
+
max_diff_no_bias = torch.max(torch.abs(output_no_bias - expected_no_bias)).item()
|
69 |
+
max_diff_with_bias = torch.max(torch.abs(output_with_bias - expected_with_bias)).item()
|
70 |
+
|
71 |
+
print(f"Max difference (no bias): {max_diff_no_bias:.6f}")
|
72 |
+
print(f"Max difference (with bias): {max_diff_with_bias:.6f}")
|
73 |
+
|
74 |
+
# Test 2: Log Softmax
|
75 |
+
print("\n" + "=" * 60)
|
76 |
+
print("🧪 Test 2: Log Softmax")
|
77 |
+
print("=" * 60)
|
78 |
+
|
79 |
+
# Parameters for log softmax (typical attention dimensions)
|
80 |
+
batch_size = 4
|
81 |
+
seq_len = 512
|
82 |
+
vocab_size = 32000
|
83 |
+
|
84 |
+
logits = torch.randn(
|
85 |
+
batch_size, seq_len, vocab_size, device=device, dtype=torch.float32
|
86 |
+
)
|
87 |
+
print(f"Input logits shape: {logits.shape}")
|
88 |
+
|
89 |
+
# Run log softmax
|
90 |
+
start_event.record()
|
91 |
+
log_probs = batch_invariant_kernel.log_softmax(logits, dim=-1)
|
92 |
+
end_event.record()
|
93 |
+
torch.cuda.synchronize()
|
94 |
+
time_log_softmax = start_event.elapsed_time(end_event)
|
95 |
+
|
96 |
+
print(f"\nLog softmax completed!")
|
97 |
+
print(f"Output shape: {log_probs.shape}")
|
98 |
+
print(f"Execution time: {time_log_softmax:.3f} ms")
|
99 |
+
|
100 |
+
# Verify correctness
|
101 |
+
expected_log_probs = torch.log_softmax(logits, dim=-1)
|
102 |
+
max_diff_log_softmax = torch.max(torch.abs(log_probs - expected_log_probs)).item()
|
103 |
+
print(f"Max difference vs PyTorch: {max_diff_log_softmax:.6f}")
|
104 |
+
|
105 |
+
# Test 3: Mean Reduction
|
106 |
+
print("\n" + "=" * 60)
|
107 |
+
print("🧪 Test 3: Mean Dimension Reduction")
|
108 |
+
print("=" * 60)
|
109 |
+
|
110 |
+
# Parameters for mean reduction (typical layer norm dimensions)
|
111 |
+
batch_size = 8
|
112 |
+
seq_len = 256
|
113 |
+
hidden_size = 768
|
114 |
+
|
115 |
+
hidden_states = torch.randn(
|
116 |
+
batch_size, seq_len, hidden_size, device=device, dtype=torch.float32
|
117 |
+
)
|
118 |
+
print(f"Input hidden states shape: {hidden_states.shape}")
|
119 |
+
|
120 |
+
# Test reduction along different dimensions
|
121 |
+
for dim in [0, 1, 2]:
|
122 |
+
start_event.record()
|
123 |
+
mean_output = batch_invariant_kernel.mean_dim(hidden_states, dim=dim, keepdim=False)
|
124 |
+
end_event.record()
|
125 |
+
torch.cuda.synchronize()
|
126 |
+
time_mean = start_event.elapsed_time(end_event)
|
127 |
+
|
128 |
+
expected_mean = torch.mean(hidden_states, dim=dim, keepdim=False)
|
129 |
+
max_diff_mean = torch.max(torch.abs(mean_output - expected_mean)).item()
|
130 |
+
|
131 |
+
print(f"\nMean reduction along dim {dim}:")
|
132 |
+
print(f" Output shape: {mean_output.shape}")
|
133 |
+
print(f" Execution time: {time_mean:.3f} ms")
|
134 |
+
print(f" Max difference vs PyTorch: {max_diff_mean:.6f}")
|
135 |
+
|
136 |
+
# Test 4: End-to-End Attention-like Computation
|
137 |
+
print("\n" + "=" * 60)
|
138 |
+
print("🧪 Test 4: End-to-End Attention-like Computation")
|
139 |
+
print("=" * 60)
|
140 |
+
|
141 |
+
# Simulate a simple attention computation using our kernels
|
142 |
+
batch_size = 4
|
143 |
+
seq_len = 128
|
144 |
+
hidden_size = 512
|
145 |
+
num_heads = 8
|
146 |
+
head_dim = hidden_size // num_heads
|
147 |
+
|
148 |
+
# Input embeddings
|
149 |
+
x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=torch.float32)
|
150 |
+
|
151 |
+
# Weight matrices for Q, K, V projections
|
152 |
+
w_q = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
|
153 |
+
w_k = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
|
154 |
+
w_v = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
|
155 |
+
w_o = torch.randn(hidden_size, hidden_size, device=device, dtype=torch.float32)
|
156 |
+
|
157 |
+
print(f"Input shape: {x.shape}")
|
158 |
+
print("Computing Q, K, V projections using batch_invariant matmul...")
|
159 |
+
|
160 |
+
# Reshape for batch matrix multiplication
|
161 |
+
x_flat = x.view(-1, hidden_size) # (batch_size * seq_len, hidden_size)
|
162 |
+
|
163 |
+
start_event.record()
|
164 |
+
|
165 |
+
# Compute Q, K, V using our custom matmul
|
166 |
+
q_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_q)
|
167 |
+
k_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_k)
|
168 |
+
v_flat = batch_invariant_kernel.matmul_persistent(x_flat, w_v)
|
169 |
+
|
170 |
+
# Reshape to multi-head format
|
171 |
+
q = q_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
|
172 |
+
k = k_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
|
173 |
+
v = v_flat.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
|
174 |
+
|
175 |
+
# Compute attention scores
|
176 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim**0.5)
|
177 |
+
|
178 |
+
# Apply softmax using our custom log_softmax (convert to softmax)
|
179 |
+
log_attn_weights = batch_invariant_kernel.log_softmax(scores, dim=-1)
|
180 |
+
attn_weights = torch.exp(log_attn_weights)
|
181 |
+
|
182 |
+
# Apply attention to values
|
183 |
+
attn_output = torch.matmul(attn_weights, v)
|
184 |
+
|
185 |
+
# Reshape and apply output projection
|
186 |
+
attn_output = (
|
187 |
+
attn_output.transpose(1, 2).contiguous().view(batch_size * seq_len, hidden_size)
|
188 |
+
)
|
189 |
+
final_output = batch_invariant_kernel.matmul_persistent(attn_output, w_o)
|
190 |
+
final_output = final_output.view(batch_size, seq_len, hidden_size)
|
191 |
+
|
192 |
+
end_event.record()
|
193 |
+
torch.cuda.synchronize()
|
194 |
+
total_time = start_event.elapsed_time(end_event)
|
195 |
+
|
196 |
+
print(f"\nEnd-to-end attention computation completed!")
|
197 |
+
print(f"Final output shape: {final_output.shape}")
|
198 |
+
print(f"Total execution time: {total_time:.3f} ms")
|
199 |
+
print(
|
200 |
+
f"Output tensor stats - Mean: {final_output.mean().item():.4f}, Std: {final_output.std().item():.4f}"
|
201 |
+
)
|
202 |
+
|