John6666 commited on
Commit
419d7f4
·
verified ·
1 Parent(s): 8a9850c

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +79 -20
  2. requirements.txt +2 -3
handler.py CHANGED
@@ -1,42 +1,101 @@
 
 
1
  import os
2
  from typing import Any, Dict
3
 
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
 
 
7
 
8
  IS_COMPILE = False
 
 
9
 
10
  if IS_COMPILE:
11
  import torch._dynamo
12
  torch._dynamo.config.suppress_errors = True
13
 
14
- #from huggingface_inference_toolkit.logging import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def compile_pipeline(pipe) -> Any:
 
 
 
 
 
 
 
 
 
 
17
  pipe.transformer.to(memory_format=torch.channels_last)
18
- pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
19
- #pipe.vae.to(memory_format=torch.channels_last)
20
- #pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
 
21
  return pipe
22
 
23
  class EndpointHandler:
24
  def __init__(self, path=""):
25
- repo_id = "camenduru/FLUX.1-dev-diffusers"
26
- #repo_id = "NoMoreCopyright/FLUX.1-dev-test"
27
- dtype = torch.bfloat16
28
- quantization_config = TorchAoConfig("int8dq")
29
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
30
- #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
- self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
- self.pipeline.transformer.fuse_qkv_projections()
33
- self.pipeline.vae.fuse_qkv_projections()
34
- if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
35
- self.pipeline.to("cuda")
36
-
37
- @torch.inference_mode()
38
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
39
- #logger.info(f"Received incoming request with {data=}")
40
 
41
  if "inputs" in data and isinstance(data["inputs"], str):
42
  prompt = data.pop("inputs")
@@ -50,7 +109,7 @@ class EndpointHandler:
50
 
51
  parameters = data.pop("parameters", {})
52
 
53
- num_inference_steps = parameters.get("num_inference_steps", 28)
54
  width = parameters.get("width", 1024)
55
  height = parameters.get("height", 1024)
56
  guidance_scale = parameters.get("guidance_scale", 3.5)
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+
3
  import os
4
  from typing import Any, Dict
5
 
6
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
  from PIL import Image
8
  import torch
9
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
+ from huggingface_hub import hf_hub_download
11
 
12
  IS_COMPILE = False
13
+ IS_TURBO = False
14
+ IS_4BIT = True
15
 
16
  if IS_COMPILE:
17
  import torch._dynamo
18
  torch._dynamo.config.suppress_errors = True
19
 
20
+ from huggingface_inference_toolkit.logging import logger
21
+
22
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
23
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
24
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
25
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
26
+ pipe.transformer.fuse_qkv_projections()
27
+ pipe.vae.fuse_qkv_projections()
28
+ pipe.to("cuda")
29
+ return pipe
30
+
31
+ def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
32
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
33
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
34
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
35
+ pipe.transformer.fuse_qkv_projections()
36
+ pipe.vae.fuse_qkv_projections()
37
+ pipe.transformer.to(memory_format=torch.channels_last)
38
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
39
+ pipe.vae.to(memory_format=torch.channels_last)
40
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
41
+ pipe.to("cuda")
42
+ return pipe
43
+
44
+ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
45
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
46
+ pipe.transformer.fuse_qkv_projections()
47
+ pipe.vae.fuse_qkv_projections()
48
+ pipe.transformer.to(memory_format=torch.channels_last)
49
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
50
+ pipe.vae.to(memory_format=torch.channels_last)
51
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
52
+ pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
53
+ pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
54
+ pipe.to("cuda")
55
+ return pipe
56
+
57
+ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
58
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
59
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
60
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
61
+ pipe.fuse_lora()
62
+ pipe.transformer.fuse_qkv_projections()
63
+ pipe.vae.fuse_qkv_projections()
64
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
65
+ quantize_(pipe.transformer, weight, device="cuda")
66
+ quantize_(pipe.vae, weight, device="cuda")
67
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
68
+ pipe.to("cuda")
69
+ return pipe
70
 
71
+ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
72
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
73
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
74
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
75
+ pipe.fuse_lora()
76
+ pipe.transformer.fuse_qkv_projections()
77
+ pipe.vae.fuse_qkv_projections()
78
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
79
+ quantize_(pipe.transformer, weight, device="cuda")
80
+ quantize_(pipe.vae, weight, device="cuda")
81
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
82
  pipe.transformer.to(memory_format=torch.channels_last)
83
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
84
+ pipe.vae.to(memory_format=torch.channels_last)
85
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
86
+ pipe.to("cuda")
87
  return pipe
88
 
89
  class EndpointHandler:
90
  def __init__(self, path=""):
91
+ repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
92
+ #dtype = torch.bfloat16
93
+ dtype = torch.float16 # for older nVidia GPUs
94
+ if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
95
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
96
+
 
 
 
 
 
 
 
97
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
98
+ logger.info(f"Received incoming request with {data=}")
99
 
100
  if "inputs" in data and isinstance(data["inputs"], str):
101
  prompt = data.pop("inputs")
 
109
 
110
  parameters = data.pop("parameters", {})
111
 
112
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
113
  width = parameters.get("width", 1024)
114
  height = parameters.get("height", 1024)
115
  guidance_scale = parameters.get("guidance_scale", 3.5)
requirements.txt CHANGED
@@ -1,15 +1,14 @@
1
  huggingface_hub
2
  torch==2.4.0
3
  torchvision
 
4
  torchao==0.9.0
5
- diffusers
6
  peft
7
- accelerate
8
  transformers
9
  numpy
10
  scipy
11
  Pillow
12
  sentencepiece
13
  protobuf
14
- pytorch-lightning
15
  triton
 
1
  huggingface_hub
2
  torch==2.4.0
3
  torchvision
4
+ torchaudio
5
  torchao==0.9.0
6
+ diffusers==0.32.2
7
  peft
 
8
  transformers
9
  numpy
10
  scipy
11
  Pillow
12
  sentencepiece
13
  protobuf
 
14
  triton