Spaces:
Runtime error
Runtime error
File size: 6,045 Bytes
7721cb1 b82f02e 7721cb1 b82f02e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
import logging
from pathlib import Path
from typing import Optional, Union, Dict, Any
# Set up logging
logging.basicConfig(level=logging.INFO)
class GGUFUNetLoader:
"""
A class for loading and managing GGUF-formatted UNet models for diffusion.
Supports quantized models with custom patch handling.
"""
def __init__(self):
self.model = None
self.patches = {}
self.backup = {}
self.load_device = "cuda" if torch.cuda.is_available() else "cpu"
self.offload_device = "cpu"
@staticmethod
def is_quantized(weight: torch.Tensor) -> bool:
"""Check if a tensor is quantized."""
return hasattr(weight, "patches")
def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor:
"""
Apply patches to model weights with quantization support.
Args:
key: The parameter key to patch
weight: The weight tensor to patch
device_to: Target device for the patched weight
Returns:
Patched weight tensor
"""
if key not in self.patches:
return weight
if self.is_quantized(weight):
# Handle quantized weights
out_weight = weight.to(device_to if device_to else self.load_device)
patches = self.patches[key]
out_weight.patches = [(self.calculate_weight, patches, key)]
return out_weight
else:
# Handle regular weights
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device)
temp_weight = weight.to(torch.float32)
if device_to:
temp_weight = temp_weight.to(device_to)
# Apply patches
for patch in self.patches[key]:
temp_weight += patch
return temp_weight.to(weight.dtype)
def load_model(self,
model_path: Union[str, Path],
config: Optional[Dict[str, Any]] = None) -> None:
"""
Load a GGUF model from disk.
Args:
model_path: Path to the GGUF model file
config: Optional configuration dictionary for model loading
"""
try:
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
if not str(model_path).endswith('.gguf'):
raise ValueError("Not a GGUF model file")
# Load the model (implementation would depend on your GGUF loader)
from .gguf_loader import load_gguf_model # You'd need to implement this
self.model = load_gguf_model(
model_path,
device=self.load_device,
config=config or {}
)
logging.info(f"Successfully loaded GGUF model from {model_path}")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
def add_patch(self, key: str, patch: torch.Tensor) -> None:
"""
Add a patch for a specific model parameter.
Args:
key: Parameter key to patch
patch: The patch tensor to apply
"""
if key not in self.patches:
self.patches[key] = []
self.patches[key].append(patch)
def clear_patches(self) -> None:
"""Remove all patches from the model."""
self.patches.clear()
# Clear quantized patches
if self.model:
for param in self.model.parameters():
if self.is_quantized(param):
param.patches = []
def to(self, device: str) -> 'GGUFUNetLoader':
"""
Move model to specified device.
Args:
device: Target device ("cuda", "cpu", etc.)
Returns:
Self for method chaining
"""
if self.model:
self.model.to(device)
self.load_device = device
return self
@staticmethod
def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor:
"""
Calculate final weight by applying patches.
Args:
patches: List of patches to apply
base_weight: Base weight tensor
key: Parameter key
Returns:
Patched weight tensor
"""
result = base_weight.clone()
for patch in patches:
result += patch
return result
def main():
# Initialize the loader
loader = GGUFUNetLoader()
# Specify model path
model_path = Path("path/to/your/model.gguf")
ckpt_path = (
"https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
# https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf
#pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf")
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
# Optional configuration for model loading
config = {
"attention_slicing": "auto",
"channels_last": True
}
if __name__ == "__main__":
main() |