import torch import logging from pathlib import Path from typing import Optional, Union, Dict, Any # Set up logging logging.basicConfig(level=logging.INFO) class GGUFUNetLoader: """ A class for loading and managing GGUF-formatted UNet models for diffusion. Supports quantized models with custom patch handling. """ def __init__(self): self.model = None self.patches = {} self.backup = {} self.load_device = "cuda" if torch.cuda.is_available() else "cpu" self.offload_device = "cpu" @staticmethod def is_quantized(weight: torch.Tensor) -> bool: """Check if a tensor is quantized.""" return hasattr(weight, "patches") def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor: """ Apply patches to model weights with quantization support. Args: key: The parameter key to patch weight: The weight tensor to patch device_to: Target device for the patched weight Returns: Patched weight tensor """ if key not in self.patches: return weight if self.is_quantized(weight): # Handle quantized weights out_weight = weight.to(device_to if device_to else self.load_device) patches = self.patches[key] out_weight.patches = [(self.calculate_weight, patches, key)] return out_weight else: # Handle regular weights if key not in self.backup: self.backup[key] = weight.to(device=self.offload_device) temp_weight = weight.to(torch.float32) if device_to: temp_weight = temp_weight.to(device_to) # Apply patches for patch in self.patches[key]: temp_weight += patch return temp_weight.to(weight.dtype) def load_model(self, model_path: Union[str, Path], config: Optional[Dict[str, Any]] = None) -> None: """ Load a GGUF model from disk. Args: model_path: Path to the GGUF model file config: Optional configuration dictionary for model loading """ try: model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") if not str(model_path).endswith('.gguf'): raise ValueError("Not a GGUF model file") # Load the model (implementation would depend on your GGUF loader) from .gguf_loader import load_gguf_model # You'd need to implement this self.model = load_gguf_model( model_path, device=self.load_device, config=config or {} ) logging.info(f"Successfully loaded GGUF model from {model_path}") except Exception as e: logging.error(f"Error loading model: {str(e)}") raise def add_patch(self, key: str, patch: torch.Tensor) -> None: """ Add a patch for a specific model parameter. Args: key: Parameter key to patch patch: The patch tensor to apply """ if key not in self.patches: self.patches[key] = [] self.patches[key].append(patch) def clear_patches(self) -> None: """Remove all patches from the model.""" self.patches.clear() # Clear quantized patches if self.model: for param in self.model.parameters(): if self.is_quantized(param): param.patches = [] def to(self, device: str) -> 'GGUFUNetLoader': """ Move model to specified device. Args: device: Target device ("cuda", "cpu", etc.) Returns: Self for method chaining """ if self.model: self.model.to(device) self.load_device = device return self @staticmethod def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor: """ Calculate final weight by applying patches. Args: patches: List of patches to apply base_weight: Base weight tensor key: Parameter key Returns: Patched weight tensor """ result = base_weight.clone() for patch in patches: result += patch return result def main(): # Initialize the loader loader = GGUFUNetLoader() # Specify model path model_path = Path("path/to/your/model.gguf") ckpt_path = ( "https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf" ) transformer = FluxTransformer2DModel.from_single_file( ckpt_path, quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), torch_dtype=torch.bfloat16, ) pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, ) # https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf #pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf") pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.enable_model_cpu_offload() prompt = "A cat holding a sign that says hello world" image = pipe(prompt, generator=torch.manual_seed(0)).images[0] image.save("flux-gguf.png") # Optional configuration for model loading config = { "attention_slicing": "auto", "channels_last": True } if __name__ == "__main__": main()