Hyper-FLUX-8Steps-LoRA_CPU / gguf_loader.py
K00B404's picture
Update gguf_loader.py
b82f02e verified
import torch
import logging
from pathlib import Path
from typing import Optional, Union, Dict, Any
# Set up logging
logging.basicConfig(level=logging.INFO)
class GGUFUNetLoader:
"""
A class for loading and managing GGUF-formatted UNet models for diffusion.
Supports quantized models with custom patch handling.
"""
def __init__(self):
self.model = None
self.patches = {}
self.backup = {}
self.load_device = "cuda" if torch.cuda.is_available() else "cpu"
self.offload_device = "cpu"
@staticmethod
def is_quantized(weight: torch.Tensor) -> bool:
"""Check if a tensor is quantized."""
return hasattr(weight, "patches")
def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor:
"""
Apply patches to model weights with quantization support.
Args:
key: The parameter key to patch
weight: The weight tensor to patch
device_to: Target device for the patched weight
Returns:
Patched weight tensor
"""
if key not in self.patches:
return weight
if self.is_quantized(weight):
# Handle quantized weights
out_weight = weight.to(device_to if device_to else self.load_device)
patches = self.patches[key]
out_weight.patches = [(self.calculate_weight, patches, key)]
return out_weight
else:
# Handle regular weights
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device)
temp_weight = weight.to(torch.float32)
if device_to:
temp_weight = temp_weight.to(device_to)
# Apply patches
for patch in self.patches[key]:
temp_weight += patch
return temp_weight.to(weight.dtype)
def load_model(self,
model_path: Union[str, Path],
config: Optional[Dict[str, Any]] = None) -> None:
"""
Load a GGUF model from disk.
Args:
model_path: Path to the GGUF model file
config: Optional configuration dictionary for model loading
"""
try:
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
if not str(model_path).endswith('.gguf'):
raise ValueError("Not a GGUF model file")
# Load the model (implementation would depend on your GGUF loader)
from .gguf_loader import load_gguf_model # You'd need to implement this
self.model = load_gguf_model(
model_path,
device=self.load_device,
config=config or {}
)
logging.info(f"Successfully loaded GGUF model from {model_path}")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
def add_patch(self, key: str, patch: torch.Tensor) -> None:
"""
Add a patch for a specific model parameter.
Args:
key: Parameter key to patch
patch: The patch tensor to apply
"""
if key not in self.patches:
self.patches[key] = []
self.patches[key].append(patch)
def clear_patches(self) -> None:
"""Remove all patches from the model."""
self.patches.clear()
# Clear quantized patches
if self.model:
for param in self.model.parameters():
if self.is_quantized(param):
param.patches = []
def to(self, device: str) -> 'GGUFUNetLoader':
"""
Move model to specified device.
Args:
device: Target device ("cuda", "cpu", etc.)
Returns:
Self for method chaining
"""
if self.model:
self.model.to(device)
self.load_device = device
return self
@staticmethod
def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor:
"""
Calculate final weight by applying patches.
Args:
patches: List of patches to apply
base_weight: Base weight tensor
key: Parameter key
Returns:
Patched weight tensor
"""
result = base_weight.clone()
for patch in patches:
result += patch
return result
def main():
# Initialize the loader
loader = GGUFUNetLoader()
# Specify model path
model_path = Path("path/to/your/model.gguf")
ckpt_path = (
"https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
# https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf
#pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf")
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
# Optional configuration for model loading
config = {
"attention_slicing": "auto",
"channels_last": True
}
if __name__ == "__main__":
main()