VRAM usage jumps above 24GB at the end of inference

#2
by sanntann - opened

Thank you! This works super well and is really easy to use.

I am wondering if the jump in VRAM usage at the end of the 4-step inference causes a slowdown that could be avoided somehow. I'm new to this tech so its a bit hard to tell.

  1. When I run it on Windows (using inference instructions in README) then during the inference of the first image it takes 22GB of VRAM and the 4 inference steps run in just around 2 seconds.
  2. Then VRAM usage jumps by 4GB, pushing it above the 24GB VRAM of the GPU and results in 2GB Shared GPU Memory being used.
  3. At this point, there is a 6 second delay until the image is returned from the pipe.

So it looks as if this model could do 2 second inference on a Nvidia 4090 but is held back by this final stage of inference which leads it to take 8 seconds total instead.

  • What might be this operation after the 4 inference steps that requires the extra 4GB of memory and causes this slowdown because it can not fit in VRAM?
  • Would moving that step to CPU speed it up or enable running it in parallel to inference queued prompts?
  • Could the pipeline be shrunk just 4GB in some other way to avoid this slowdown?

Thanks!

By pt 2. in the first para, do you mean it takes less VRAM when generating the very first image after loading the pipeline? And then it shots up?

To debug this further, I would consider taking the original pipeline file and logging the memory consumption with torch.cuda.memory_allocated() in strategic locations. This will give us better signals and will help us localize the point in code that leads to memory growth.

VRAM shoots up right after the completion of final inference step of the first image (based on the tqdm process log in the console), and stays up for any following images. It is not clear whether VRAM usage drops in the 2-second period as inference steps happen for the following images, or whether it just stays up. The processing duration of the first and any following images seems identicaly, roughly 2 seconds for the 4 infrerence steps and then another 6 second period until the image is returned.

I will try get the memory usage logs for different steps in the pipeline.

Maybe it's because of the VAE. Can you try calling pipeline.vael.enable_slicing() before performing inference? It might have a small latency penalty but should reduce the memory.

Also, 2 seconds is for the denoising step. There is a separate decoding step which we need to account for as well:
https://github.com/huggingface/diffusers/blob/a98a839de75f1ad82d8d200c3bc2e4ff89929081/src/diffusers/pipelines/flux/pipeline_flux.py#L775

Maybe to have full transparency into what step consumers how much time and how much memory, it would be better to decouple all the crucial steps as I have done here:
https://gist.github.com/sayakpaul/23862a2e7f5ab73dfdcc513751289bea

You will have to make some changes to suit the int8 weights, but I think it's doable.

Sign up or log in to comment