Spaces:
Runtime error
Runtime error
import torch | |
import logging | |
from pathlib import Path | |
from typing import Optional, Union, Dict, Any | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
class GGUFUNetLoader: | |
""" | |
A class for loading and managing GGUF-formatted UNet models for diffusion. | |
Supports quantized models with custom patch handling. | |
""" | |
def __init__(self): | |
self.model = None | |
self.patches = {} | |
self.backup = {} | |
self.load_device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.offload_device = "cpu" | |
def is_quantized(weight: torch.Tensor) -> bool: | |
"""Check if a tensor is quantized.""" | |
return hasattr(weight, "patches") | |
def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor: | |
""" | |
Apply patches to model weights with quantization support. | |
Args: | |
key: The parameter key to patch | |
weight: The weight tensor to patch | |
device_to: Target device for the patched weight | |
Returns: | |
Patched weight tensor | |
""" | |
if key not in self.patches: | |
return weight | |
if self.is_quantized(weight): | |
# Handle quantized weights | |
out_weight = weight.to(device_to if device_to else self.load_device) | |
patches = self.patches[key] | |
out_weight.patches = [(self.calculate_weight, patches, key)] | |
return out_weight | |
else: | |
# Handle regular weights | |
if key not in self.backup: | |
self.backup[key] = weight.to(device=self.offload_device) | |
temp_weight = weight.to(torch.float32) | |
if device_to: | |
temp_weight = temp_weight.to(device_to) | |
# Apply patches | |
for patch in self.patches[key]: | |
temp_weight += patch | |
return temp_weight.to(weight.dtype) | |
def load_model(self, | |
model_path: Union[str, Path], | |
config: Optional[Dict[str, Any]] = None) -> None: | |
""" | |
Load a GGUF model from disk. | |
Args: | |
model_path: Path to the GGUF model file | |
config: Optional configuration dictionary for model loading | |
""" | |
try: | |
model_path = Path(model_path) | |
if not model_path.exists(): | |
raise FileNotFoundError(f"Model file not found: {model_path}") | |
if not str(model_path).endswith('.gguf'): | |
raise ValueError("Not a GGUF model file") | |
# Load the model (implementation would depend on your GGUF loader) | |
from .gguf_loader import load_gguf_model # You'd need to implement this | |
self.model = load_gguf_model( | |
model_path, | |
device=self.load_device, | |
config=config or {} | |
) | |
logging.info(f"Successfully loaded GGUF model from {model_path}") | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
raise | |
def add_patch(self, key: str, patch: torch.Tensor) -> None: | |
""" | |
Add a patch for a specific model parameter. | |
Args: | |
key: Parameter key to patch | |
patch: The patch tensor to apply | |
""" | |
if key not in self.patches: | |
self.patches[key] = [] | |
self.patches[key].append(patch) | |
def clear_patches(self) -> None: | |
"""Remove all patches from the model.""" | |
self.patches.clear() | |
# Clear quantized patches | |
if self.model: | |
for param in self.model.parameters(): | |
if self.is_quantized(param): | |
param.patches = [] | |
def to(self, device: str) -> 'GGUFUNetLoader': | |
""" | |
Move model to specified device. | |
Args: | |
device: Target device ("cuda", "cpu", etc.) | |
Returns: | |
Self for method chaining | |
""" | |
if self.model: | |
self.model.to(device) | |
self.load_device = device | |
return self | |
def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor: | |
""" | |
Calculate final weight by applying patches. | |
Args: | |
patches: List of patches to apply | |
base_weight: Base weight tensor | |
key: Parameter key | |
Returns: | |
Patched weight tensor | |
""" | |
result = base_weight.clone() | |
for patch in patches: | |
result += patch | |
return result | |
def main(): | |
# Initialize the loader | |
loader = GGUFUNetLoader() | |
# Specify model path | |
model_path = Path("path/to/your/model.gguf") | |
ckpt_path = ( | |
"https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf" | |
) | |
transformer = FluxTransformer2DModel.from_single_file( | |
ckpt_path, | |
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), | |
torch_dtype=torch.bfloat16, | |
) | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
transformer=transformer, | |
torch_dtype=torch.bfloat16, | |
) | |
# https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf | |
#pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf") | |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
pipe.fuse_lora(lora_scale=0.125) | |
pipe.enable_model_cpu_offload() | |
prompt = "A cat holding a sign that says hello world" | |
image = pipe(prompt, generator=torch.manual_seed(0)).images[0] | |
image.save("flux-gguf.png") | |
# Optional configuration for model loading | |
config = { | |
"attention_slicing": "auto", | |
"channels_last": True | |
} | |
if __name__ == "__main__": | |
main() |