# Credits: # Original Flux code can be found on: https://github.com/black-forest-labs/flux # Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow from dataclasses import dataclass from typing import Optional import torch from torch import Tensor, nn from einops import repeat import comfy.ldm.common_dit from comfy.ldm.flux.layers import EmbedND from comfy.ldm.chroma.model import Chroma, ChromaParams from comfy.ldm.chroma.layers import ( DoubleStreamBlock, SingleStreamBlock, Approximator, ) from .layers import ( NerfEmbedder, NerfGLUBlock, NerfFinalLayer, NerfFinalLayerConv, ) @dataclass class ChromaRadianceParams(ChromaParams): patch_size: int nerf_hidden_size: int nerf_mlp_ratio: int nerf_depth: int nerf_max_freqs: int # Setting nerf_tile_size to 0 disables tiling. nerf_tile_size: int # Currently one of linear (legacy) or conv. nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] class ChromaRadiance(Chroma): """ Transformer model for flow matching on sequences. """ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): if operations is None: raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") nn.Module.__init__(self) self.dtype = dtype params = ChromaRadianceParams(**kwargs) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" ) pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.in_dim = params.in_dim self.out_dim = params.out_dim self.hidden_dim = params.hidden_dim self.n_layers = params.n_layers self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.img_in_patch = operations.Conv2d( params.in_channels, params.hidden_size, kernel_size=params.patch_size, stride=params.patch_size, bias=True, dtype=dtype, device=device, ) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) # set as nn identity for now, will overwrite it later. self.distilled_guidance_layer = Approximator( in_dim=self.in_dim, hidden_dim=self.hidden_dim, out_dim=self.out_dim, n_layers=self.n_layers, dtype=dtype, device=device, operations=operations ) self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations, ) for _ in range(params.depth_single_blocks) ] ) # pixel channel concat with DCT self.nerf_image_embedder = NerfEmbedder( in_channels=params.in_channels, hidden_size_input=params.nerf_hidden_size, max_freqs=params.nerf_max_freqs, dtype=params.nerf_embedder_dtype or dtype, device=device, operations=operations, ) self.nerf_blocks = nn.ModuleList([ NerfGLUBlock( hidden_size_s=params.hidden_size, hidden_size_x=params.nerf_hidden_size, mlp_ratio=params.nerf_mlp_ratio, dtype=dtype, device=device, operations=operations, ) for _ in range(params.nerf_depth) ]) if params.nerf_final_head_type == "linear": self.nerf_final_layer = NerfFinalLayer( params.nerf_hidden_size, out_channels=params.in_channels, dtype=dtype, device=device, operations=operations, ) elif params.nerf_final_head_type == "conv": self.nerf_final_layer_conv = NerfFinalLayerConv( params.nerf_hidden_size, out_channels=params.in_channels, dtype=dtype, device=device, operations=operations, ) else: errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" raise ValueError(errstr) self.skip_mmdit = [] self.skip_dit = [] self.lite = False @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": return self.nerf_final_layer if self.params.nerf_final_head_type == "conv": return self.nerf_final_layer_conv # Impossible to get here as we raise an error on unexpected types on initialization. raise NotImplementedError def img_in(self, img: Tensor) -> Tensor: img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] # flatten into a sequence for the transformer. return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] def forward_nerf( self, img_orig: Tensor, img_out: Tensor, params: ChromaRadianceParams, ) -> Tensor: B, C, H, W = img_orig.shape num_patches = img_out.shape[1] patch_size = params.patch_size # Store the raw pixel values of each patch for the NeRF head later. # unfold creates patches: [B, C * P * P, NumPatches] nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than # the tile size. img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) else: # Reshape for per-patch processing nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) # Get DCT-encoded pixel embeddings [pixel-dct] img_dct = self.nerf_image_embedder(nerf_pixels) # Pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: img_dct = block(img_dct, nerf_hidden) # Reassemble the patches into the final image. img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] # Reshape to combine with batch dimension for fold img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] img_dct = nn.functional.fold( img_dct, output_size=(H, W), kernel_size=patch_size, stride=patch_size, ) return self._nerf_final_layer(img_dct) def forward_tiled_nerf( self, nerf_hidden: Tensor, nerf_pixels: Tensor, batch: int, channels: int, num_patches: int, patch_size: int, params: ChromaRadianceParams, ) -> Tensor: """ Processes the NeRF head in tiles to save memory. nerf_hidden has shape [B, L, D] nerf_pixels has shape [B, L, C * P * P] """ tile_size = params.nerf_tile_size output_tiles = [] # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. for i in range(0, num_patches, tile_size): end = min(i + tile_size, num_patches) # Slice the current tile from the input tensors nerf_hidden_tile = nerf_hidden[:, i:end, :] nerf_pixels_tile = nerf_pixels[:, i:end, :] # Get the actual number of patches in this tile (can be smaller for the last tile) num_patches_tile = nerf_hidden_tile.shape[1] # Reshape the tile for per-patch processing # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) # get DCT-encoded pixel embeddings [pixel-dct] img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) # pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: img_dct_tile = block(img_dct_tile, nerf_hidden_tile) output_tiles.append(img_dct_tile) # Concatenate the processed tiles along the patch dimension return torch.cat(output_tiles, dim=0) def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: params = self.params if not overrides: return params params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} nullable_keys = frozenset(("nerf_embedder_dtype",)) bad_keys = tuple(k for k in overrides if k not in params_dict) if bad_keys: e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" raise ValueError(e) bad_keys = tuple( k for k, v in overrides.items() if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" raise ValueError(e) # At this point it's all valid keys and values so we can merge with the existing params. params_dict |= overrides return params.__class__(**params_dict) def _forward( self, x: Tensor, timestep: Tensor, context: Tensor, guidance: Optional[Tensor], control: Optional[dict]=None, transformer_options: dict={}, **kwargs: dict, ) -> Tensor: bs, c, h, w = x.shape img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) if img.ndim != 4: raise ValueError("Input img tensor must be in [B, C, H, W] format.") if context.ndim != 3: raise ValueError("Input txt tensors must have 3 dimensions.") params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) h_len = (img.shape[-2] // self.patch_size) w_len = (img.shape[-1] // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) img_out = self.forward_orig( img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None), ) return self.forward_nerf(img, img_out, params)[:, :, :h, :w]