Spaces:
Sleeping
Sleeping
| # some parts are modified from Diffusers library (Apache License 2.0) | |
| import math | |
| from types import SimpleNamespace | |
| from typing import Any, Optional | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from einops import rearrange | |
| from library.utils import setup_logging | |
| setup_logging() | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| from library import sdxl_original_unet | |
| from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl | |
| class ControlNetConditioningEmbedding(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| dims = [16, 32, 96, 256] | |
| self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(dims) - 1): | |
| channel_in = dims[i] | |
| channel_out = dims[i + 1] | |
| self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) | |
| self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) | |
| self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) | |
| nn.init.zeros_(self.conv_out.weight) # zero module weight | |
| nn.init.zeros_(self.conv_out.bias) # zero module bias | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| x = F.silu(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = F.silu(x) | |
| x = self.conv_out(x) | |
| return x | |
| class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): | |
| def __init__(self, multiplier: Optional[float] = None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.multiplier = multiplier | |
| # remove unet layers | |
| self.output_blocks = nn.ModuleList([]) | |
| del self.out | |
| self.controlnet_cond_embedding = ControlNetConditioningEmbedding() | |
| dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] | |
| self.controlnet_down_blocks = nn.ModuleList([]) | |
| for dim in dims: | |
| self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) | |
| nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight | |
| nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias | |
| self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) | |
| nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight | |
| nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias | |
| def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): | |
| unet_sd = unet.state_dict() | |
| unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} | |
| sd = super().state_dict() | |
| sd.update(unet_sd) | |
| info = super().load_state_dict(sd, strict=True, assign=True) | |
| return info | |
| def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: | |
| # convert state_dict to SAI format | |
| unet_sd = {} | |
| for k in list(state_dict.keys()): | |
| if not k.startswith("controlnet_"): | |
| unet_sd[k] = state_dict.pop(k) | |
| unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) | |
| state_dict.update(unet_sd) | |
| super().load_state_dict(state_dict, strict=strict, assign=assign) | |
| def state_dict(self, destination=None, prefix="", keep_vars=False): | |
| # convert state_dict to Diffusers format | |
| state_dict = super().state_dict(destination, prefix, keep_vars) | |
| control_net_sd = {} | |
| for k in list(state_dict.keys()): | |
| if k.startswith("controlnet_"): | |
| control_net_sd[k] = state_dict.pop(k) | |
| state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) | |
| state_dict.update(control_net_sd) | |
| return state_dict | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| cond_image: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| # broadcast timesteps to batch dimension | |
| timesteps = timesteps.expand(x.shape[0]) | |
| t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) | |
| t_emb = t_emb.to(x.dtype) | |
| emb = self.time_embed(t_emb) | |
| assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" | |
| assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" | |
| emb = emb + self.label_emb(y) | |
| def call_module(module, h, emb, context): | |
| x = h | |
| for layer in module: | |
| if isinstance(layer, sdxl_original_unet.ResnetBlock2D): | |
| x = layer(x, emb) | |
| elif isinstance(layer, sdxl_original_unet.Transformer2DModel): | |
| x = layer(x, context) | |
| else: | |
| x = layer(x) | |
| return x | |
| h = x | |
| multiplier = self.multiplier if self.multiplier is not None else 1.0 | |
| hs = [] | |
| for i, module in enumerate(self.input_blocks): | |
| h = call_module(module, h, emb, context) | |
| if i == 0: | |
| h = self.controlnet_cond_embedding(cond_image) + h | |
| hs.append(self.controlnet_down_blocks[i](h) * multiplier) | |
| h = call_module(self.middle_block, h, emb, context) | |
| h = self.controlnet_mid_block(h) * multiplier | |
| return hs, h | |
| class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): | |
| """ | |
| This class is for training purpose only. | |
| """ | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): | |
| # broadcast timesteps to batch dimension | |
| timesteps = timesteps.expand(x.shape[0]) | |
| hs = [] | |
| t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) | |
| t_emb = t_emb.to(x.dtype) | |
| emb = self.time_embed(t_emb) | |
| assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" | |
| assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" | |
| emb = emb + self.label_emb(y) | |
| def call_module(module, h, emb, context): | |
| x = h | |
| for layer in module: | |
| if isinstance(layer, sdxl_original_unet.ResnetBlock2D): | |
| x = layer(x, emb) | |
| elif isinstance(layer, sdxl_original_unet.Transformer2DModel): | |
| x = layer(x, context) | |
| else: | |
| x = layer(x) | |
| return x | |
| h = x | |
| for module in self.input_blocks: | |
| h = call_module(module, h, emb, context) | |
| hs.append(h) | |
| h = call_module(self.middle_block, h, emb, context) | |
| h = h + mid_add | |
| for module in self.output_blocks: | |
| resi = hs.pop() + input_resi_add.pop() | |
| h = torch.cat([h, resi], dim=1) | |
| h = call_module(module, h, emb, context) | |
| h = h.type(x.dtype) | |
| h = call_module(self.out, h, emb, context) | |
| return h | |
| if __name__ == "__main__": | |
| import time | |
| logger.info("create unet") | |
| unet = SdxlControlledUNet() | |
| unet.to("cuda", torch.bfloat16) | |
| unet.set_use_sdpa(True) | |
| unet.set_gradient_checkpointing(True) | |
| unet.train() | |
| logger.info("create control_net") | |
| control_net = SdxlControlNet() | |
| control_net.to("cuda") | |
| control_net.set_use_sdpa(True) | |
| control_net.set_gradient_checkpointing(True) | |
| control_net.train() | |
| logger.info("Initialize control_net from unet") | |
| control_net.init_from_unet(unet) | |
| unet.requires_grad_(False) | |
| control_net.requires_grad_(True) | |
| # 使用メモリ量確認用の疑似学習ループ | |
| logger.info("preparing optimizer") | |
| # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working | |
| import bitsandbytes | |
| optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working | |
| # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 | |
| # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 | |
| # import transformers | |
| # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 | |
| scaler = torch.cuda.amp.GradScaler(enabled=True) | |
| logger.info("start training") | |
| steps = 10 | |
| batch_size = 1 | |
| for step in range(steps): | |
| logger.info(f"step {step}") | |
| if step == 1: | |
| time_start = time.perf_counter() | |
| x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 | |
| t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") | |
| txt = torch.randn(batch_size, 77, 2048).cuda() | |
| vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() | |
| cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() | |
| with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): | |
| input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) | |
| output = unet(x, t, txt, vector, input_resi_add, mid_add) | |
| target = torch.randn_like(output) | |
| loss = torch.nn.functional.mse_loss(output, target) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad(set_to_none=True) | |
| time_end = time.perf_counter() | |
| logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") | |
| logger.info("finish training") | |
| sd = control_net.state_dict() | |
| from safetensors.torch import save_file | |
| save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") | |