kernel
File size: 5,264 Bytes
a7165c8
 
c743a32
a7165c8
 
 
 
 
c743a32
9002ff5
d774688
 
dd2f0f9
 
9002ff5
a7165c8
 
c743a32
a7165c8
 
 
 
 
 
 
 
39b4aba
 
 
 
 
 
 
876ac68
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
9002ff5
 
 
a7165c8
876ac68
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
a7165c8
 
c743a32
876ac68
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
a7165c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
[general]
name = "flash_attn"
universal=false

[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]

[kernel.flash_attn]
backend = "cuda"
cuda-capabilities = [
  "8.0",
  "9.0",
  "10.0",
  "12.0",
]
src = [
  "flash_attn/flash_api.cpp",

  "flash_attn/src/philox_unpack.cuh",
  "flash_attn/src/namespace_config.h",
  "flash_attn/src/hardware_info.h",
  "flash_attn/src/flash.h",
  "flash_attn/src/static_switch.h",
  "flash_attn/src/alibi.h",
  "flash_attn/src/block_info.h",
  "flash_attn/src/dropout.h",
  "flash_attn/src/kernel_traits.h",
  "flash_attn/src/mask.h",
  "flash_attn/src/philox.cuh",
  "flash_attn/src/rotary.h",
  "flash_attn/src/softmax.h",
  "flash_attn/src/utils.h",

  # bwd kernels - commented out since mha_bwd functions are disabled
  "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_kernel.h",
  "flash_attn/src/flash_bwd_launch_template.h",
  "flash_attn/src/flash_bwd_preprocess_kernel.h",

  ## fwd kernels - keeping only FP16 kernels for hdim 64 and 128 (both causal and non-causal)
  "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_kernel.h",
  "flash_attn/src/flash_fwd_launch_template.h",

  # split kernels - keeping only FP16 kernels for hdim 64 and 128 (both causal and non-causal)
  "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
]
depends = ["torch", "cutlass_3_6"]