FLUX.1-dev-fa3-aoti / optimization.py
cbensimon's picture
cbensimon HF Staff
Enable aoti
7301ed0
raw
history blame
1.08 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from fa3 import FlashFusedFluxAttnProcessor3_0
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
pipeline.transformer.fuse_qkv_projections()
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
spaces.aoti_apply(compile_transformer(), pipeline.transformer)