K00B404 commited on
Commit
b82f02e
·
verified ·
1 Parent(s): 7721cb1

Update gguf_loader.py

Browse files
Files changed (1) hide show
  1. gguf_loader.py +45 -1
gguf_loader.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import logging
3
  from pathlib import Path
4
  from typing import Optional, Union, Dict, Any
 
 
5
 
6
  class GGUFUNetLoader:
7
  """
@@ -141,4 +143,46 @@ class GGUFUNetLoader:
141
  result = base_weight.clone()
142
  for patch in patches:
143
  result += patch
144
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import logging
3
  from pathlib import Path
4
  from typing import Optional, Union, Dict, Any
5
+ # Set up logging
6
+ logging.basicConfig(level=logging.INFO)
7
 
8
  class GGUFUNetLoader:
9
  """
 
143
  result = base_weight.clone()
144
  for patch in patches:
145
  result += patch
146
+ return result
147
+
148
+
149
+ def main():
150
+ # Initialize the loader
151
+ loader = GGUFUNetLoader()
152
+
153
+ # Specify model path
154
+ model_path = Path("path/to/your/model.gguf")
155
+
156
+ ckpt_path = (
157
+ "https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf"
158
+ )
159
+
160
+ transformer = FluxTransformer2DModel.from_single_file(
161
+ ckpt_path,
162
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
163
+ torch_dtype=torch.bfloat16,
164
+ )
165
+
166
+ pipe = FluxPipeline.from_pretrained(
167
+ "black-forest-labs/FLUX.1-dev",
168
+ transformer=transformer,
169
+ torch_dtype=torch.bfloat16,
170
+ )
171
+ # https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf
172
+ #pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf")
173
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
174
+ pipe.fuse_lora(lora_scale=0.125)
175
+
176
+ pipe.enable_model_cpu_offload()
177
+ prompt = "A cat holding a sign that says hello world"
178
+ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
179
+ image.save("flux-gguf.png")
180
+ # Optional configuration for model loading
181
+ config = {
182
+ "attention_slicing": "auto",
183
+ "channels_last": True
184
+ }
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()