Spaces:
Runtime error
Runtime error
| """ | |
| training script for imagedream | |
| - the config system is similar with stable diffusion ldm code base(using omigaconf, yaml; target, params initialization, etc.) | |
| - the training code base is similar with unidiffuser training code base using accelerate | |
| """ | |
| from omegaconf import OmegaConf | |
| import argparse | |
| from pathlib import Path | |
| from torch.utils.data import DataLoader | |
| import os.path as osp | |
| import numpy as np | |
| import os | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import wandb | |
| from libs.base_utils import get_data_generator, PrintContext | |
| from libs.base_utils import ( | |
| setup, | |
| instantiate_from_config, | |
| dct2str, | |
| add_prefix, | |
| get_obj_from_str, | |
| ) | |
| from absl import logging | |
| from einops import rearrange | |
| from imagedream.camera_utils import get_camera | |
| from libs.sample import ImageDreamDiffusion | |
| from rich import print | |
| def train(config, unk): | |
| # using pipeline to extract models | |
| accelerator, device = setup(config, unk) | |
| with PrintContext(f"{'access STAT':-^50}", accelerator.is_main_process): | |
| print(accelerator.state) | |
| dtype = { | |
| "fp16": torch.float16, | |
| "fp32": torch.float32, | |
| "no": torch.float32, | |
| "bf16": torch.bfloat16, | |
| }[accelerator.state.mixed_precision] | |
| num_frames = config.num_frames | |
| ################## load models ################## | |
| model_config = config.models.config | |
| model_config = OmegaConf.load(model_config) | |
| model = instantiate_from_config(model_config.model) | |
| state_dict = torch.load(config.models.resume, map_location="cpu") | |
| print(model.load_state_dict(state_dict, strict=False)) | |
| print("loaded model from {}".format(config.models.resume)) | |
| latest_step = 0 | |
| if config.get("resume", False): | |
| print("resuming from specified workdir") | |
| ckpts = os.listdir(config.ckpt_root) | |
| if len(ckpts) == 0: | |
| print("no ckpt found") | |
| else: | |
| latest_ckpt = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))[-1] | |
| latest_step = int(latest_ckpt.split("-")[-1]) | |
| print("loadding ckpt from ", osp.join(config.ckpt_root, latest_ckpt)) | |
| unet_state_dict = torch.load( | |
| osp.join(config.ckpt_root, latest_ckpt), map_location="cpu" | |
| ) | |
| print(model.model.load_state_dict(unet_state_dict, strict=False)) | |
| elif config.models.get("resume_unet", None) is not None: | |
| unet_state_dict = torch.load(config.models.resume_unet, map_location="cpu") | |
| print(model.model.load_state_dict(unet_state_dict, strict=False)) | |
| print(f"______ load unet from {config.models.resume_unet} ______") | |
| model.to(device) | |
| model.device = device | |
| model.clip_model.device = device | |
| ################# setup optimizer ################# | |
| from torch.optim import AdamW | |
| from accelerate.utils import DummyOptim | |
| optimizer_cls = ( | |
| AdamW | |
| if accelerator.state.deepspeed_plugin is None | |
| or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config | |
| else DummyOptim | |
| ) | |
| optimizer = optimizer_cls(model.model.parameters(), **config.optimizer) | |
| ################# prepare datasets ################# | |
| dataset = instantiate_from_config(config.train_data) | |
| eval_dataset = instantiate_from_config(config.eval_data) | |
| in_the_wild_images = ( | |
| instantiate_from_config(config.in_the_wild_images) | |
| if config.get("in_the_wild_images", None) is not None | |
| else None | |
| ) | |
| dl_config = config.dataloader | |
| dataloader = DataLoader(dataset, **dl_config, batch_size=config.batch_size) | |
| ( | |
| model, | |
| optimizer, | |
| dataloader, | |
| ) = accelerator.prepare(model, optimizer, dataloader) | |
| generator = get_data_generator(dataloader, accelerator.is_main_process, "train") | |
| if config.get("sampler", None) is not None: | |
| sampler_cls = get_obj_from_str(config.sampler.target) | |
| sampler = sampler_cls(model, device, dtype, **config.sampler.params) | |
| else: | |
| sampler = ImageDreamDiffusion( | |
| model, | |
| mode=config.mode, | |
| num_frames=num_frames, | |
| device=device, | |
| dtype=dtype, | |
| camera_views=dataset.camera_views, | |
| offset_noise=config.get("offset_noise", False), | |
| ref_position=dataset.ref_position, | |
| random_background=dataset.random_background, | |
| resize_rate=dataset.resize_rate, | |
| ) | |
| ################# evaluation code ################# | |
| def evaluation(): | |
| return_ls = [] | |
| for i in range( | |
| accelerator.process_index, len(eval_dataset), accelerator.num_processes | |
| ): | |
| cond = eval_dataset[i]["cond"] | |
| images = sampler.diffuse("3D assets.", cond, n_test=2) | |
| images = np.concatenate(images, 0) | |
| images = [Image.fromarray(images)] | |
| return_ls.append(dict(images=images, ident=eval_dataset[i]["ident"])) | |
| return return_ls | |
| def evaluation2(): | |
| # eval for common used in the wild image | |
| return_ls = [] | |
| in_the_wild_images.init_item() | |
| for i in range( | |
| accelerator.process_index, | |
| len(in_the_wild_images), | |
| accelerator.num_processes, | |
| ): | |
| cond = in_the_wild_images[i]["cond"] | |
| images = sampler.diffuse("3D assets.", cond, n_test=2) | |
| images = np.concatenate(images, 0) | |
| images = [Image.fromarray(images)] | |
| return_ls.append(dict(images=images, ident=in_the_wild_images[i]["ident"])) | |
| return return_ls | |
| if latest_step == 0: | |
| global_step = 0 | |
| total_step = 0 | |
| log_step = 0 | |
| eval_step = 0 | |
| save_step = 0 | |
| else: | |
| global_step = latest_step // config.total_batch_size | |
| total_step = latest_step | |
| log_step = latest_step + config.log_interval | |
| eval_step = latest_step + config.eval_interval | |
| save_step = latest_step + config.save_interval | |
| unet = model.model | |
| while True: | |
| item = next(generator) | |
| unet.train() | |
| bs = item["clip_cond"].shape[0] | |
| BS = bs * num_frames | |
| item["clip_cond"] = item["clip_cond"].to(device).to(dtype) | |
| item["vae_cond"] = item["vae_cond"].to(device).to(dtype) | |
| camera_input = item["cameras"].to(device) | |
| camera_input = camera_input.reshape((BS, camera_input.shape[-1])) | |
| gd_type = config.get("gd_type", "pixel") | |
| if gd_type == "pixel": | |
| item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) | |
| gd = item["target_images_vae"] | |
| elif gd_type == "xyz": | |
| item["target_images_xyz_vae"] = ( | |
| item["target_images_xyz_vae"].to(device).to(dtype) | |
| ) | |
| gd = item["target_images_xyz_vae"] | |
| elif gd_type == "fusechannel": | |
| item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) | |
| item["target_images_xyz_vae"] = ( | |
| item["target_images_xyz_vae"].to(device).to(dtype) | |
| ) | |
| gd = torch.cat( | |
| (item["target_images_vae"], item["target_images_xyz_vae"]), dim=0 | |
| ) | |
| else: | |
| raise NotImplementedError | |
| with torch.no_grad(), accelerator.autocast("cuda"): | |
| ip_embed = model.clip_model.encode_image_with_transformer(item["clip_cond"]) | |
| ip_ = ip_embed.repeat_interleave(num_frames, dim=0) | |
| ip_img = model.get_first_stage_encoding( | |
| model.encode_first_stage(item["vae_cond"]) | |
| ) | |
| gd = rearrange(gd, "B F C H W -> (B F) C H W") | |
| latent_target_images = model.get_first_stage_encoding( | |
| model.encode_first_stage(gd) | |
| ) | |
| if gd_type == "fusechannel": | |
| latent_target_images = rearrange( | |
| latent_target_images, "(B F) C H W -> B F C H W", B=bs * 2 | |
| ) | |
| image_latent, xyz_latent = torch.chunk(latent_target_images, 2) | |
| fused_channel_latent = torch.cat((image_latent, xyz_latent), dim=-3) | |
| latent_target_images = rearrange( | |
| fused_channel_latent, "B F C H W -> (B F) C H W" | |
| ) | |
| if item.get("captions", None) is not None: | |
| caption_ls = np.array(item["caption"]).T.reshape((-1, BS)).squeeze() | |
| prompt_cond = model.get_learned_conditioning(caption_ls) | |
| elif item.get("caption", None) is not None: | |
| prompt_cond = model.get_learned_conditioning(item["caption"]) | |
| prompt_cond = prompt_cond.repeat_interleave(num_frames, dim=0) | |
| else: | |
| prompt_cond = model.get_learned_conditioning(["3D assets."]).repeat( | |
| BS, 1, 1 | |
| ) | |
| condition = { | |
| "context": prompt_cond, | |
| "ip": ip_, | |
| "ip_img": ip_img, | |
| "camera": camera_input, | |
| } | |
| with torch.autocast("cuda"), accelerator.accumulate(model): | |
| time_steps = torch.randint(0, model.num_timesteps, (BS,), device=device) | |
| noise = torch.randn_like(latent_target_images, device=device) | |
| # noise_img, _ = torch.chunk(noise, 2, dim=1) | |
| # noise = torch.cat((noise_img, noise_img), dim=1) | |
| x_noisy = model.q_sample(latent_target_images, time_steps, noise) | |
| output = unet(x_noisy, time_steps, **condition, num_frames=num_frames) | |
| reshaped_pred = output.reshape(bs, num_frames, *output.shape[1:]).permute( | |
| 1, 0, 2, 3, 4 | |
| ) | |
| reshaped_noise = noise.reshape(bs, num_frames, *noise.shape[1:]).permute( | |
| 1, 0, 2, 3, 4 | |
| ) | |
| true_pred = reshaped_pred[: num_frames - 1] | |
| fake_pred = reshaped_pred[num_frames - 1 :] | |
| true_noise = reshaped_noise[: num_frames - 1] | |
| fake_noise = reshaped_noise[num_frames - 1 :] | |
| loss = ( | |
| torch.nn.functional.mse_loss(true_noise, true_pred) | |
| + torch.nn.functional.mse_loss(fake_noise, fake_pred) * 0 | |
| ) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| total_step = global_step * config.total_batch_size | |
| if total_step > log_step: | |
| metrics = dict( | |
| loss=accelerator.gather(loss.detach().mean()).mean().item(), | |
| scale=( | |
| accelerator.scaler.get_scale() | |
| if accelerator.scaler is not None | |
| else -1 | |
| ), | |
| ) | |
| log_step += config.log_interval | |
| if accelerator.is_main_process: | |
| logging.info(dct2str(dict(step=total_step, **metrics))) | |
| wandb.log(add_prefix(metrics, "train"), step=total_step) | |
| if total_step > save_step and accelerator.is_main_process: | |
| logging.info("saving done") | |
| torch.save( | |
| unet.state_dict(), osp.join(config.ckpt_root, f"unet-{total_step}") | |
| ) | |
| save_step += config.save_interval | |
| logging.info("save done") | |
| if total_step > eval_step: | |
| logging.info("evaluationing") | |
| unet.eval() | |
| return_ls = evaluation() | |
| cur_eval_base = osp.join(config.eval_root, f"{total_step:07d}") | |
| os.makedirs(cur_eval_base, exist_ok=True) | |
| for item in return_ls: | |
| for i, im in enumerate(item["images"]): | |
| im.save( | |
| osp.join( | |
| cur_eval_base, | |
| f"{item['ident']}-{i:03d}-{accelerator.process_index}-.png", | |
| ) | |
| ) | |
| return_ls2 = evaluation2() | |
| cur_eval_base = osp.join(config.eval_root2, f"{total_step:07d}") | |
| os.makedirs(cur_eval_base, exist_ok=True) | |
| for item in return_ls2: | |
| for i, im in enumerate(item["images"]): | |
| im.save( | |
| osp.join( | |
| cur_eval_base, | |
| f"{item['ident']}-{i:03d}-{accelerator.process_index}-inthewild.png", | |
| ) | |
| ) | |
| eval_step += config.eval_interval | |
| logging.info("evaluation done") | |
| accelerator.wait_for_everyone() | |
| if total_step > config.max_step: | |
| break | |
| if __name__ == "__main__": | |
| # load config from config path, then merge with cli args | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", type=str, default="configs/nf7_v3_SNR_rd_size_stroke.yaml" | |
| ) | |
| parser.add_argument( | |
| "--logdir", type=str, default="train_logs", help="the dir to put logs" | |
| ) | |
| parser.add_argument( | |
| "--resume_workdir", type=str, default=None, help="specify to do resume" | |
| ) | |
| args, unk = parser.parse_known_args() | |
| print(args, unk) | |
| config = OmegaConf.load(args.config) | |
| if args.resume_workdir is not None: | |
| assert osp.exists(args.resume_workdir), f"{args.resume_workdir} not exists" | |
| config.config.workdir = args.resume_workdir | |
| config.config.resume = True | |
| OmegaConf.set_struct(config, True) # prevent adding new keys | |
| cli_conf = OmegaConf.from_cli(unk) | |
| config = OmegaConf.merge(config, cli_conf) | |
| config = config.config | |
| OmegaConf.set_struct(config, False) | |
| config.logdir = args.logdir | |
| config.config_name = Path(args.config).stem | |
| train(config, unk) | |