Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 9 files
Browse files- src/flux/__init__.py +11 -0
- src/flux/__main__.py +4 -0
- src/flux/api.py +194 -0
- src/flux/cli.py +254 -0
- src/flux/controlnet.py +222 -0
- src/flux/math.py +30 -0
- src/flux/model.py +217 -0
- src/flux/sampling.py +188 -0
- src/flux/util.py +237 -0
    	
        src/flux/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            try:
         | 
| 2 | 
            +
                from ._version import version as __version__  # type: ignore
         | 
| 3 | 
            +
                from ._version import version_tuple
         | 
| 4 | 
            +
            except ImportError:
         | 
| 5 | 
            +
                __version__ = "unknown (no version information available)"
         | 
| 6 | 
            +
                version_tuple = (0, 0, "unknown", "noinfo")
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            PACKAGE = __package__.replace("_", "-")
         | 
| 11 | 
            +
            PACKAGE_ROOT = Path(__file__).parent
         | 
    	
        src/flux/__main__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .cli import app
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            if __name__ == "__main__":
         | 
| 4 | 
            +
                app()
         | 
    	
        src/flux/api.py
    ADDED
    
    | @@ -0,0 +1,194 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import io
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import requests
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            API_ENDPOINT = "https://api.bfl.ml"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ApiException(Exception):
         | 
| 13 | 
            +
                def __init__(self, status_code: int, detail: str | list[dict] | None = None):
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.detail = detail
         | 
| 16 | 
            +
                    self.status_code = status_code
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __str__(self) -> str:
         | 
| 19 | 
            +
                    return self.__repr__()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __repr__(self) -> str:
         | 
| 22 | 
            +
                    if self.detail is None:
         | 
| 23 | 
            +
                        message = None
         | 
| 24 | 
            +
                    elif isinstance(self.detail, str):
         | 
| 25 | 
            +
                        message = self.detail
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
         | 
| 28 | 
            +
                    return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ImageRequest:
         | 
| 32 | 
            +
                def __init__(
         | 
| 33 | 
            +
                    self,
         | 
| 34 | 
            +
                    prompt: str,
         | 
| 35 | 
            +
                    width: int = 1024,
         | 
| 36 | 
            +
                    height: int = 1024,
         | 
| 37 | 
            +
                    name: str = "flux.1-pro",
         | 
| 38 | 
            +
                    num_steps: int = 50,
         | 
| 39 | 
            +
                    prompt_upsampling: bool = False,
         | 
| 40 | 
            +
                    seed: int | None = None,
         | 
| 41 | 
            +
                    validate: bool = True,
         | 
| 42 | 
            +
                    launch: bool = True,
         | 
| 43 | 
            +
                    api_key: str | None = None,
         | 
| 44 | 
            +
                ):
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    Manages an image generation request to the API.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    Args:
         | 
| 49 | 
            +
                        prompt: Prompt to sample
         | 
| 50 | 
            +
                        width: Width of the image in pixel
         | 
| 51 | 
            +
                        height: Height of the image in pixel
         | 
| 52 | 
            +
                        name: Name of the model
         | 
| 53 | 
            +
                        num_steps: Number of network evaluations
         | 
| 54 | 
            +
                        prompt_upsampling: Use prompt upsampling
         | 
| 55 | 
            +
                        seed: Fix the generation seed
         | 
| 56 | 
            +
                        validate: Run input validation
         | 
| 57 | 
            +
                        launch: Directly launches request
         | 
| 58 | 
            +
                        api_key: Your API key if not provided by the environment
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    Raises:
         | 
| 61 | 
            +
                        ValueError: For invalid input
         | 
| 62 | 
            +
                        ApiException: For errors raised from the API
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    if validate:
         | 
| 65 | 
            +
                        if name not in ["flux.1-pro"]:
         | 
| 66 | 
            +
                            raise ValueError(f"Invalid model {name}")
         | 
| 67 | 
            +
                        elif width % 32 != 0:
         | 
| 68 | 
            +
                            raise ValueError(f"width must be divisible by 32, got {width}")
         | 
| 69 | 
            +
                        elif not (256 <= width <= 1440):
         | 
| 70 | 
            +
                            raise ValueError(f"width must be between 256 and 1440, got {width}")
         | 
| 71 | 
            +
                        elif height % 32 != 0:
         | 
| 72 | 
            +
                            raise ValueError(f"height must be divisible by 32, got {height}")
         | 
| 73 | 
            +
                        elif not (256 <= height <= 1440):
         | 
| 74 | 
            +
                            raise ValueError(f"height must be between 256 and 1440, got {height}")
         | 
| 75 | 
            +
                        elif not (1 <= num_steps <= 50):
         | 
| 76 | 
            +
                            raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.request_json = {
         | 
| 79 | 
            +
                        "prompt": prompt,
         | 
| 80 | 
            +
                        "width": width,
         | 
| 81 | 
            +
                        "height": height,
         | 
| 82 | 
            +
                        "variant": name,
         | 
| 83 | 
            +
                        "steps": num_steps,
         | 
| 84 | 
            +
                        "prompt_upsampling": prompt_upsampling,
         | 
| 85 | 
            +
                    }
         | 
| 86 | 
            +
                    if seed is not None:
         | 
| 87 | 
            +
                        self.request_json["seed"] = seed
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.request_id: str | None = None
         | 
| 90 | 
            +
                    self.result: dict | None = None
         | 
| 91 | 
            +
                    self._image_bytes: bytes | None = None
         | 
| 92 | 
            +
                    self._url: str | None = None
         | 
| 93 | 
            +
                    if api_key is None:
         | 
| 94 | 
            +
                        self.api_key = os.environ.get("BFL_API_KEY")
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        self.api_key = api_key
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if launch:
         | 
| 99 | 
            +
                        self.request()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def request(self):
         | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
                    Request to generate the image.
         | 
| 104 | 
            +
                    """
         | 
| 105 | 
            +
                    if self.request_id is not None:
         | 
| 106 | 
            +
                        return
         | 
| 107 | 
            +
                    response = requests.post(
         | 
| 108 | 
            +
                        f"{API_ENDPOINT}/v1/image",
         | 
| 109 | 
            +
                        headers={
         | 
| 110 | 
            +
                            "accept": "application/json",
         | 
| 111 | 
            +
                            "x-key": self.api_key,
         | 
| 112 | 
            +
                            "Content-Type": "application/json",
         | 
| 113 | 
            +
                        },
         | 
| 114 | 
            +
                        json=self.request_json,
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    result = response.json()
         | 
| 117 | 
            +
                    if response.status_code != 200:
         | 
| 118 | 
            +
                        raise ApiException(status_code=response.status_code, detail=result.get("detail"))
         | 
| 119 | 
            +
                    self.request_id = response.json()["id"]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def retrieve(self) -> dict:
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    Wait for the generation to finish and retrieve response.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    if self.request_id is None:
         | 
| 126 | 
            +
                        self.request()
         | 
| 127 | 
            +
                    while self.result is None:
         | 
| 128 | 
            +
                        response = requests.get(
         | 
| 129 | 
            +
                            f"{API_ENDPOINT}/v1/get_result",
         | 
| 130 | 
            +
                            headers={
         | 
| 131 | 
            +
                                "accept": "application/json",
         | 
| 132 | 
            +
                                "x-key": self.api_key,
         | 
| 133 | 
            +
                            },
         | 
| 134 | 
            +
                            params={
         | 
| 135 | 
            +
                                "id": self.request_id,
         | 
| 136 | 
            +
                            },
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
                        result = response.json()
         | 
| 139 | 
            +
                        if "status" not in result:
         | 
| 140 | 
            +
                            raise ApiException(status_code=response.status_code, detail=result.get("detail"))
         | 
| 141 | 
            +
                        elif result["status"] == "Ready":
         | 
| 142 | 
            +
                            self.result = result["result"]
         | 
| 143 | 
            +
                        elif result["status"] == "Pending":
         | 
| 144 | 
            +
                            time.sleep(0.5)
         | 
| 145 | 
            +
                        else:
         | 
| 146 | 
            +
                            raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
         | 
| 147 | 
            +
                    return self.result
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                @property
         | 
| 150 | 
            +
                def bytes(self) -> bytes:
         | 
| 151 | 
            +
                    """
         | 
| 152 | 
            +
                    Generated image as bytes.
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    if self._image_bytes is None:
         | 
| 155 | 
            +
                        response = requests.get(self.url)
         | 
| 156 | 
            +
                        if response.status_code == 200:
         | 
| 157 | 
            +
                            self._image_bytes = response.content
         | 
| 158 | 
            +
                        else:
         | 
| 159 | 
            +
                            raise ApiException(status_code=response.status_code)
         | 
| 160 | 
            +
                    return self._image_bytes
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                @property
         | 
| 163 | 
            +
                def url(self) -> str:
         | 
| 164 | 
            +
                    """
         | 
| 165 | 
            +
                    Public url to retrieve the image from
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    if self._url is None:
         | 
| 168 | 
            +
                        result = self.retrieve()
         | 
| 169 | 
            +
                        self._url = result["sample"]
         | 
| 170 | 
            +
                    return self._url
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                @property
         | 
| 173 | 
            +
                def image(self) -> Image.Image:
         | 
| 174 | 
            +
                    """
         | 
| 175 | 
            +
                    Load the image as a PIL Image
         | 
| 176 | 
            +
                    """
         | 
| 177 | 
            +
                    return Image.open(io.BytesIO(self.bytes))
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def save(self, path: str):
         | 
| 180 | 
            +
                    """
         | 
| 181 | 
            +
                    Save the generated image to a local path
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    suffix = Path(self.url).suffix
         | 
| 184 | 
            +
                    if not path.endswith(suffix):
         | 
| 185 | 
            +
                        path = path + suffix
         | 
| 186 | 
            +
                    Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
         | 
| 187 | 
            +
                    with open(path, "wb") as file:
         | 
| 188 | 
            +
                        file.write(self.bytes)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            if __name__ == "__main__":
         | 
| 192 | 
            +
                from fire import Fire
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                Fire(ImageRequest)
         | 
    	
        src/flux/cli.py
    ADDED
    
    | @@ -0,0 +1,254 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
            from glob import iglob
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from fire import Fire
         | 
| 10 | 
            +
            from PIL import ExifTags, Image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
         | 
| 13 | 
            +
            from flux.util import (configs, embed_watermark, load_ae, load_clip,
         | 
| 14 | 
            +
                                   load_flow_model, load_t5)
         | 
| 15 | 
            +
            from transformers import pipeline
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            NSFW_THRESHOLD = 0.85
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            @dataclass
         | 
| 20 | 
            +
            class SamplingOptions:
         | 
| 21 | 
            +
                prompt: str
         | 
| 22 | 
            +
                width: int
         | 
| 23 | 
            +
                height: int
         | 
| 24 | 
            +
                num_steps: int
         | 
| 25 | 
            +
                guidance: float
         | 
| 26 | 
            +
                seed: int | None
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
         | 
| 30 | 
            +
                user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
         | 
| 31 | 
            +
                usage = (
         | 
| 32 | 
            +
                    "Usage: Either write your prompt directly, leave this field empty "
         | 
| 33 | 
            +
                    "to repeat the prompt or write a command starting with a slash:\n"
         | 
| 34 | 
            +
                    "- '/w <width>' will set the width of the generated image\n"
         | 
| 35 | 
            +
                    "- '/h <height>' will set the height of the generated image\n"
         | 
| 36 | 
            +
                    "- '/s <seed>' sets the next seed\n"
         | 
| 37 | 
            +
                    "- '/g <guidance>' sets the guidance (flux-dev only)\n"
         | 
| 38 | 
            +
                    "- '/n <steps>' sets the number of steps\n"
         | 
| 39 | 
            +
                    "- '/q' to quit"
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                while (prompt := input(user_question)).startswith("/"):
         | 
| 43 | 
            +
                    if prompt.startswith("/w"):
         | 
| 44 | 
            +
                        if prompt.count(" ") != 1:
         | 
| 45 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 46 | 
            +
                            continue
         | 
| 47 | 
            +
                        _, width = prompt.split()
         | 
| 48 | 
            +
                        options.width = 16 * (int(width) // 16)
         | 
| 49 | 
            +
                        print(
         | 
| 50 | 
            +
                            f"Setting resolution to {options.width} x {options.height} "
         | 
| 51 | 
            +
                            f"({options.height *options.width/1e6:.2f}MP)"
         | 
| 52 | 
            +
                        )
         | 
| 53 | 
            +
                    elif prompt.startswith("/h"):
         | 
| 54 | 
            +
                        if prompt.count(" ") != 1:
         | 
| 55 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 56 | 
            +
                            continue
         | 
| 57 | 
            +
                        _, height = prompt.split()
         | 
| 58 | 
            +
                        options.height = 16 * (int(height) // 16)
         | 
| 59 | 
            +
                        print(
         | 
| 60 | 
            +
                            f"Setting resolution to {options.width} x {options.height} "
         | 
| 61 | 
            +
                            f"({options.height *options.width/1e6:.2f}MP)"
         | 
| 62 | 
            +
                        )
         | 
| 63 | 
            +
                    elif prompt.startswith("/g"):
         | 
| 64 | 
            +
                        if prompt.count(" ") != 1:
         | 
| 65 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 66 | 
            +
                            continue
         | 
| 67 | 
            +
                        _, guidance = prompt.split()
         | 
| 68 | 
            +
                        options.guidance = float(guidance)
         | 
| 69 | 
            +
                        print(f"Setting guidance to {options.guidance}")
         | 
| 70 | 
            +
                    elif prompt.startswith("/s"):
         | 
| 71 | 
            +
                        if prompt.count(" ") != 1:
         | 
| 72 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 73 | 
            +
                            continue
         | 
| 74 | 
            +
                        _, seed = prompt.split()
         | 
| 75 | 
            +
                        options.seed = int(seed)
         | 
| 76 | 
            +
                        print(f"Setting seed to {options.seed}")
         | 
| 77 | 
            +
                    elif prompt.startswith("/n"):
         | 
| 78 | 
            +
                        if prompt.count(" ") != 1:
         | 
| 79 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 80 | 
            +
                            continue
         | 
| 81 | 
            +
                        _, steps = prompt.split()
         | 
| 82 | 
            +
                        options.num_steps = int(steps)
         | 
| 83 | 
            +
                        print(f"Setting seed to {options.num_steps}")
         | 
| 84 | 
            +
                    elif prompt.startswith("/q"):
         | 
| 85 | 
            +
                        print("Quitting")
         | 
| 86 | 
            +
                        return None
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        if not prompt.startswith("/h"):
         | 
| 89 | 
            +
                            print(f"Got invalid command '{prompt}'\n{usage}")
         | 
| 90 | 
            +
                        print(usage)
         | 
| 91 | 
            +
                if prompt != "":
         | 
| 92 | 
            +
                    options.prompt = prompt
         | 
| 93 | 
            +
                return options
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            @torch.inference_mode()
         | 
| 97 | 
            +
            def main(
         | 
| 98 | 
            +
                name: str = "flux-schnell",
         | 
| 99 | 
            +
                width: int = 1360,
         | 
| 100 | 
            +
                height: int = 768,
         | 
| 101 | 
            +
                seed: int | None = None,
         | 
| 102 | 
            +
                prompt: str = (
         | 
| 103 | 
            +
                    "a photo of a forest with mist swirling around the tree trunks. The word "
         | 
| 104 | 
            +
                    '"FLUX" is painted over it in big, red brush strokes with visible texture'
         | 
| 105 | 
            +
                ),
         | 
| 106 | 
            +
                device: str = "cuda" if torch.cuda.is_available() else "cpu",
         | 
| 107 | 
            +
                num_steps: int | None = None,
         | 
| 108 | 
            +
                loop: bool = False,
         | 
| 109 | 
            +
                guidance: float = 3.5,
         | 
| 110 | 
            +
                offload: bool = False,
         | 
| 111 | 
            +
                output_dir: str = "output",
         | 
| 112 | 
            +
                add_sampling_metadata: bool = True,
         | 
| 113 | 
            +
            ):
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                Sample the flux model. Either interactively (set `--loop`) or run for a
         | 
| 116 | 
            +
                single image.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                Args:
         | 
| 119 | 
            +
                    name: Name of the model to load
         | 
| 120 | 
            +
                    height: height of the sample in pixels (should be a multiple of 16)
         | 
| 121 | 
            +
                    width: width of the sample in pixels (should be a multiple of 16)
         | 
| 122 | 
            +
                    seed: Set a seed for sampling
         | 
| 123 | 
            +
                    output_name: where to save the output image, `{idx}` will be replaced
         | 
| 124 | 
            +
                        by the index of the sample
         | 
| 125 | 
            +
                    prompt: Prompt used for sampling
         | 
| 126 | 
            +
                    device: Pytorch device
         | 
| 127 | 
            +
                    num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
         | 
| 128 | 
            +
                    loop: start an interactive session and sample multiple times
         | 
| 129 | 
            +
                    guidance: guidance value used for guidance distillation
         | 
| 130 | 
            +
                    add_sampling_metadata: Add the prompt to the image Exif metadata
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                if name not in configs:
         | 
| 135 | 
            +
                    available = ", ".join(configs.keys())
         | 
| 136 | 
            +
                    raise ValueError(f"Got unknown model name: {name}, chose from {available}")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                torch_device = torch.device(device)
         | 
| 139 | 
            +
                if num_steps is None:
         | 
| 140 | 
            +
                    num_steps = 4 if name == "flux-schnell" else 50
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # allow for packing and conversion to latent space
         | 
| 143 | 
            +
                height = 16 * (height // 16)
         | 
| 144 | 
            +
                width = 16 * (width // 16)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                output_name = os.path.join(output_dir, "img_{idx}.jpg")
         | 
| 147 | 
            +
                if not os.path.exists(output_dir):
         | 
| 148 | 
            +
                    os.makedirs(output_dir)
         | 
| 149 | 
            +
                    idx = 0
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
         | 
| 152 | 
            +
                    if len(fns) > 0:
         | 
| 153 | 
            +
                        idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        idx = 0
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # init all components
         | 
| 158 | 
            +
                t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
         | 
| 159 | 
            +
                clip = load_clip(torch_device)
         | 
| 160 | 
            +
                model = load_flow_model(name, device="cpu" if offload else torch_device)
         | 
| 161 | 
            +
                ae = load_ae(name, device="cpu" if offload else torch_device)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                rng = torch.Generator(device="cpu")
         | 
| 164 | 
            +
                opts = SamplingOptions(
         | 
| 165 | 
            +
                    prompt=prompt,
         | 
| 166 | 
            +
                    width=width,
         | 
| 167 | 
            +
                    height=height,
         | 
| 168 | 
            +
                    num_steps=num_steps,
         | 
| 169 | 
            +
                    guidance=guidance,
         | 
| 170 | 
            +
                    seed=seed,
         | 
| 171 | 
            +
                )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                if loop:
         | 
| 174 | 
            +
                    opts = parse_prompt(opts)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                while opts is not None:
         | 
| 177 | 
            +
                    if opts.seed is None:
         | 
| 178 | 
            +
                        opts.seed = rng.seed()
         | 
| 179 | 
            +
                    print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
         | 
| 180 | 
            +
                    t0 = time.perf_counter()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # prepare input
         | 
| 183 | 
            +
                    x = get_noise(
         | 
| 184 | 
            +
                        1,
         | 
| 185 | 
            +
                        opts.height,
         | 
| 186 | 
            +
                        opts.width,
         | 
| 187 | 
            +
                        device=torch_device,
         | 
| 188 | 
            +
                        dtype=torch.bfloat16,
         | 
| 189 | 
            +
                        seed=opts.seed,
         | 
| 190 | 
            +
                    )
         | 
| 191 | 
            +
                    opts.seed = None
         | 
| 192 | 
            +
                    if offload:
         | 
| 193 | 
            +
                        ae = ae.cpu()
         | 
| 194 | 
            +
                        torch.cuda.empty_cache()
         | 
| 195 | 
            +
                        t5, clip = t5.to(torch_device), clip.to(torch_device)
         | 
| 196 | 
            +
                    inp = prepare(t5, clip, x, prompt=opts.prompt)
         | 
| 197 | 
            +
                    timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    # offload TEs to CPU, load model to gpu
         | 
| 200 | 
            +
                    if offload:
         | 
| 201 | 
            +
                        t5, clip = t5.cpu(), clip.cpu()
         | 
| 202 | 
            +
                        torch.cuda.empty_cache()
         | 
| 203 | 
            +
                        model = model.to(torch_device)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # denoise initial noise
         | 
| 206 | 
            +
                    x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # offload model, load autoencoder to gpu
         | 
| 209 | 
            +
                    if offload:
         | 
| 210 | 
            +
                        model.cpu()
         | 
| 211 | 
            +
                        torch.cuda.empty_cache()
         | 
| 212 | 
            +
                        ae.decoder.to(x.device)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # decode latents to pixel space
         | 
| 215 | 
            +
                    x = unpack(x.float(), opts.height, opts.width)
         | 
| 216 | 
            +
                    with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
         | 
| 217 | 
            +
                        x = ae.decode(x)
         | 
| 218 | 
            +
                    t1 = time.perf_counter()
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    fn = output_name.format(idx=idx)
         | 
| 221 | 
            +
                    print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
         | 
| 222 | 
            +
                    # bring into PIL format and save
         | 
| 223 | 
            +
                    x = x.clamp(-1, 1)
         | 
| 224 | 
            +
                    x = embed_watermark(x.float())
         | 
| 225 | 
            +
                    x = rearrange(x[0], "c h w -> h w c")
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
         | 
| 228 | 
            +
                    nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
         | 
| 229 | 
            +
                    
         | 
| 230 | 
            +
                    if nsfw_score < NSFW_THRESHOLD:
         | 
| 231 | 
            +
                        exif_data = Image.Exif()
         | 
| 232 | 
            +
                        exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
         | 
| 233 | 
            +
                        exif_data[ExifTags.Base.Make] = "Black Forest Labs"
         | 
| 234 | 
            +
                        exif_data[ExifTags.Base.Model] = name
         | 
| 235 | 
            +
                        if add_sampling_metadata:
         | 
| 236 | 
            +
                            exif_data[ExifTags.Base.ImageDescription] = prompt
         | 
| 237 | 
            +
                        img.save(fn, exif=exif_data, quality=95, subsampling=0)
         | 
| 238 | 
            +
                        idx += 1
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        print("Your generated image may contain NSFW content.")
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    if loop:
         | 
| 243 | 
            +
                        print("-" * 80)
         | 
| 244 | 
            +
                        opts = parse_prompt(opts)
         | 
| 245 | 
            +
                    else:
         | 
| 246 | 
            +
                        opts = None
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def app():
         | 
| 250 | 
            +
                Fire(main)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            if __name__ == "__main__":
         | 
| 254 | 
            +
                app()
         | 
    	
        src/flux/controlnet.py
    ADDED
    
    | @@ -0,0 +1,222 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import Tensor, nn
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
         | 
| 8 | 
            +
                                             MLPEmbedder, SingleStreamBlock,
         | 
| 9 | 
            +
                                             timestep_embedding)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @dataclass
         | 
| 13 | 
            +
            class FluxParams:
         | 
| 14 | 
            +
                in_channels: int
         | 
| 15 | 
            +
                vec_in_dim: int
         | 
| 16 | 
            +
                context_in_dim: int
         | 
| 17 | 
            +
                hidden_size: int
         | 
| 18 | 
            +
                mlp_ratio: float
         | 
| 19 | 
            +
                num_heads: int
         | 
| 20 | 
            +
                depth: int
         | 
| 21 | 
            +
                depth_single_blocks: int
         | 
| 22 | 
            +
                axes_dim: list[int]
         | 
| 23 | 
            +
                theta: int
         | 
| 24 | 
            +
                qkv_bias: bool
         | 
| 25 | 
            +
                guidance_embed: bool
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def zero_module(module):
         | 
| 28 | 
            +
                for p in module.parameters():
         | 
| 29 | 
            +
                    nn.init.zeros_(p)
         | 
| 30 | 
            +
                return module
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class ControlNetFlux(nn.Module):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Transformer model for flow matching on sequences.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(self, params: FluxParams, controlnet_depth=2):
         | 
| 40 | 
            +
                    super().__init__()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.params = params
         | 
| 43 | 
            +
                    self.in_channels = params.in_channels
         | 
| 44 | 
            +
                    self.out_channels = self.in_channels
         | 
| 45 | 
            +
                    if params.hidden_size % params.num_heads != 0:
         | 
| 46 | 
            +
                        raise ValueError(
         | 
| 47 | 
            +
                            f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                    pe_dim = params.hidden_size // params.num_heads
         | 
| 50 | 
            +
                    if sum(params.axes_dim) != pe_dim:
         | 
| 51 | 
            +
                        raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
         | 
| 52 | 
            +
                    self.hidden_size = params.hidden_size
         | 
| 53 | 
            +
                    self.num_heads = params.num_heads
         | 
| 54 | 
            +
                    self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
         | 
| 55 | 
            +
                    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
         | 
| 56 | 
            +
                    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
         | 
| 57 | 
            +
                    self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
         | 
| 58 | 
            +
                    self.guidance_in = (
         | 
| 59 | 
            +
                        MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.double_blocks = nn.ModuleList(
         | 
| 64 | 
            +
                        [
         | 
| 65 | 
            +
                            DoubleStreamBlock(
         | 
| 66 | 
            +
                                self.hidden_size,
         | 
| 67 | 
            +
                                self.num_heads,
         | 
| 68 | 
            +
                                mlp_ratio=params.mlp_ratio,
         | 
| 69 | 
            +
                                qkv_bias=params.qkv_bias,
         | 
| 70 | 
            +
                            )
         | 
| 71 | 
            +
                            for _ in range(controlnet_depth)
         | 
| 72 | 
            +
                        ]
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # add ControlNet blocks
         | 
| 76 | 
            +
                    self.controlnet_blocks = nn.ModuleList([])
         | 
| 77 | 
            +
                    for _ in range(controlnet_depth):
         | 
| 78 | 
            +
                        controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
         | 
| 79 | 
            +
                        controlnet_block = zero_module(controlnet_block)
         | 
| 80 | 
            +
                        self.controlnet_blocks.append(controlnet_block)
         | 
| 81 | 
            +
                    self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
         | 
| 82 | 
            +
                    self.gradient_checkpointing = False
         | 
| 83 | 
            +
                    self.input_hint_block = nn.Sequential(
         | 
| 84 | 
            +
                        nn.Conv2d(3, 16, 3, padding=1),
         | 
| 85 | 
            +
                        nn.SiLU(),
         | 
| 86 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1),
         | 
| 87 | 
            +
                        nn.SiLU(),
         | 
| 88 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1, stride=2),
         | 
| 89 | 
            +
                        nn.SiLU(),
         | 
| 90 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1),
         | 
| 91 | 
            +
                        nn.SiLU(),
         | 
| 92 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1, stride=2),
         | 
| 93 | 
            +
                        nn.SiLU(),
         | 
| 94 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1),
         | 
| 95 | 
            +
                        nn.SiLU(),
         | 
| 96 | 
            +
                        nn.Conv2d(16, 16, 3, padding=1, stride=2),
         | 
| 97 | 
            +
                        nn.SiLU(),
         | 
| 98 | 
            +
                        zero_module(nn.Conv2d(16, 16, 3, padding=1))
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 102 | 
            +
                    if hasattr(module, "gradient_checkpointing"):
         | 
| 103 | 
            +
                        module.gradient_checkpointing = value
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
                @property
         | 
| 107 | 
            +
                def attn_processors(self):
         | 
| 108 | 
            +
                    # set recursively
         | 
| 109 | 
            +
                    processors = {}
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
         | 
| 112 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 113 | 
            +
                            processors[f"{name}.processor"] = module.processor
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 116 | 
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                        return processors
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    for name, module in self.named_children():
         | 
| 121 | 
            +
                        fn_recursive_add_processors(name, module, processors)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    return processors
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def set_attn_processor(self, processor):
         | 
| 126 | 
            +
                    r"""
         | 
| 127 | 
            +
                    Sets the attention processor to use to compute attention.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    Parameters:
         | 
| 130 | 
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         | 
| 131 | 
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         | 
| 132 | 
            +
                            for **all** `Attention` layers.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         | 
| 135 | 
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    count = len(self.attn_processors.keys())
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 141 | 
            +
                        raise ValueError(
         | 
| 142 | 
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         | 
| 143 | 
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 144 | 
            +
                        )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 147 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 148 | 
            +
                            if not isinstance(processor, dict):
         | 
| 149 | 
            +
                                module.set_processor(processor)
         | 
| 150 | 
            +
                            else:
         | 
| 151 | 
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 154 | 
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    for name, module in self.named_children():
         | 
| 157 | 
            +
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def forward(
         | 
| 160 | 
            +
                    self,
         | 
| 161 | 
            +
                    img: Tensor,
         | 
| 162 | 
            +
                    img_ids: Tensor,
         | 
| 163 | 
            +
                    controlnet_cond: Tensor,
         | 
| 164 | 
            +
                    txt: Tensor,
         | 
| 165 | 
            +
                    txt_ids: Tensor,
         | 
| 166 | 
            +
                    timesteps: Tensor,
         | 
| 167 | 
            +
                    y: Tensor,
         | 
| 168 | 
            +
                    guidance: Tensor | None = None,
         | 
| 169 | 
            +
                ) -> Tensor:
         | 
| 170 | 
            +
                    if img.ndim != 3 or txt.ndim != 3:
         | 
| 171 | 
            +
                        raise ValueError("Input img and txt tensors must have 3 dimensions.")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # running on sequences img
         | 
| 174 | 
            +
                    img = self.img_in(img)
         | 
| 175 | 
            +
                    controlnet_cond = self.input_hint_block(controlnet_cond)
         | 
| 176 | 
            +
                    controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         | 
| 177 | 
            +
                    controlnet_cond = self.pos_embed_input(controlnet_cond)
         | 
| 178 | 
            +
                    img = img + controlnet_cond
         | 
| 179 | 
            +
                    vec = self.time_in(timestep_embedding(timesteps, 256))
         | 
| 180 | 
            +
                    if self.params.guidance_embed:
         | 
| 181 | 
            +
                        if guidance is None:
         | 
| 182 | 
            +
                            raise ValueError("Didn't get guidance strength for guidance distilled model.")
         | 
| 183 | 
            +
                        vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
         | 
| 184 | 
            +
                    vec = vec + self.vector_in(y)
         | 
| 185 | 
            +
                    txt = self.txt_in(txt)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    ids = torch.cat((txt_ids, img_ids), dim=1)
         | 
| 188 | 
            +
                    pe = self.pe_embedder(ids)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    block_res_samples = ()
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    for block in self.double_blocks:
         | 
| 193 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 196 | 
            +
                                def custom_forward(*inputs):
         | 
| 197 | 
            +
                                    if return_dict is not None:
         | 
| 198 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 199 | 
            +
                                    else:
         | 
| 200 | 
            +
                                        return module(*inputs)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                                return custom_forward
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         | 
| 205 | 
            +
                            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 206 | 
            +
                                create_custom_forward(block),
         | 
| 207 | 
            +
                                img,
         | 
| 208 | 
            +
                                txt,
         | 
| 209 | 
            +
                                vec,
         | 
| 210 | 
            +
                                pe,
         | 
| 211 | 
            +
                            )
         | 
| 212 | 
            +
                        else:
         | 
| 213 | 
            +
                            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                        block_res_samples = block_res_samples + (img,)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    controlnet_block_res_samples = ()
         | 
| 218 | 
            +
                    for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
         | 
| 219 | 
            +
                        block_res_sample = controlnet_block(block_res_sample)
         | 
| 220 | 
            +
                        controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    return controlnet_block_res_samples
         | 
    	
        src/flux/math.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from einops import rearrange
         | 
| 3 | 
            +
            from torch import Tensor
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
         | 
| 7 | 
            +
                q, k = apply_rope(q, k, pe)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
         | 
| 10 | 
            +
                x = rearrange(x, "B H L D -> B L (H D)")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                return x
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
         | 
| 16 | 
            +
                assert dim % 2 == 0
         | 
| 17 | 
            +
                scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
         | 
| 18 | 
            +
                omega = 1.0 / (theta**scale)
         | 
| 19 | 
            +
                out = torch.einsum("...n,d->...nd", pos, omega)
         | 
| 20 | 
            +
                out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
         | 
| 21 | 
            +
                out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
         | 
| 22 | 
            +
                return out.float()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
         | 
| 26 | 
            +
                xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         | 
| 27 | 
            +
                xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         | 
| 28 | 
            +
                xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         | 
| 29 | 
            +
                xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         | 
| 30 | 
            +
                return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         | 
    	
        src/flux/model.py
    ADDED
    
    | @@ -0,0 +1,217 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import Tensor, nn
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
         | 
| 8 | 
            +
                                             MLPEmbedder, SingleStreamBlock,
         | 
| 9 | 
            +
                                             timestep_embedding)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @dataclass
         | 
| 13 | 
            +
            class FluxParams:
         | 
| 14 | 
            +
                in_channels: int
         | 
| 15 | 
            +
                vec_in_dim: int
         | 
| 16 | 
            +
                context_in_dim: int
         | 
| 17 | 
            +
                hidden_size: int
         | 
| 18 | 
            +
                mlp_ratio: float
         | 
| 19 | 
            +
                num_heads: int
         | 
| 20 | 
            +
                depth: int
         | 
| 21 | 
            +
                depth_single_blocks: int
         | 
| 22 | 
            +
                axes_dim: list[int]
         | 
| 23 | 
            +
                theta: int
         | 
| 24 | 
            +
                qkv_bias: bool
         | 
| 25 | 
            +
                guidance_embed: bool
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class Flux(nn.Module):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Transformer model for flow matching on sequences.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(self, params: FluxParams):
         | 
| 35 | 
            +
                    super().__init__()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.params = params
         | 
| 38 | 
            +
                    self.in_channels = params.in_channels
         | 
| 39 | 
            +
                    self.out_channels = self.in_channels
         | 
| 40 | 
            +
                    if params.hidden_size % params.num_heads != 0:
         | 
| 41 | 
            +
                        raise ValueError(
         | 
| 42 | 
            +
                            f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
         | 
| 43 | 
            +
                        )
         | 
| 44 | 
            +
                    pe_dim = params.hidden_size // params.num_heads
         | 
| 45 | 
            +
                    if sum(params.axes_dim) != pe_dim:
         | 
| 46 | 
            +
                        raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
         | 
| 47 | 
            +
                    self.hidden_size = params.hidden_size
         | 
| 48 | 
            +
                    self.num_heads = params.num_heads
         | 
| 49 | 
            +
                    self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
         | 
| 50 | 
            +
                    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
         | 
| 51 | 
            +
                    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
         | 
| 52 | 
            +
                    self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
         | 
| 53 | 
            +
                    self.guidance_in = (
         | 
| 54 | 
            +
                        MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.double_blocks = nn.ModuleList(
         | 
| 59 | 
            +
                        [
         | 
| 60 | 
            +
                            DoubleStreamBlock(
         | 
| 61 | 
            +
                                self.hidden_size,
         | 
| 62 | 
            +
                                self.num_heads,
         | 
| 63 | 
            +
                                mlp_ratio=params.mlp_ratio,
         | 
| 64 | 
            +
                                qkv_bias=params.qkv_bias,
         | 
| 65 | 
            +
                            )
         | 
| 66 | 
            +
                            for _ in range(params.depth)
         | 
| 67 | 
            +
                        ]
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.single_blocks = nn.ModuleList(
         | 
| 71 | 
            +
                        [
         | 
| 72 | 
            +
                            SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
         | 
| 73 | 
            +
                            for _ in range(params.depth_single_blocks)
         | 
| 74 | 
            +
                        ]
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
         | 
| 78 | 
            +
                    self.gradient_checkpointing = False
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 81 | 
            +
                    if hasattr(module, "gradient_checkpointing"):
         | 
| 82 | 
            +
                        module.gradient_checkpointing = value
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                @property
         | 
| 85 | 
            +
                def attn_processors(self):
         | 
| 86 | 
            +
                    # set recursively
         | 
| 87 | 
            +
                    processors = {}
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
         | 
| 90 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 91 | 
            +
                            processors[f"{name}.processor"] = module.processor
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 94 | 
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        return processors
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    for name, module in self.named_children():
         | 
| 99 | 
            +
                        fn_recursive_add_processors(name, module, processors)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    return processors
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def set_attn_processor(self, processor):
         | 
| 104 | 
            +
                    r"""
         | 
| 105 | 
            +
                    Sets the attention processor to use to compute attention.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    Parameters:
         | 
| 108 | 
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         | 
| 109 | 
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         | 
| 110 | 
            +
                            for **all** `Attention` layers.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         | 
| 113 | 
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    """
         | 
| 116 | 
            +
                    count = len(self.attn_processors.keys())
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 119 | 
            +
                        raise ValueError(
         | 
| 120 | 
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         | 
| 121 | 
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 125 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 126 | 
            +
                            if not isinstance(processor, dict):
         | 
| 127 | 
            +
                                module.set_processor(processor)
         | 
| 128 | 
            +
                            else:
         | 
| 129 | 
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 132 | 
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    for name, module in self.named_children():
         | 
| 135 | 
            +
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def forward(
         | 
| 138 | 
            +
                    self,
         | 
| 139 | 
            +
                    img: Tensor,
         | 
| 140 | 
            +
                    img_ids: Tensor,
         | 
| 141 | 
            +
                    txt: Tensor,
         | 
| 142 | 
            +
                    txt_ids: Tensor,
         | 
| 143 | 
            +
                    timesteps: Tensor,
         | 
| 144 | 
            +
                    y: Tensor,
         | 
| 145 | 
            +
                    block_controlnet_hidden_states=None,
         | 
| 146 | 
            +
                    guidance: Tensor | None = None,
         | 
| 147 | 
            +
                ) -> Tensor:
         | 
| 148 | 
            +
                    if img.ndim != 3 or txt.ndim != 3:
         | 
| 149 | 
            +
                        raise ValueError("Input img and txt tensors must have 3 dimensions.")
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    # running on sequences img
         | 
| 152 | 
            +
                    img = self.img_in(img)
         | 
| 153 | 
            +
                    vec = self.time_in(timestep_embedding(timesteps, 256))
         | 
| 154 | 
            +
                    if self.params.guidance_embed:
         | 
| 155 | 
            +
                        if guidance is None:
         | 
| 156 | 
            +
                            raise ValueError("Didn't get guidance strength for guidance distilled model.")
         | 
| 157 | 
            +
                        vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
         | 
| 158 | 
            +
                    vec = vec + self.vector_in(y)
         | 
| 159 | 
            +
                    txt = self.txt_in(txt)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    ids = torch.cat((txt_ids, img_ids), dim=1)
         | 
| 162 | 
            +
                    pe = self.pe_embedder(ids)
         | 
| 163 | 
            +
                    if block_controlnet_hidden_states is not None:
         | 
| 164 | 
            +
                        controlnet_depth = len(block_controlnet_hidden_states)
         | 
| 165 | 
            +
                    for index_block, block in enumerate(self.double_blocks):
         | 
| 166 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 169 | 
            +
                                def custom_forward(*inputs):
         | 
| 170 | 
            +
                                    if return_dict is not None:
         | 
| 171 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 172 | 
            +
                                    else:
         | 
| 173 | 
            +
                                        return module(*inputs)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                                return custom_forward
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         | 
| 178 | 
            +
                            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 179 | 
            +
                                create_custom_forward(block),
         | 
| 180 | 
            +
                                img,
         | 
| 181 | 
            +
                                txt,
         | 
| 182 | 
            +
                                vec,
         | 
| 183 | 
            +
                                pe,
         | 
| 184 | 
            +
                            )
         | 
| 185 | 
            +
                        else:
         | 
| 186 | 
            +
                            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
         | 
| 187 | 
            +
                        # controlnet residual
         | 
| 188 | 
            +
                        if block_controlnet_hidden_states is not None:
         | 
| 189 | 
            +
                            img = img + block_controlnet_hidden_states[index_block % 2]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
                    img = torch.cat((txt, img), 1)
         | 
| 193 | 
            +
                    for block in self.single_blocks:
         | 
| 194 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 197 | 
            +
                                def custom_forward(*inputs):
         | 
| 198 | 
            +
                                    if return_dict is not None:
         | 
| 199 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 200 | 
            +
                                    else:
         | 
| 201 | 
            +
                                        return module(*inputs)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                                return custom_forward
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         | 
| 206 | 
            +
                            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 207 | 
            +
                                create_custom_forward(block),
         | 
| 208 | 
            +
                                img,
         | 
| 209 | 
            +
                                vec,
         | 
| 210 | 
            +
                                pe,
         | 
| 211 | 
            +
                            )
         | 
| 212 | 
            +
                        else:
         | 
| 213 | 
            +
                            img = block(img, vec=vec, pe=pe)
         | 
| 214 | 
            +
                    img = img[:, txt.shape[1] :, ...]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
         | 
| 217 | 
            +
                    return img
         | 
    	
        src/flux/sampling.py
    ADDED
    
    | @@ -0,0 +1,188 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Callable
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from einops import rearrange, repeat
         | 
| 6 | 
            +
            from torch import Tensor
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .model import Flux
         | 
| 9 | 
            +
            from .modules.conditioner import HFEmbedder
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_noise(
         | 
| 13 | 
            +
                num_samples: int,
         | 
| 14 | 
            +
                height: int,
         | 
| 15 | 
            +
                width: int,
         | 
| 16 | 
            +
                device: torch.device,
         | 
| 17 | 
            +
                dtype: torch.dtype,
         | 
| 18 | 
            +
                seed: int,
         | 
| 19 | 
            +
            ):
         | 
| 20 | 
            +
                return torch.randn(
         | 
| 21 | 
            +
                    num_samples,
         | 
| 22 | 
            +
                    16,
         | 
| 23 | 
            +
                    # allow for packing
         | 
| 24 | 
            +
                    2 * math.ceil(height / 16),
         | 
| 25 | 
            +
                    2 * math.ceil(width / 16),
         | 
| 26 | 
            +
                    device=device,
         | 
| 27 | 
            +
                    dtype=dtype,
         | 
| 28 | 
            +
                    generator=torch.Generator(device=device).manual_seed(seed),
         | 
| 29 | 
            +
                )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
         | 
| 33 | 
            +
                bs, c, h, w = img.shape
         | 
| 34 | 
            +
                if bs == 1 and not isinstance(prompt, str):
         | 
| 35 | 
            +
                    bs = len(prompt)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         | 
| 38 | 
            +
                if img.shape[0] == 1 and bs > 1:
         | 
| 39 | 
            +
                    img = repeat(img, "1 ... -> bs ...", bs=bs)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                img_ids = torch.zeros(h // 2, w // 2, 3)
         | 
| 42 | 
            +
                img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
         | 
| 43 | 
            +
                img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
         | 
| 44 | 
            +
                img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                if isinstance(prompt, str):
         | 
| 47 | 
            +
                    prompt = [prompt]
         | 
| 48 | 
            +
                txt = t5(prompt)
         | 
| 49 | 
            +
                if txt.shape[0] == 1 and bs > 1:
         | 
| 50 | 
            +
                    txt = repeat(txt, "1 ... -> bs ...", bs=bs)
         | 
| 51 | 
            +
                txt_ids = torch.zeros(bs, txt.shape[1], 3)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                vec = clip(prompt)
         | 
| 54 | 
            +
                if vec.shape[0] == 1 and bs > 1:
         | 
| 55 | 
            +
                    vec = repeat(vec, "1 ... -> bs ...", bs=bs)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                return {
         | 
| 58 | 
            +
                    "img": img,
         | 
| 59 | 
            +
                    "img_ids": img_ids.to(img.device),
         | 
| 60 | 
            +
                    "txt": txt.to(img.device),
         | 
| 61 | 
            +
                    "txt_ids": txt_ids.to(img.device),
         | 
| 62 | 
            +
                    "vec": vec.to(img.device),
         | 
| 63 | 
            +
                }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def time_shift(mu: float, sigma: float, t: Tensor):
         | 
| 67 | 
            +
                return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def get_lin_function(
         | 
| 71 | 
            +
                x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
         | 
| 72 | 
            +
            ) -> Callable[[float], float]:
         | 
| 73 | 
            +
                m = (y2 - y1) / (x2 - x1)
         | 
| 74 | 
            +
                b = y1 - m * x1
         | 
| 75 | 
            +
                return lambda x: m * x + b
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def get_schedule(
         | 
| 79 | 
            +
                num_steps: int,
         | 
| 80 | 
            +
                image_seq_len: int,
         | 
| 81 | 
            +
                base_shift: float = 0.5,
         | 
| 82 | 
            +
                max_shift: float = 1.15,
         | 
| 83 | 
            +
                shift: bool = True,
         | 
| 84 | 
            +
            ) -> list[float]:
         | 
| 85 | 
            +
                # extra step for zero
         | 
| 86 | 
            +
                timesteps = torch.linspace(1, 0, num_steps + 1)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # shifting the schedule to favor high timesteps for higher signal images
         | 
| 89 | 
            +
                if shift:
         | 
| 90 | 
            +
                    # eastimate mu based on linear estimation between two points
         | 
| 91 | 
            +
                    mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
         | 
| 92 | 
            +
                    timesteps = time_shift(mu, 1.0, timesteps)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                return timesteps.tolist()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def denoise(
         | 
| 98 | 
            +
                model: Flux,
         | 
| 99 | 
            +
                # model input
         | 
| 100 | 
            +
                img: Tensor,
         | 
| 101 | 
            +
                img_ids: Tensor,
         | 
| 102 | 
            +
                txt: Tensor,
         | 
| 103 | 
            +
                txt_ids: Tensor,
         | 
| 104 | 
            +
                vec: Tensor,
         | 
| 105 | 
            +
                # sampling parameters
         | 
| 106 | 
            +
                timesteps: list[float],
         | 
| 107 | 
            +
                guidance: float = 4.0,
         | 
| 108 | 
            +
                use_gs=False,
         | 
| 109 | 
            +
                gs=4,
         | 
| 110 | 
            +
            ):
         | 
| 111 | 
            +
                # this is ignored for schnell
         | 
| 112 | 
            +
                guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
         | 
| 113 | 
            +
                for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
         | 
| 114 | 
            +
                    t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
         | 
| 115 | 
            +
                    pred = model(
         | 
| 116 | 
            +
                        img=img,
         | 
| 117 | 
            +
                        img_ids=img_ids,
         | 
| 118 | 
            +
                        txt=txt,
         | 
| 119 | 
            +
                        txt_ids=txt_ids,
         | 
| 120 | 
            +
                        y=vec,
         | 
| 121 | 
            +
                        timesteps=t_vec,
         | 
| 122 | 
            +
                        guidance=guidance_vec,
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
                    if use_gs:
         | 
| 125 | 
            +
                        pred_uncond, pred_text = pred.chunk(2)
         | 
| 126 | 
            +
                        pred = pred_uncond + gs * (pred_text - pred_uncond)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    img = img + (t_prev - t_curr) * pred
         | 
| 129 | 
            +
                    #if use_gs:
         | 
| 130 | 
            +
                    #    img = torch.cat([img] * 2)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                return img
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            def denoise_controlnet(
         | 
| 135 | 
            +
                model: Flux,
         | 
| 136 | 
            +
                controlnet:None,
         | 
| 137 | 
            +
                # model input
         | 
| 138 | 
            +
                img: Tensor,
         | 
| 139 | 
            +
                img_ids: Tensor,
         | 
| 140 | 
            +
                txt: Tensor,
         | 
| 141 | 
            +
                txt_ids: Tensor,
         | 
| 142 | 
            +
                vec: Tensor,
         | 
| 143 | 
            +
                controlnet_cond,
         | 
| 144 | 
            +
                # sampling parameters
         | 
| 145 | 
            +
                timesteps: list[float],
         | 
| 146 | 
            +
                guidance: float = 4.0,
         | 
| 147 | 
            +
                controlnet_gs=0.7,
         | 
| 148 | 
            +
            ):
         | 
| 149 | 
            +
                # this is ignored for schnell
         | 
| 150 | 
            +
                guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
         | 
| 151 | 
            +
                for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
         | 
| 152 | 
            +
                    t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
         | 
| 153 | 
            +
                    block_res_samples = controlnet(
         | 
| 154 | 
            +
                                img=img,
         | 
| 155 | 
            +
                                img_ids=img_ids,
         | 
| 156 | 
            +
                                controlnet_cond=controlnet_cond,
         | 
| 157 | 
            +
                                txt=txt,
         | 
| 158 | 
            +
                                txt_ids=txt_ids,
         | 
| 159 | 
            +
                                y=vec,
         | 
| 160 | 
            +
                                timesteps=t_vec,
         | 
| 161 | 
            +
                                guidance=guidance_vec,
         | 
| 162 | 
            +
                            )
         | 
| 163 | 
            +
                    pred = model(
         | 
| 164 | 
            +
                        img=img,
         | 
| 165 | 
            +
                        img_ids=img_ids,
         | 
| 166 | 
            +
                        txt=txt,
         | 
| 167 | 
            +
                        txt_ids=txt_ids,
         | 
| 168 | 
            +
                        y=vec,
         | 
| 169 | 
            +
                        timesteps=t_vec,
         | 
| 170 | 
            +
                        guidance=guidance_vec,
         | 
| 171 | 
            +
                        block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples]
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    img = img + (t_prev - t_curr) * pred
         | 
| 175 | 
            +
                    #if use_gs:
         | 
| 176 | 
            +
                    #    img = torch.cat([img] * 2)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                return img
         | 
| 179 | 
            +
             | 
| 180 | 
            +
            def unpack(x: Tensor, height: int, width: int) -> Tensor:
         | 
| 181 | 
            +
                return rearrange(
         | 
| 182 | 
            +
                    x,
         | 
| 183 | 
            +
                    "b (h w) (c ph pw) -> b c (h ph) (w pw)",
         | 
| 184 | 
            +
                    h=math.ceil(height / 16),
         | 
| 185 | 
            +
                    w=math.ceil(width / 16),
         | 
| 186 | 
            +
                    ph=2,
         | 
| 187 | 
            +
                    pw=2,
         | 
| 188 | 
            +
                )
         | 
    	
        src/flux/util.py
    ADDED
    
    | @@ -0,0 +1,237 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 7 | 
            +
            from safetensors.torch import load_file as load_sft
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .model import Flux, FluxParams
         | 
| 10 | 
            +
            from .controlnet import ControlNetFlux
         | 
| 11 | 
            +
            from .modules.autoencoder import AutoEncoder, AutoEncoderParams
         | 
| 12 | 
            +
            from .modules.conditioner import HFEmbedder
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from safetensors import safe_open
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def load_safetensors(path):
         | 
| 17 | 
            +
                tensors = {}
         | 
| 18 | 
            +
                with safe_open(path, framework="pt", device="cpu") as f:
         | 
| 19 | 
            +
                    for key in f.keys():
         | 
| 20 | 
            +
                        tensors[key] = f.get_tensor(key)
         | 
| 21 | 
            +
                return tensors
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            @dataclass
         | 
| 24 | 
            +
            class ModelSpec:
         | 
| 25 | 
            +
                params: FluxParams
         | 
| 26 | 
            +
                ae_params: AutoEncoderParams
         | 
| 27 | 
            +
                ckpt_path: str | None
         | 
| 28 | 
            +
                ae_path: str | None
         | 
| 29 | 
            +
                repo_id: str | None
         | 
| 30 | 
            +
                repo_flow: str | None
         | 
| 31 | 
            +
                repo_ae: str | None
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            configs = {
         | 
| 35 | 
            +
                "flux-dev": ModelSpec(
         | 
| 36 | 
            +
                    repo_id="black-forest-labs/FLUX.1-dev",
         | 
| 37 | 
            +
                    repo_flow="flux1-dev.safetensors",
         | 
| 38 | 
            +
                    repo_ae="ae.safetensors",
         | 
| 39 | 
            +
                    ckpt_path=os.getenv("FLUX_DEV"),
         | 
| 40 | 
            +
                    params=FluxParams(
         | 
| 41 | 
            +
                        in_channels=64,
         | 
| 42 | 
            +
                        vec_in_dim=768,
         | 
| 43 | 
            +
                        context_in_dim=4096,
         | 
| 44 | 
            +
                        hidden_size=3072,
         | 
| 45 | 
            +
                        mlp_ratio=4.0,
         | 
| 46 | 
            +
                        num_heads=24,
         | 
| 47 | 
            +
                        depth=19,
         | 
| 48 | 
            +
                        depth_single_blocks=38,
         | 
| 49 | 
            +
                        axes_dim=[16, 56, 56],
         | 
| 50 | 
            +
                        theta=10_000,
         | 
| 51 | 
            +
                        qkv_bias=True,
         | 
| 52 | 
            +
                        guidance_embed=True,
         | 
| 53 | 
            +
                    ),
         | 
| 54 | 
            +
                    ae_path=os.getenv("AE"),
         | 
| 55 | 
            +
                    ae_params=AutoEncoderParams(
         | 
| 56 | 
            +
                        resolution=256,
         | 
| 57 | 
            +
                        in_channels=3,
         | 
| 58 | 
            +
                        ch=128,
         | 
| 59 | 
            +
                        out_ch=3,
         | 
| 60 | 
            +
                        ch_mult=[1, 2, 4, 4],
         | 
| 61 | 
            +
                        num_res_blocks=2,
         | 
| 62 | 
            +
                        z_channels=16,
         | 
| 63 | 
            +
                        scale_factor=0.3611,
         | 
| 64 | 
            +
                        shift_factor=0.1159,
         | 
| 65 | 
            +
                    ),
         | 
| 66 | 
            +
                ),
         | 
| 67 | 
            +
                "flux-schnell": ModelSpec(
         | 
| 68 | 
            +
                    repo_id="black-forest-labs/FLUX.1-schnell",
         | 
| 69 | 
            +
                    repo_flow="flux1-schnell.safetensors",
         | 
| 70 | 
            +
                    repo_ae="ae.safetensors",
         | 
| 71 | 
            +
                    ckpt_path=os.getenv("FLUX_SCHNELL"),
         | 
| 72 | 
            +
                    params=FluxParams(
         | 
| 73 | 
            +
                        in_channels=64,
         | 
| 74 | 
            +
                        vec_in_dim=768,
         | 
| 75 | 
            +
                        context_in_dim=4096,
         | 
| 76 | 
            +
                        hidden_size=3072,
         | 
| 77 | 
            +
                        mlp_ratio=4.0,
         | 
| 78 | 
            +
                        num_heads=24,
         | 
| 79 | 
            +
                        depth=19,
         | 
| 80 | 
            +
                        depth_single_blocks=38,
         | 
| 81 | 
            +
                        axes_dim=[16, 56, 56],
         | 
| 82 | 
            +
                        theta=10_000,
         | 
| 83 | 
            +
                        qkv_bias=True,
         | 
| 84 | 
            +
                        guidance_embed=False,
         | 
| 85 | 
            +
                    ),
         | 
| 86 | 
            +
                    ae_path=os.getenv("AE"),
         | 
| 87 | 
            +
                    ae_params=AutoEncoderParams(
         | 
| 88 | 
            +
                        resolution=256,
         | 
| 89 | 
            +
                        in_channels=3,
         | 
| 90 | 
            +
                        ch=128,
         | 
| 91 | 
            +
                        out_ch=3,
         | 
| 92 | 
            +
                        ch_mult=[1, 2, 4, 4],
         | 
| 93 | 
            +
                        num_res_blocks=2,
         | 
| 94 | 
            +
                        z_channels=16,
         | 
| 95 | 
            +
                        scale_factor=0.3611,
         | 
| 96 | 
            +
                        shift_factor=0.1159,
         | 
| 97 | 
            +
                    ),
         | 
| 98 | 
            +
                ),
         | 
| 99 | 
            +
            }
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
         | 
| 103 | 
            +
                if len(missing) > 0 and len(unexpected) > 0:
         | 
| 104 | 
            +
                    print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
         | 
| 105 | 
            +
                    print("\n" + "-" * 79 + "\n")
         | 
| 106 | 
            +
                    print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
         | 
| 107 | 
            +
                elif len(missing) > 0:
         | 
| 108 | 
            +
                    print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
         | 
| 109 | 
            +
                elif len(unexpected) > 0:
         | 
| 110 | 
            +
                    print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
         | 
| 114 | 
            +
                # Loading Flux
         | 
| 115 | 
            +
                print("Init model")
         | 
| 116 | 
            +
                ckpt_path = configs[name].ckpt_path
         | 
| 117 | 
            +
                if (
         | 
| 118 | 
            +
                    ckpt_path is None
         | 
| 119 | 
            +
                    and configs[name].repo_id is not None
         | 
| 120 | 
            +
                    and configs[name].repo_flow is not None
         | 
| 121 | 
            +
                    and hf_download
         | 
| 122 | 
            +
                ):
         | 
| 123 | 
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                with torch.device("meta" if ckpt_path is not None else device):
         | 
| 126 | 
            +
                    model = Flux(configs[name].params).to(torch.bfloat16)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                if ckpt_path is not None:
         | 
| 129 | 
            +
                    print("Loading checkpoint")
         | 
| 130 | 
            +
                    # load_sft doesn't support torch.device
         | 
| 131 | 
            +
                    sd = load_sft(ckpt_path, device=str(device))
         | 
| 132 | 
            +
                    missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
         | 
| 133 | 
            +
                    print_load_warning(missing, unexpected)
         | 
| 134 | 
            +
                return model
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
            def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
         | 
| 137 | 
            +
                # Loading Flux
         | 
| 138 | 
            +
                print("Init model")
         | 
| 139 | 
            +
                ckpt_path = configs[name].ckpt_path
         | 
| 140 | 
            +
                if (
         | 
| 141 | 
            +
                    ckpt_path is None
         | 
| 142 | 
            +
                    and configs[name].repo_id is not None
         | 
| 143 | 
            +
                    and configs[name].repo_flow is not None
         | 
| 144 | 
            +
                    and hf_download
         | 
| 145 | 
            +
                ):
         | 
| 146 | 
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                with torch.device("meta" if ckpt_path is not None else device):
         | 
| 149 | 
            +
                    model = Flux(configs[name].params)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                if ckpt_path is not None:
         | 
| 152 | 
            +
                    print("Loading checkpoint")
         | 
| 153 | 
            +
                    # load_sft doesn't support torch.device
         | 
| 154 | 
            +
                    sd = load_sft(ckpt_path, device=str(device))
         | 
| 155 | 
            +
                    missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
         | 
| 156 | 
            +
                    print_load_warning(missing, unexpected)
         | 
| 157 | 
            +
                return model
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            def load_controlnet(name, device, transformer=None):
         | 
| 160 | 
            +
                with torch.device(device):
         | 
| 161 | 
            +
                    controlnet = ControlNetFlux(configs[name].params)
         | 
| 162 | 
            +
                if transformer is not None:
         | 
| 163 | 
            +
                    controlnet.load_state_dict(transformer.state_dict(), strict=False)
         | 
| 164 | 
            +
                return controlnet
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
         | 
| 167 | 
            +
                # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
         | 
| 168 | 
            +
                return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
         | 
| 172 | 
            +
                return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
         | 
| 176 | 
            +
                ckpt_path = configs[name].ae_path
         | 
| 177 | 
            +
                if (
         | 
| 178 | 
            +
                    ckpt_path is None
         | 
| 179 | 
            +
                    and configs[name].repo_id is not None
         | 
| 180 | 
            +
                    and configs[name].repo_ae is not None
         | 
| 181 | 
            +
                    and hf_download
         | 
| 182 | 
            +
                ):
         | 
| 183 | 
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # Loading the autoencoder
         | 
| 186 | 
            +
                print("Init AE")
         | 
| 187 | 
            +
                with torch.device("meta" if ckpt_path is not None else device):
         | 
| 188 | 
            +
                    ae = AutoEncoder(configs[name].ae_params)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                if ckpt_path is not None:
         | 
| 191 | 
            +
                    sd = load_sft(ckpt_path, device=str(device))
         | 
| 192 | 
            +
                    missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
         | 
| 193 | 
            +
                    print_load_warning(missing, unexpected)
         | 
| 194 | 
            +
                return ae
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            class WatermarkEmbedder:
         | 
| 198 | 
            +
                def __init__(self, watermark):
         | 
| 199 | 
            +
                    self.watermark = watermark
         | 
| 200 | 
            +
                    self.num_bits = len(WATERMARK_BITS)
         | 
| 201 | 
            +
                    self.encoder = WatermarkEncoder()
         | 
| 202 | 
            +
                    self.encoder.set_watermark("bits", self.watermark)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def __call__(self, image: torch.Tensor) -> torch.Tensor:
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    Adds a predefined watermark to the input image
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    Args:
         | 
| 209 | 
            +
                        image: ([N,] B, RGB, H, W) in range [-1, 1]
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    Returns:
         | 
| 212 | 
            +
                        same as input but watermarked
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    image = 0.5 * image + 0.5
         | 
| 215 | 
            +
                    squeeze = len(image.shape) == 4
         | 
| 216 | 
            +
                    if squeeze:
         | 
| 217 | 
            +
                        image = image[None, ...]
         | 
| 218 | 
            +
                    n = image.shape[0]
         | 
| 219 | 
            +
                    image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
         | 
| 220 | 
            +
                    # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
         | 
| 221 | 
            +
                    # watermarking libary expects input as cv2 BGR format
         | 
| 222 | 
            +
                    for k in range(image_np.shape[0]):
         | 
| 223 | 
            +
                        image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
         | 
| 224 | 
            +
                    image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
         | 
| 225 | 
            +
                        image.device
         | 
| 226 | 
            +
                    )
         | 
| 227 | 
            +
                    image = torch.clamp(image / 255, min=0.0, max=1.0)
         | 
| 228 | 
            +
                    if squeeze:
         | 
| 229 | 
            +
                        image = image[0]
         | 
| 230 | 
            +
                    image = 2 * image - 1
         | 
| 231 | 
            +
                    return image
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            # A fixed 48-bit message that was choosen at random
         | 
| 235 | 
            +
            WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
         | 
| 236 | 
            +
            # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
         | 
| 237 | 
            +
            WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
         | 
