""" training script for imagedream - the config system is similar with stable diffusion ldm code base(using omigaconf, yaml; target, params initialization, etc.) - the training code base is similar with unidiffuser training code base using accelerate """ from omegaconf import OmegaConf import argparse from pathlib import Path from torch.utils.data import DataLoader import os.path as osp import numpy as np import os import torch from PIL import Image import numpy as np import wandb from libs.base_utils import get_data_generator, PrintContext from libs.base_utils import ( setup, instantiate_from_config, dct2str, add_prefix, get_obj_from_str, ) from absl import logging from einops import rearrange from imagedream.camera_utils import get_camera from libs.sample import ImageDreamDiffusion from rich import print def train(config, unk): # using pipeline to extract models accelerator, device = setup(config, unk) with PrintContext(f"{'access STAT':-^50}", accelerator.is_main_process): print(accelerator.state) dtype = { "fp16": torch.float16, "fp32": torch.float32, "no": torch.float32, "bf16": torch.bfloat16, }[accelerator.state.mixed_precision] num_frames = config.num_frames ################## load models ################## model_config = config.models.config model_config = OmegaConf.load(model_config) model = instantiate_from_config(model_config.model) state_dict = torch.load(config.models.resume, map_location="cpu") print(model.load_state_dict(state_dict, strict=False)) print("loaded model from {}".format(config.models.resume)) latest_step = 0 if config.get("resume", False): print("resuming from specified workdir") ckpts = os.listdir(config.ckpt_root) if len(ckpts) == 0: print("no ckpt found") else: latest_ckpt = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))[-1] latest_step = int(latest_ckpt.split("-")[-1]) print("loadding ckpt from ", osp.join(config.ckpt_root, latest_ckpt)) unet_state_dict = torch.load( osp.join(config.ckpt_root, latest_ckpt), map_location="cpu" ) print(model.model.load_state_dict(unet_state_dict, strict=False)) elif config.models.get("resume_unet", None) is not None: unet_state_dict = torch.load(config.models.resume_unet, map_location="cpu") print(model.model.load_state_dict(unet_state_dict, strict=False)) print(f"______ load unet from {config.models.resume_unet} ______") model.to(device) model.device = device model.clip_model.device = device ################# setup optimizer ################# from torch.optim import AdamW from accelerate.utils import DummyOptim optimizer_cls = ( AdamW if accelerator.state.deepspeed_plugin is None or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config else DummyOptim ) optimizer = optimizer_cls(model.model.parameters(), **config.optimizer) ################# prepare datasets ################# dataset = instantiate_from_config(config.train_data) eval_dataset = instantiate_from_config(config.eval_data) in_the_wild_images = ( instantiate_from_config(config.in_the_wild_images) if config.get("in_the_wild_images", None) is not None else None ) dl_config = config.dataloader dataloader = DataLoader(dataset, **dl_config, batch_size=config.batch_size) ( model, optimizer, dataloader, ) = accelerator.prepare(model, optimizer, dataloader) generator = get_data_generator(dataloader, accelerator.is_main_process, "train") if config.get("sampler", None) is not None: sampler_cls = get_obj_from_str(config.sampler.target) sampler = sampler_cls(model, device, dtype, **config.sampler.params) else: sampler = ImageDreamDiffusion( model, mode=config.mode, num_frames=num_frames, device=device, dtype=dtype, camera_views=dataset.camera_views, offset_noise=config.get("offset_noise", False), ref_position=dataset.ref_position, random_background=dataset.random_background, resize_rate=dataset.resize_rate, ) ################# evaluation code ################# def evaluation(): return_ls = [] for i in range( accelerator.process_index, len(eval_dataset), accelerator.num_processes ): cond = eval_dataset[i]["cond"] images = sampler.diffuse("3D assets.", cond, n_test=2) images = np.concatenate(images, 0) images = [Image.fromarray(images)] return_ls.append(dict(images=images, ident=eval_dataset[i]["ident"])) return return_ls def evaluation2(): # eval for common used in the wild image return_ls = [] in_the_wild_images.init_item() for i in range( accelerator.process_index, len(in_the_wild_images), accelerator.num_processes, ): cond = in_the_wild_images[i]["cond"] images = sampler.diffuse("3D assets.", cond, n_test=2) images = np.concatenate(images, 0) images = [Image.fromarray(images)] return_ls.append(dict(images=images, ident=in_the_wild_images[i]["ident"])) return return_ls if latest_step == 0: global_step = 0 total_step = 0 log_step = 0 eval_step = 0 save_step = 0 else: global_step = latest_step // config.total_batch_size total_step = latest_step log_step = latest_step + config.log_interval eval_step = latest_step + config.eval_interval save_step = latest_step + config.save_interval unet = model.model while True: item = next(generator) unet.train() bs = item["clip_cond"].shape[0] BS = bs * num_frames item["clip_cond"] = item["clip_cond"].to(device).to(dtype) item["vae_cond"] = item["vae_cond"].to(device).to(dtype) camera_input = item["cameras"].to(device) camera_input = camera_input.reshape((BS, camera_input.shape[-1])) gd_type = config.get("gd_type", "pixel") if gd_type == "pixel": item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) gd = item["target_images_vae"] elif gd_type == "xyz": item["target_images_xyz_vae"] = ( item["target_images_xyz_vae"].to(device).to(dtype) ) gd = item["target_images_xyz_vae"] elif gd_type == "fusechannel": item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) item["target_images_xyz_vae"] = ( item["target_images_xyz_vae"].to(device).to(dtype) ) gd = torch.cat( (item["target_images_vae"], item["target_images_xyz_vae"]), dim=0 ) else: raise NotImplementedError with torch.no_grad(), accelerator.autocast("cuda"): ip_embed = model.clip_model.encode_image_with_transformer(item["clip_cond"]) ip_ = ip_embed.repeat_interleave(num_frames, dim=0) ip_img = model.get_first_stage_encoding( model.encode_first_stage(item["vae_cond"]) ) gd = rearrange(gd, "B F C H W -> (B F) C H W") latent_target_images = model.get_first_stage_encoding( model.encode_first_stage(gd) ) if gd_type == "fusechannel": latent_target_images = rearrange( latent_target_images, "(B F) C H W -> B F C H W", B=bs * 2 ) image_latent, xyz_latent = torch.chunk(latent_target_images, 2) fused_channel_latent = torch.cat((image_latent, xyz_latent), dim=-3) latent_target_images = rearrange( fused_channel_latent, "B F C H W -> (B F) C H W" ) if item.get("captions", None) is not None: caption_ls = np.array(item["caption"]).T.reshape((-1, BS)).squeeze() prompt_cond = model.get_learned_conditioning(caption_ls) elif item.get("caption", None) is not None: prompt_cond = model.get_learned_conditioning(item["caption"]) prompt_cond = prompt_cond.repeat_interleave(num_frames, dim=0) else: prompt_cond = model.get_learned_conditioning(["3D assets."]).repeat( BS, 1, 1 ) condition = { "context": prompt_cond, "ip": ip_, "ip_img": ip_img, "camera": camera_input, } with torch.autocast("cuda"), accelerator.accumulate(model): time_steps = torch.randint(0, model.num_timesteps, (BS,), device=device) noise = torch.randn_like(latent_target_images, device=device) # noise_img, _ = torch.chunk(noise, 2, dim=1) # noise = torch.cat((noise_img, noise_img), dim=1) x_noisy = model.q_sample(latent_target_images, time_steps, noise) output = unet(x_noisy, time_steps, **condition, num_frames=num_frames) reshaped_pred = output.reshape(bs, num_frames, *output.shape[1:]).permute( 1, 0, 2, 3, 4 ) reshaped_noise = noise.reshape(bs, num_frames, *noise.shape[1:]).permute( 1, 0, 2, 3, 4 ) true_pred = reshaped_pred[: num_frames - 1] fake_pred = reshaped_pred[num_frames - 1 :] true_noise = reshaped_noise[: num_frames - 1] fake_noise = reshaped_noise[num_frames - 1 :] loss = ( torch.nn.functional.mse_loss(true_noise, true_pred) + torch.nn.functional.mse_loss(fake_noise, fake_pred) * 0 ) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() global_step += 1 total_step = global_step * config.total_batch_size if total_step > log_step: metrics = dict( loss=accelerator.gather(loss.detach().mean()).mean().item(), scale=( accelerator.scaler.get_scale() if accelerator.scaler is not None else -1 ), ) log_step += config.log_interval if accelerator.is_main_process: logging.info(dct2str(dict(step=total_step, **metrics))) wandb.log(add_prefix(metrics, "train"), step=total_step) if total_step > save_step and accelerator.is_main_process: logging.info("saving done") torch.save( unet.state_dict(), osp.join(config.ckpt_root, f"unet-{total_step}") ) save_step += config.save_interval logging.info("save done") if total_step > eval_step: logging.info("evaluationing") unet.eval() return_ls = evaluation() cur_eval_base = osp.join(config.eval_root, f"{total_step:07d}") os.makedirs(cur_eval_base, exist_ok=True) for item in return_ls: for i, im in enumerate(item["images"]): im.save( osp.join( cur_eval_base, f"{item['ident']}-{i:03d}-{accelerator.process_index}-.png", ) ) return_ls2 = evaluation2() cur_eval_base = osp.join(config.eval_root2, f"{total_step:07d}") os.makedirs(cur_eval_base, exist_ok=True) for item in return_ls2: for i, im in enumerate(item["images"]): im.save( osp.join( cur_eval_base, f"{item['ident']}-{i:03d}-{accelerator.process_index}-inthewild.png", ) ) eval_step += config.eval_interval logging.info("evaluation done") accelerator.wait_for_everyone() if total_step > config.max_step: break if __name__ == "__main__": # load config from config path, then merge with cli args parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, default="configs/nf7_v3_SNR_rd_size_stroke.yaml" ) parser.add_argument( "--logdir", type=str, default="train_logs", help="the dir to put logs" ) parser.add_argument( "--resume_workdir", type=str, default=None, help="specify to do resume" ) args, unk = parser.parse_known_args() print(args, unk) config = OmegaConf.load(args.config) if args.resume_workdir is not None: assert osp.exists(args.resume_workdir), f"{args.resume_workdir} not exists" config.config.workdir = args.resume_workdir config.config.resume = True OmegaConf.set_struct(config, True) # prevent adding new keys cli_conf = OmegaConf.from_cli(unk) config = OmegaConf.merge(config, cli_conf) config = config.config OmegaConf.set_struct(config, False) config.logdir = args.logdir config.config_name = Path(args.config).stem train(config, unk)