kernel
drbh commited on
Commit
56449c1
Β·
1 Parent(s): 09eec95

feat: bump api on readme

Browse files
Files changed (2) hide show
  1. README.md +8 -14
  2. scripts/readme_example.py +3 -8
README.md CHANGED
@@ -30,8 +30,6 @@ torch.manual_seed(42)
30
  flash_attn = get_kernel("kernels-community/flash-attn")
31
  device = torch.device("cuda")
32
 
33
- print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
34
-
35
  # Create test tensors
36
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
37
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
@@ -46,12 +44,11 @@ def reference_attention(query, key, value, causal=False):
46
  # 1. Standard attention
47
  print("\n1. Standard attention:")
48
  out_ref = reference_attention(q, k, v)
49
- out_flash = flash_attn.mha_fwd(
50
  q=q,
51
  k=k,
52
  v=v,
53
  is_causal=False,
54
- softmax_scale=1.0 / (D ** 0.5), # scale factor
55
  )[0]
56
  print(f"Reference output: {out_ref.shape}")
57
  print(f"Flash output: {out_flash.shape}")
@@ -61,12 +58,11 @@ print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)
61
  print("\n2. Causal attention:")
62
 
63
  out_ref_causal = reference_attention(q, k, v, causal=True)
64
- out_causal = flash_attn.mha_fwd(
65
  q=q,
66
  k=k,
67
  v=v,
68
  is_causal=True,
69
- softmax_scale=1.0 / (D ** 0.5), # scale factor
70
  )[0]
71
  print(f"Reference causal output: {out_ref_causal.shape}")
72
  print(f"Flash causal output: {out_causal.shape}")
@@ -74,7 +70,7 @@ print(f"Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rt
74
 
75
  def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
76
  batch_size = cu_seqlens_q.shape[0] - 1
77
- # Return output in packed format
78
  total_tokens_q = q.shape[0]
79
  out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
80
 
@@ -111,7 +107,7 @@ cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
111
 
112
  out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
113
  # Custom function to handle variable
114
- out_var = flash_attn.mha_varlen_fwd(
115
  q=q_var,
116
  k=k_var,
117
  v=v_var,
@@ -119,7 +115,6 @@ out_var = flash_attn.mha_varlen_fwd(
119
  cu_seqlens_k=cu_k,
120
  max_seqlen_q=4,
121
  max_seqlen_k=5,
122
- softmax_scale=1.0 / (D ** 0.5), # scale factor
123
  )[0]
124
  print(f"Variable length output: {out_var.shape}")
125
  print(f"Reference variable length output: {out_var_ref.shape}")
@@ -133,21 +128,20 @@ uv run scripts/readme_example.py
133
  ```
134
 
135
  ```txt
136
- Reading inline script metadata from `flash-attn/scripts/readme_example.py`
137
- Fetching 4 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 33354.31it/s]
138
- Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']
139
 
140
  1. Standard attention:
141
  Reference output: torch.Size([2, 5, 4, 8])
142
  Flash output: torch.Size([2, 5, 4, 8])
143
  Outputs close: True
144
 
145
- 1. Causal attention:
146
  Reference causal output: torch.Size([2, 5, 4, 8])
147
  Flash causal output: torch.Size([2, 5, 4, 8])
148
  Outputs close: True
149
 
150
- 1. Variable length sequences:
151
  Variable length output: torch.Size([10, 4, 8])
152
  Reference variable length output: torch.Size([10, 4, 8])
153
  Outputs close: True
 
30
  flash_attn = get_kernel("kernels-community/flash-attn")
31
  device = torch.device("cuda")
32
 
 
 
33
  # Create test tensors
34
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
35
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
 
44
  # 1. Standard attention
45
  print("\n1. Standard attention:")
46
  out_ref = reference_attention(q, k, v)
47
+ out_flash = flash_attn.fwd(
48
  q=q,
49
  k=k,
50
  v=v,
51
  is_causal=False,
 
52
  )[0]
53
  print(f"Reference output: {out_ref.shape}")
54
  print(f"Flash output: {out_flash.shape}")
 
58
  print("\n2. Causal attention:")
59
 
60
  out_ref_causal = reference_attention(q, k, v, causal=True)
61
+ out_causal = flash_attn.fwd(
62
  q=q,
63
  k=k,
64
  v=v,
65
  is_causal=True,
 
66
  )[0]
67
  print(f"Reference causal output: {out_ref_causal.shape}")
68
  print(f"Flash causal output: {out_causal.shape}")
 
70
 
71
  def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
72
  batch_size = cu_seqlens_q.shape[0] - 1
73
+ # Return output in packed format (same as flash attention)
74
  total_tokens_q = q.shape[0]
75
  out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
76
 
 
107
 
108
  out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
109
  # Custom function to handle variable
110
+ out_var = flash_attn.varlen_fwd(
111
  q=q_var,
112
  k=k_var,
113
  v=v_var,
 
115
  cu_seqlens_k=cu_k,
116
  max_seqlen_q=4,
117
  max_seqlen_k=5,
 
118
  )[0]
119
  print(f"Variable length output: {out_var.shape}")
120
  print(f"Reference variable length output: {out_var_ref.shape}")
 
128
  ```
129
 
130
  ```txt
131
+ Reading inline script metadata from `scripts/readme_example.py`
132
+ Fetching 20 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20/20 [00:00<00:00, 16371.21it/s]
 
133
 
134
  1. Standard attention:
135
  Reference output: torch.Size([2, 5, 4, 8])
136
  Flash output: torch.Size([2, 5, 4, 8])
137
  Outputs close: True
138
 
139
+ 2. Causal attention:
140
  Reference causal output: torch.Size([2, 5, 4, 8])
141
  Flash causal output: torch.Size([2, 5, 4, 8])
142
  Outputs close: True
143
 
144
+ 3. Variable length sequences:
145
  Variable length output: torch.Size([10, 4, 8])
146
  Reference variable length output: torch.Size([10, 4, 8])
147
  Outputs close: True
scripts/readme_example.py CHANGED
@@ -13,8 +13,6 @@ torch.manual_seed(42)
13
  flash_attn = get_kernel("kernels-community/flash-attn")
14
  device = torch.device("cuda")
15
 
16
- print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
17
-
18
  # Create test tensors
19
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
20
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
@@ -29,12 +27,11 @@ def reference_attention(query, key, value, causal=False):
29
  # 1. Standard attention
30
  print("\n1. Standard attention:")
31
  out_ref = reference_attention(q, k, v)
32
- out_flash = flash_attn.mha_fwd(
33
  q=q,
34
  k=k,
35
  v=v,
36
  is_causal=False,
37
- softmax_scale=1.0 / (D ** 0.5), # scale factor
38
  )[0]
39
  print(f"Reference output: {out_ref.shape}")
40
  print(f"Flash output: {out_flash.shape}")
@@ -44,12 +41,11 @@ print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)
44
  print("\n2. Causal attention:")
45
 
46
  out_ref_causal = reference_attention(q, k, v, causal=True)
47
- out_causal = flash_attn.mha_fwd(
48
  q=q,
49
  k=k,
50
  v=v,
51
  is_causal=True,
52
- softmax_scale=1.0 / (D ** 0.5), # scale factor
53
  )[0]
54
  print(f"Reference causal output: {out_ref_causal.shape}")
55
  print(f"Flash causal output: {out_causal.shape}")
@@ -94,7 +90,7 @@ cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
94
 
95
  out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
96
  # Custom function to handle variable
97
- out_var = flash_attn.mha_varlen_fwd(
98
  q=q_var,
99
  k=k_var,
100
  v=v_var,
@@ -102,7 +98,6 @@ out_var = flash_attn.mha_varlen_fwd(
102
  cu_seqlens_k=cu_k,
103
  max_seqlen_q=4,
104
  max_seqlen_k=5,
105
- softmax_scale=1.0 / (D ** 0.5), # scale factor
106
  )[0]
107
  print(f"Variable length output: {out_var.shape}")
108
  print(f"Reference variable length output: {out_var_ref.shape}")
 
13
  flash_attn = get_kernel("kernels-community/flash-attn")
14
  device = torch.device("cuda")
15
 
 
 
16
  # Create test tensors
17
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
18
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
 
27
  # 1. Standard attention
28
  print("\n1. Standard attention:")
29
  out_ref = reference_attention(q, k, v)
30
+ out_flash = flash_attn.fwd(
31
  q=q,
32
  k=k,
33
  v=v,
34
  is_causal=False,
 
35
  )[0]
36
  print(f"Reference output: {out_ref.shape}")
37
  print(f"Flash output: {out_flash.shape}")
 
41
  print("\n2. Causal attention:")
42
 
43
  out_ref_causal = reference_attention(q, k, v, causal=True)
44
+ out_causal = flash_attn.fwd(
45
  q=q,
46
  k=k,
47
  v=v,
48
  is_causal=True,
 
49
  )[0]
50
  print(f"Reference causal output: {out_ref_causal.shape}")
51
  print(f"Flash causal output: {out_causal.shape}")
 
90
 
91
  out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
92
  # Custom function to handle variable
93
+ out_var = flash_attn.varlen_fwd(
94
  q=q_var,
95
  k=k_var,
96
  v=v_var,
 
98
  cu_seqlens_k=cu_k,
99
  max_seqlen_q=4,
100
  max_seqlen_k=5,
 
101
  )[0]
102
  print(f"Variable length output: {out_var.shape}")
103
  print(f"Reference variable length output: {out_var_ref.shape}")