Quantizing flux

#1
by ukaprch - opened

Not a problem but a general observation:
I don't see how this approach works any better or faster than
directly quantizing the transformer and text_encoder_2
directly after loading the pipeline:

repo_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
from optimum.quanto import freeze, qint8, quantize
quantize(pipe.transformer, weights=qint8)
print("Running transformer freeze DEV")
freeze(pipe.transformer)
quantize(pipe.text_encoder_2, weights=qint8)
print("Running text_encoder_2 freeze DEV")
freeze(pipe.text_encoder_2)

What am I missing here?

Known issue: https://github.com/huggingface/optimum-quanto/issues/270
Disabling GEMM reduces the loading time from 250 seconds to 20 seconds.

You can disable GEMM if you want:

from optimum import quanto
quanto.tensor.qbits.QBitsTensor.create = lambda *args, **kwargs: quanto.tensor.qbits.QBitsTensor(*args, **kwargs)

Sign up or log in to comment