Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Union, Dict, Any | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| class GGUFUNetLoader: | |
| """ | |
| A class for loading and managing GGUF-formatted UNet models for diffusion. | |
| Supports quantized models with custom patch handling. | |
| """ | |
| def __init__(self): | |
| self.model = None | |
| self.patches = {} | |
| self.backup = {} | |
| self.load_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.offload_device = "cpu" | |
| def is_quantized(weight: torch.Tensor) -> bool: | |
| """Check if a tensor is quantized.""" | |
| return hasattr(weight, "patches") | |
| def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor: | |
| """ | |
| Apply patches to model weights with quantization support. | |
| Args: | |
| key: The parameter key to patch | |
| weight: The weight tensor to patch | |
| device_to: Target device for the patched weight | |
| Returns: | |
| Patched weight tensor | |
| """ | |
| if key not in self.patches: | |
| return weight | |
| if self.is_quantized(weight): | |
| # Handle quantized weights | |
| out_weight = weight.to(device_to if device_to else self.load_device) | |
| patches = self.patches[key] | |
| out_weight.patches = [(self.calculate_weight, patches, key)] | |
| return out_weight | |
| else: | |
| # Handle regular weights | |
| if key not in self.backup: | |
| self.backup[key] = weight.to(device=self.offload_device) | |
| temp_weight = weight.to(torch.float32) | |
| if device_to: | |
| temp_weight = temp_weight.to(device_to) | |
| # Apply patches | |
| for patch in self.patches[key]: | |
| temp_weight += patch | |
| return temp_weight.to(weight.dtype) | |
| def load_model(self, | |
| model_path: Union[str, Path], | |
| config: Optional[Dict[str, Any]] = None) -> None: | |
| """ | |
| Load a GGUF model from disk. | |
| Args: | |
| model_path: Path to the GGUF model file | |
| config: Optional configuration dictionary for model loading | |
| """ | |
| try: | |
| model_path = Path(model_path) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| if not str(model_path).endswith('.gguf'): | |
| raise ValueError("Not a GGUF model file") | |
| # Load the model (implementation would depend on your GGUF loader) | |
| from .gguf_loader import load_gguf_model # You'd need to implement this | |
| self.model = load_gguf_model( | |
| model_path, | |
| device=self.load_device, | |
| config=config or {} | |
| ) | |
| logging.info(f"Successfully loaded GGUF model from {model_path}") | |
| except Exception as e: | |
| logging.error(f"Error loading model: {str(e)}") | |
| raise | |
| def add_patch(self, key: str, patch: torch.Tensor) -> None: | |
| """ | |
| Add a patch for a specific model parameter. | |
| Args: | |
| key: Parameter key to patch | |
| patch: The patch tensor to apply | |
| """ | |
| if key not in self.patches: | |
| self.patches[key] = [] | |
| self.patches[key].append(patch) | |
| def clear_patches(self) -> None: | |
| """Remove all patches from the model.""" | |
| self.patches.clear() | |
| # Clear quantized patches | |
| if self.model: | |
| for param in self.model.parameters(): | |
| if self.is_quantized(param): | |
| param.patches = [] | |
| def to(self, device: str) -> 'GGUFUNetLoader': | |
| """ | |
| Move model to specified device. | |
| Args: | |
| device: Target device ("cuda", "cpu", etc.) | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if self.model: | |
| self.model.to(device) | |
| self.load_device = device | |
| return self | |
| def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor: | |
| """ | |
| Calculate final weight by applying patches. | |
| Args: | |
| patches: List of patches to apply | |
| base_weight: Base weight tensor | |
| key: Parameter key | |
| Returns: | |
| Patched weight tensor | |
| """ | |
| result = base_weight.clone() | |
| for patch in patches: | |
| result += patch | |
| return result | |
| def main(): | |
| # Initialize the loader | |
| loader = GGUFUNetLoader() | |
| # Specify model path | |
| model_path = Path("path/to/your/model.gguf") | |
| ckpt_path = ( | |
| "https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf" | |
| ) | |
| transformer = FluxTransformer2DModel.from_single_file( | |
| ckpt_path, | |
| quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf | |
| #pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf") | |
| pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
| pipe.fuse_lora(lora_scale=0.125) | |
| pipe.enable_model_cpu_offload() | |
| prompt = "A cat holding a sign that says hello world" | |
| image = pipe(prompt, generator=torch.manual_seed(0)).images[0] | |
| image.save("flux-gguf.png") | |
| # Optional configuration for model loading | |
| config = { | |
| "attention_slicing": "auto", | |
| "channels_last": True | |
| } | |
| if __name__ == "__main__": | |
| main() |