Mariam-Elz commited on
Commit
e03ed9e
·
verified ·
1 Parent(s): 19cf59c

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +354 -0
train.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training script for imagedream
3
+ - the config system is similar with stable diffusion ldm code base(using omigaconf, yaml; target, params initialization, etc.)
4
+ - the training code base is similar with unidiffuser training code base using accelerate
5
+
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+ import argparse
10
+ from pathlib import Path
11
+ from torch.utils.data import DataLoader
12
+ import os.path as osp
13
+ import numpy as np
14
+ import os
15
+ import torch
16
+ from PIL import Image
17
+ import numpy as np
18
+ import wandb
19
+ from libs.base_utils import get_data_generator, PrintContext
20
+ from libs.base_utils import (
21
+ setup,
22
+ instantiate_from_config,
23
+ dct2str,
24
+ add_prefix,
25
+ get_obj_from_str,
26
+ )
27
+ from absl import logging
28
+ from einops import rearrange
29
+ from imagedream.camera_utils import get_camera
30
+ from libs.sample import ImageDreamDiffusion
31
+ from rich import print
32
+
33
+
34
+ def train(config, unk):
35
+ # using pipeline to extract models
36
+ accelerator, device = setup(config, unk)
37
+ with PrintContext(f"{'access STAT':-^50}", accelerator.is_main_process):
38
+ print(accelerator.state)
39
+ dtype = {
40
+ "fp16": torch.float16,
41
+ "fp32": torch.float32,
42
+ "no": torch.float32,
43
+ "bf16": torch.bfloat16,
44
+ }[accelerator.state.mixed_precision]
45
+
46
+ num_frames = config.num_frames
47
+
48
+ ################## load models ##################
49
+ model_config = config.models.config
50
+ model_config = OmegaConf.load(model_config)
51
+ model = instantiate_from_config(model_config.model)
52
+ state_dict = torch.load(config.models.resume, map_location="cpu")
53
+
54
+ print(model.load_state_dict(state_dict, strict=False))
55
+ print("loaded model from {}".format(config.models.resume))
56
+
57
+ latest_step = 0
58
+ if config.get("resume", False):
59
+ print("resuming from specified workdir")
60
+ ckpts = os.listdir(config.ckpt_root)
61
+ if len(ckpts) == 0:
62
+ print("no ckpt found")
63
+ else:
64
+ latest_ckpt = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))[-1]
65
+ latest_step = int(latest_ckpt.split("-")[-1])
66
+ print("loadding ckpt from ", osp.join(config.ckpt_root, latest_ckpt))
67
+ unet_state_dict = torch.load(
68
+ osp.join(config.ckpt_root, latest_ckpt), map_location="cpu"
69
+ )
70
+ print(model.model.load_state_dict(unet_state_dict, strict=False))
71
+
72
+ elif config.models.get("resume_unet", None) is not None:
73
+ unet_state_dict = torch.load(config.models.resume_unet, map_location="cpu")
74
+ print(model.model.load_state_dict(unet_state_dict, strict=False))
75
+ print(f"______ load unet from {config.models.resume_unet} ______")
76
+ model.to(device)
77
+ model.device = device
78
+ model.clip_model.device = device
79
+
80
+ ################# setup optimizer #################
81
+ from torch.optim import AdamW
82
+ from accelerate.utils import DummyOptim
83
+
84
+ optimizer_cls = (
85
+ AdamW
86
+ if accelerator.state.deepspeed_plugin is None
87
+ or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
88
+ else DummyOptim
89
+ )
90
+ optimizer = optimizer_cls(model.model.parameters(), **config.optimizer)
91
+
92
+ ################# prepare datasets #################
93
+ dataset = instantiate_from_config(config.train_data)
94
+ eval_dataset = instantiate_from_config(config.eval_data)
95
+ in_the_wild_images = (
96
+ instantiate_from_config(config.in_the_wild_images)
97
+ if config.get("in_the_wild_images", None) is not None
98
+ else None
99
+ )
100
+
101
+ dl_config = config.dataloader
102
+ dataloader = DataLoader(dataset, **dl_config, batch_size=config.batch_size)
103
+
104
+ (
105
+ model,
106
+ optimizer,
107
+ dataloader,
108
+ ) = accelerator.prepare(model, optimizer, dataloader)
109
+
110
+ generator = get_data_generator(dataloader, accelerator.is_main_process, "train")
111
+ if config.get("sampler", None) is not None:
112
+ sampler_cls = get_obj_from_str(config.sampler.target)
113
+ sampler = sampler_cls(model, device, dtype, **config.sampler.params)
114
+ else:
115
+ sampler = ImageDreamDiffusion(
116
+ model,
117
+ mode=config.mode,
118
+ num_frames=num_frames,
119
+ device=device,
120
+ dtype=dtype,
121
+ camera_views=dataset.camera_views,
122
+ offset_noise=config.get("offset_noise", False),
123
+ ref_position=dataset.ref_position,
124
+ random_background=dataset.random_background,
125
+ resize_rate=dataset.resize_rate,
126
+ )
127
+
128
+ ################# evaluation code #################
129
+ def evaluation():
130
+ return_ls = []
131
+ for i in range(
132
+ accelerator.process_index, len(eval_dataset), accelerator.num_processes
133
+ ):
134
+ cond = eval_dataset[i]["cond"]
135
+
136
+ images = sampler.diffuse("3D assets.", cond, n_test=2)
137
+ images = np.concatenate(images, 0)
138
+ images = [Image.fromarray(images)]
139
+ return_ls.append(dict(images=images, ident=eval_dataset[i]["ident"]))
140
+ return return_ls
141
+
142
+ def evaluation2():
143
+ # eval for common used in the wild image
144
+ return_ls = []
145
+ in_the_wild_images.init_item()
146
+ for i in range(
147
+ accelerator.process_index,
148
+ len(in_the_wild_images),
149
+ accelerator.num_processes,
150
+ ):
151
+ cond = in_the_wild_images[i]["cond"]
152
+ images = sampler.diffuse("3D assets.", cond, n_test=2)
153
+ images = np.concatenate(images, 0)
154
+ images = [Image.fromarray(images)]
155
+ return_ls.append(dict(images=images, ident=in_the_wild_images[i]["ident"]))
156
+ return return_ls
157
+
158
+ if latest_step == 0:
159
+ global_step = 0
160
+ total_step = 0
161
+ log_step = 0
162
+ eval_step = 0
163
+ save_step = 0
164
+ else:
165
+ global_step = latest_step // config.total_batch_size
166
+ total_step = latest_step
167
+ log_step = latest_step + config.log_interval
168
+ eval_step = latest_step + config.eval_interval
169
+ save_step = latest_step + config.save_interval
170
+
171
+ unet = model.model
172
+ while True:
173
+ item = next(generator)
174
+ unet.train()
175
+ bs = item["clip_cond"].shape[0]
176
+ BS = bs * num_frames
177
+ item["clip_cond"] = item["clip_cond"].to(device).to(dtype)
178
+ item["vae_cond"] = item["vae_cond"].to(device).to(dtype)
179
+ camera_input = item["cameras"].to(device)
180
+ camera_input = camera_input.reshape((BS, camera_input.shape[-1]))
181
+
182
+ gd_type = config.get("gd_type", "pixel")
183
+ if gd_type == "pixel":
184
+ item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype)
185
+ gd = item["target_images_vae"]
186
+ elif gd_type == "xyz":
187
+ item["target_images_xyz_vae"] = (
188
+ item["target_images_xyz_vae"].to(device).to(dtype)
189
+ )
190
+ gd = item["target_images_xyz_vae"]
191
+ elif gd_type == "fusechannel":
192
+ item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype)
193
+ item["target_images_xyz_vae"] = (
194
+ item["target_images_xyz_vae"].to(device).to(dtype)
195
+ )
196
+ gd = torch.cat(
197
+ (item["target_images_vae"], item["target_images_xyz_vae"]), dim=0
198
+ )
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ with torch.no_grad(), accelerator.autocast("cuda"):
203
+ ip_embed = model.clip_model.encode_image_with_transformer(item["clip_cond"])
204
+ ip_ = ip_embed.repeat_interleave(num_frames, dim=0)
205
+
206
+ ip_img = model.get_first_stage_encoding(
207
+ model.encode_first_stage(item["vae_cond"])
208
+ )
209
+
210
+ gd = rearrange(gd, "B F C H W -> (B F) C H W")
211
+
212
+ latent_target_images = model.get_first_stage_encoding(
213
+ model.encode_first_stage(gd)
214
+ )
215
+
216
+ if gd_type == "fusechannel":
217
+ latent_target_images = rearrange(
218
+ latent_target_images, "(B F) C H W -> B F C H W", B=bs * 2
219
+ )
220
+ image_latent, xyz_latent = torch.chunk(latent_target_images, 2)
221
+ fused_channel_latent = torch.cat((image_latent, xyz_latent), dim=-3)
222
+ latent_target_images = rearrange(
223
+ fused_channel_latent, "B F C H W -> (B F) C H W"
224
+ )
225
+
226
+ if item.get("captions", None) is not None:
227
+ caption_ls = np.array(item["caption"]).T.reshape((-1, BS)).squeeze()
228
+ prompt_cond = model.get_learned_conditioning(caption_ls)
229
+ elif item.get("caption", None) is not None:
230
+ prompt_cond = model.get_learned_conditioning(item["caption"])
231
+ prompt_cond = prompt_cond.repeat_interleave(num_frames, dim=0)
232
+ else:
233
+ prompt_cond = model.get_learned_conditioning(["3D assets."]).repeat(
234
+ BS, 1, 1
235
+ )
236
+ condition = {
237
+ "context": prompt_cond,
238
+ "ip": ip_,
239
+ "ip_img": ip_img,
240
+ "camera": camera_input,
241
+ }
242
+
243
+ with torch.autocast("cuda"), accelerator.accumulate(model):
244
+ time_steps = torch.randint(0, model.num_timesteps, (BS,), device=device)
245
+ noise = torch.randn_like(latent_target_images, device=device)
246
+ # noise_img, _ = torch.chunk(noise, 2, dim=1)
247
+ # noise = torch.cat((noise_img, noise_img), dim=1)
248
+ x_noisy = model.q_sample(latent_target_images, time_steps, noise)
249
+ output = unet(x_noisy, time_steps, **condition, num_frames=num_frames)
250
+ reshaped_pred = output.reshape(bs, num_frames, *output.shape[1:]).permute(
251
+ 1, 0, 2, 3, 4
252
+ )
253
+ reshaped_noise = noise.reshape(bs, num_frames, *noise.shape[1:]).permute(
254
+ 1, 0, 2, 3, 4
255
+ )
256
+ true_pred = reshaped_pred[: num_frames - 1]
257
+ fake_pred = reshaped_pred[num_frames - 1 :]
258
+ true_noise = reshaped_noise[: num_frames - 1]
259
+ fake_noise = reshaped_noise[num_frames - 1 :]
260
+ loss = (
261
+ torch.nn.functional.mse_loss(true_noise, true_pred)
262
+ + torch.nn.functional.mse_loss(fake_noise, fake_pred) * 0
263
+ )
264
+
265
+ accelerator.backward(loss)
266
+ optimizer.step()
267
+ optimizer.zero_grad()
268
+ global_step += 1
269
+
270
+ total_step = global_step * config.total_batch_size
271
+ if total_step > log_step:
272
+ metrics = dict(
273
+ loss=accelerator.gather(loss.detach().mean()).mean().item(),
274
+ scale=(
275
+ accelerator.scaler.get_scale()
276
+ if accelerator.scaler is not None
277
+ else -1
278
+ ),
279
+ )
280
+ log_step += config.log_interval
281
+ if accelerator.is_main_process:
282
+ logging.info(dct2str(dict(step=total_step, **metrics)))
283
+ wandb.log(add_prefix(metrics, "train"), step=total_step)
284
+
285
+ if total_step > save_step and accelerator.is_main_process:
286
+ logging.info("saving done")
287
+ torch.save(
288
+ unet.state_dict(), osp.join(config.ckpt_root, f"unet-{total_step}")
289
+ )
290
+ save_step += config.save_interval
291
+ logging.info("save done")
292
+
293
+ if total_step > eval_step:
294
+ logging.info("evaluationing")
295
+ unet.eval()
296
+ return_ls = evaluation()
297
+ cur_eval_base = osp.join(config.eval_root, f"{total_step:07d}")
298
+ os.makedirs(cur_eval_base, exist_ok=True)
299
+ for item in return_ls:
300
+ for i, im in enumerate(item["images"]):
301
+ im.save(
302
+ osp.join(
303
+ cur_eval_base,
304
+ f"{item['ident']}-{i:03d}-{accelerator.process_index}-.png",
305
+ )
306
+ )
307
+
308
+ return_ls2 = evaluation2()
309
+ cur_eval_base = osp.join(config.eval_root2, f"{total_step:07d}")
310
+ os.makedirs(cur_eval_base, exist_ok=True)
311
+ for item in return_ls2:
312
+ for i, im in enumerate(item["images"]):
313
+ im.save(
314
+ osp.join(
315
+ cur_eval_base,
316
+ f"{item['ident']}-{i:03d}-{accelerator.process_index}-inthewild.png",
317
+ )
318
+ )
319
+ eval_step += config.eval_interval
320
+ logging.info("evaluation done")
321
+
322
+ accelerator.wait_for_everyone()
323
+ if total_step > config.max_step:
324
+ break
325
+
326
+
327
+ if __name__ == "__main__":
328
+ # load config from config path, then merge with cli args
329
+ parser = argparse.ArgumentParser()
330
+ parser.add_argument(
331
+ "--config", type=str, default="configs/nf7_v3_SNR_rd_size_stroke.yaml"
332
+ )
333
+ parser.add_argument(
334
+ "--logdir", type=str, default="train_logs", help="the dir to put logs"
335
+ )
336
+ parser.add_argument(
337
+ "--resume_workdir", type=str, default=None, help="specify to do resume"
338
+ )
339
+ args, unk = parser.parse_known_args()
340
+ print(args, unk)
341
+ config = OmegaConf.load(args.config)
342
+ if args.resume_workdir is not None:
343
+ assert osp.exists(args.resume_workdir), f"{args.resume_workdir} not exists"
344
+ config.config.workdir = args.resume_workdir
345
+ config.config.resume = True
346
+ OmegaConf.set_struct(config, True) # prevent adding new keys
347
+ cli_conf = OmegaConf.from_cli(unk)
348
+ config = OmegaConf.merge(config, cli_conf)
349
+ config = config.config
350
+ OmegaConf.set_struct(config, False)
351
+ config.logdir = args.logdir
352
+ config.config_name = Path(args.config).stem
353
+
354
+ train(config, unk)