1inkusFace commited on
Commit
67ad25f
·
verified ·
1 Parent(s): 3150bd7

Update skyreels_v2_infer/modules/attention.py

Browse files
skyreels_v2_infer/modules/attention.py CHANGED
@@ -104,7 +104,7 @@ def flash_attention(
104
  deterministic=deterministic,
105
  )[0].unflatten(0, (b, lq))
106
  else:
107
- #assert FLASH_ATTN_2_AVAILABLE
108
  x = flash_attn.flash_attn_varlen_func(
109
  q=q,
110
  k=k,
@@ -144,6 +144,23 @@ def attention(
144
  dtype=torch.bfloat16,
145
  fa_version=None,
146
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  if q_lens is not None or k_lens is not None:
148
  warnings.warn(
149
  "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
@@ -159,4 +176,4 @@ def attention(
159
  )
160
 
161
  out = out.transpose(1, 2).contiguous()
162
- return out
 
104
  deterministic=deterministic,
105
  )[0].unflatten(0, (b, lq))
106
  else:
107
+ assert FLASH_ATTN_2_AVAILABLE
108
  x = flash_attn.flash_attn_varlen_func(
109
  q=q,
110
  k=k,
 
144
  dtype=torch.bfloat16,
145
  fa_version=None,
146
  ):
147
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
148
+ return flash_attention(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ q_lens=q_lens,
153
+ k_lens=k_lens,
154
+ dropout_p=dropout_p,
155
+ softmax_scale=softmax_scale,
156
+ q_scale=q_scale,
157
+ causal=causal,
158
+ window_size=window_size,
159
+ deterministic=deterministic,
160
+ dtype=dtype,
161
+ version=fa_version,
162
+ )
163
+ else:
164
  if q_lens is not None or k_lens is not None:
165
  warnings.warn(
166
  "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
 
176
  )
177
 
178
  out = out.transpose(1, 2).contiguous()
179
+ return out