kernel
drbh commited on
Commit
f2b2454
·
1 Parent(s): 0518250

fix: add readme example script

Browse files
Files changed (1) hide show
  1. scripts/readme_example.py +109 -0
scripts/readme_example.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "numpy",
4
+ # "torch",
5
+ # "kernels"
6
+ # ]
7
+ # ///
8
+ import torch
9
+ from kernels import get_kernel
10
+
11
+ # Setup
12
+ torch.manual_seed(42)
13
+ flash_attn = get_kernel("kernels-community/flash-attn")
14
+ device = torch.device("cuda")
15
+
16
+ print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
17
+
18
+ # Create test tensors
19
+ B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
20
+ q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
21
+
22
+ # Reference implementation using PyTorch SDPA
23
+ def reference_attention(query, key, value, causal=False):
24
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
25
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
26
+ out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
27
+ return out.transpose(1, 2).contiguous()
28
+
29
+ # 1. Standard attention
30
+ print("\n1. Standard attention:")
31
+ out_ref = reference_attention(q, k, v)
32
+ out_flash = flash_attn.mha_fwd(
33
+ q=q,
34
+ k=k,
35
+ v=v,
36
+ is_causal=False,
37
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
38
+ )[0]
39
+ print(f"Reference output: {out_ref.shape}")
40
+ print(f"Flash output: {out_flash.shape}")
41
+ print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}")
42
+
43
+ # 2. Causal attention (for autoregressive models)
44
+ print("\n2. Causal attention:")
45
+
46
+ out_ref_causal = reference_attention(q, k, v, causal=True)
47
+ out_causal = flash_attn.mha_fwd(
48
+ q=q,
49
+ k=k,
50
+ v=v,
51
+ is_causal=True,
52
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
53
+ )[0]
54
+ print(f"Reference causal output: {out_ref_causal.shape}")
55
+ print(f"Flash causal output: {out_causal.shape}")
56
+ print(f"Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}")
57
+
58
+ def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
59
+ batch_size = cu_seqlens_q.shape[0] - 1
60
+ # Return output in packed format (same as flash attention)
61
+ total_tokens_q = q.shape[0]
62
+ out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
63
+
64
+ for b in range(batch_size):
65
+ start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
66
+ start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
67
+
68
+ # Extract slices for this batch
69
+ q_slice = q[start_q:end_q] # Shape: (seq_len_q, H, D)
70
+ k_slice = k[start_k:end_k] # Shape: (seq_len_k, H, D)
71
+ v_slice = v[start_k:end_k] # Shape: (seq_len_k, H, D)
72
+
73
+ # Add batch dimension for reference_attention
74
+ q_slice = q_slice.unsqueeze(0) # Shape: (1, seq_len_q, H, D)
75
+ k_slice = k_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
76
+ v_slice = v_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
77
+
78
+ # Compute attention and remove batch dimension
79
+ attn_out = reference_attention(q_slice, k_slice, v_slice, causal=causal)
80
+ attn_out = attn_out.squeeze(0) # Shape: (seq_len_q, H, D)
81
+
82
+ # Place result in output tensor (packed format)
83
+ out[start_q:end_q] = attn_out
84
+
85
+ return out
86
+
87
+ # 3. Variable length sequences (packed format)
88
+ print("\n3. Variable length sequences:")
89
+ # Pack sequences of lengths [3,4,3] for q and [4,5,3] for k into single tensors
90
+ q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
91
+ k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
92
+ cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) # cumulative sequence lengths
93
+ cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
94
+
95
+ out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
96
+ # Custom function to handle variable
97
+ out_var = flash_attn.mha_varlen_fwd(
98
+ q=q_var,
99
+ k=k_var,
100
+ v=v_var,
101
+ cu_seqlens_q=cu_q,
102
+ cu_seqlens_k=cu_k,
103
+ max_seqlen_q=4,
104
+ max_seqlen_k=5,
105
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
106
+ )[0]
107
+ print(f"Variable length output: {out_var.shape}")
108
+ print(f"Reference variable length output: {out_var_ref.shape}")
109
+ print(f"Outputs close: {torch.allclose(out_var, out_var_ref, atol=1e-2, rtol=1e-3)}")