Spaces:
Running
on
Zero
Running
on
Zero
Update skyreels_v2_infer/modules/attention.py
Browse files
skyreels_v2_infer/modules/attention.py
CHANGED
@@ -104,7 +104,7 @@ def flash_attention(
|
|
104 |
deterministic=deterministic,
|
105 |
)[0].unflatten(0, (b, lq))
|
106 |
else:
|
107 |
-
|
108 |
x = flash_attn.flash_attn_varlen_func(
|
109 |
q=q,
|
110 |
k=k,
|
@@ -144,6 +144,23 @@ def attention(
|
|
144 |
dtype=torch.bfloat16,
|
145 |
fa_version=None,
|
146 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
if q_lens is not None or k_lens is not None:
|
148 |
warnings.warn(
|
149 |
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
@@ -159,4 +176,4 @@ def attention(
|
|
159 |
)
|
160 |
|
161 |
out = out.transpose(1, 2).contiguous()
|
162 |
-
return out
|
|
|
104 |
deterministic=deterministic,
|
105 |
)[0].unflatten(0, (b, lq))
|
106 |
else:
|
107 |
+
assert FLASH_ATTN_2_AVAILABLE
|
108 |
x = flash_attn.flash_attn_varlen_func(
|
109 |
q=q,
|
110 |
k=k,
|
|
|
144 |
dtype=torch.bfloat16,
|
145 |
fa_version=None,
|
146 |
):
|
147 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
148 |
+
return flash_attention(
|
149 |
+
q=q,
|
150 |
+
k=k,
|
151 |
+
v=v,
|
152 |
+
q_lens=q_lens,
|
153 |
+
k_lens=k_lens,
|
154 |
+
dropout_p=dropout_p,
|
155 |
+
softmax_scale=softmax_scale,
|
156 |
+
q_scale=q_scale,
|
157 |
+
causal=causal,
|
158 |
+
window_size=window_size,
|
159 |
+
deterministic=deterministic,
|
160 |
+
dtype=dtype,
|
161 |
+
version=fa_version,
|
162 |
+
)
|
163 |
+
else:
|
164 |
if q_lens is not None or k_lens is not None:
|
165 |
warnings.warn(
|
166 |
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
|
|
176 |
)
|
177 |
|
178 |
out = out.transpose(1, 2).contiguous()
|
179 |
+
return out
|