File size: 6,045 Bytes
7721cb1
 
 
 
b82f02e
 
7721cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b82f02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
import logging
from pathlib import Path
from typing import Optional, Union, Dict, Any
# Set up logging
logging.basicConfig(level=logging.INFO)

class GGUFUNetLoader:
    """
    A class for loading and managing GGUF-formatted UNet models for diffusion.
    Supports quantized models with custom patch handling.
    """
    def __init__(self):
        self.model = None
        self.patches = {}
        self.backup = {}
        self.load_device = "cuda" if torch.cuda.is_available() else "cpu"
        self.offload_device = "cpu"

    @staticmethod
    def is_quantized(weight: torch.Tensor) -> bool:
        """Check if a tensor is quantized."""
        return hasattr(weight, "patches")

    def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor:
        """
        Apply patches to model weights with quantization support.
        
        Args:
            key: The parameter key to patch
            weight: The weight tensor to patch
            device_to: Target device for the patched weight
        
        Returns:
            Patched weight tensor
        """
        if key not in self.patches:
            return weight

        if self.is_quantized(weight):
            # Handle quantized weights
            out_weight = weight.to(device_to if device_to else self.load_device)
            patches = self.patches[key]
            out_weight.patches = [(self.calculate_weight, patches, key)]
            return out_weight
        else:
            # Handle regular weights
            if key not in self.backup:
                self.backup[key] = weight.to(device=self.offload_device)

            temp_weight = weight.to(torch.float32)
            if device_to:
                temp_weight = temp_weight.to(device_to)

            # Apply patches
            for patch in self.patches[key]:
                temp_weight += patch

            return temp_weight.to(weight.dtype)

    def load_model(self, 
                  model_path: Union[str, Path], 
                  config: Optional[Dict[str, Any]] = None) -> None:
        """
        Load a GGUF model from disk.
        
        Args:
            model_path: Path to the GGUF model file
            config: Optional configuration dictionary for model loading
        """
        try:
            model_path = Path(model_path)
            if not model_path.exists():
                raise FileNotFoundError(f"Model file not found: {model_path}")
            
            if not str(model_path).endswith('.gguf'):
                raise ValueError("Not a GGUF model file")

            # Load the model (implementation would depend on your GGUF loader)
            from .gguf_loader import load_gguf_model  # You'd need to implement this
            self.model = load_gguf_model(
                model_path,
                device=self.load_device,
                config=config or {}
            )
            
            logging.info(f"Successfully loaded GGUF model from {model_path}")
            
        except Exception as e:
            logging.error(f"Error loading model: {str(e)}")
            raise

    def add_patch(self, key: str, patch: torch.Tensor) -> None:
        """
        Add a patch for a specific model parameter.
        
        Args:
            key: Parameter key to patch
            patch: The patch tensor to apply
        """
        if key not in self.patches:
            self.patches[key] = []
        self.patches[key].append(patch)

    def clear_patches(self) -> None:
        """Remove all patches from the model."""
        self.patches.clear()
        
        # Clear quantized patches
        if self.model:
            for param in self.model.parameters():
                if self.is_quantized(param):
                    param.patches = []

    def to(self, device: str) -> 'GGUFUNetLoader':
        """
        Move model to specified device.
        
        Args:
            device: Target device ("cuda", "cpu", etc.)
        
        Returns:
            Self for method chaining
        """
        if self.model:
            self.model.to(device)
            self.load_device = device
        return self

    @staticmethod
    def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor:
        """
        Calculate final weight by applying patches.
        
        Args:
            patches: List of patches to apply
            base_weight: Base weight tensor
            key: Parameter key
        
        Returns:
            Patched weight tensor
        """
        result = base_weight.clone()
        for patch in patches:
            result += patch
        return result


def main():
    # Initialize the loader
    loader = GGUFUNetLoader()
    
    # Specify model path
    model_path = Path("path/to/your/model.gguf")
    
    ckpt_path = (
        "https://huggingface.co/city96/flux.1-lite-8B-alpha-gguf/flux.1-lite-8B-alpha-Q3_K_S.gguf"
    )
    
    transformer = FluxTransformer2DModel.from_single_file(
        ckpt_path,
        quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
        torch_dtype=torch.bfloat16,
    )

    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )
    # https://huggingface.co/martintomov/Hyper-FLUX.1-dev-gguf/resolve/main/hyper-flux-16step-Q3_K_M.gguf
    #pipe = FluxPipeline.from_pretrained("flux1-schnell-Q3_K_S.gguf")
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
    pipe.fuse_lora(lora_scale=0.125)
    
    pipe.enable_model_cpu_offload()
    prompt = "A cat holding a sign that says hello world"
    image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
    image.save("flux-gguf.png")
    # Optional configuration for model loading
    config = {
        "attention_slicing": "auto",
        "channels_last": True
    }
  

if __name__ == "__main__":
    main()