刘虹雨 commited on
Commit
ab06a25
·
1 Parent(s): 7ead217

update code

Browse files
.gitignore CHANGED
@@ -5,7 +5,6 @@ output*
5
  logs*
6
  taming*
7
  samples*
8
- datasets*
9
  asset*
10
  temp_samples*
11
  wandb*
 
5
  logs*
6
  taming*
7
  samples*
 
8
  asset*
9
  temp_samples*
10
  wandb*
DiT_VAE/diffusion/data/datasets/TriplaneData.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from transformers import AutoImageProcessor
8
+ from DiT_VAE.diffusion.data.builder import DATASETS
9
+ from omegaconf import OmegaConf
10
+ from torchvision import transforms
11
+ from transformers import CLIPImageProcessor
12
+ import io
13
+ import zipfile
14
+ import numpy
15
+ import json
16
+
17
+
18
+ def to_rgb_image(maybe_rgba: Image.Image):
19
+ if maybe_rgba.mode == 'RGB':
20
+ return maybe_rgba
21
+ elif maybe_rgba.mode == 'RGBA':
22
+ rgba = maybe_rgba
23
+ img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
24
+ img = Image.fromarray(img, 'RGB')
25
+ img.paste(rgba, mask=rgba.getchannel('A'))
26
+ return img
27
+ else:
28
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
29
+
30
+
31
+ @DATASETS.register_module()
32
+ class TriplaneData(Dataset):
33
+ def __init__(self,
34
+ data_base_dir,
35
+ model_names,
36
+ data_json_file,
37
+ dino_path,
38
+ i_drop_rate=0.1,
39
+ image_size=256,
40
+ **kwargs):
41
+ self.dict_data_image = json.load(open(data_json_file)) # {'image_name': pose}
42
+ self.data_base_dir = data_base_dir
43
+ self.dino_img_processor = AutoImageProcessor.from_pretrained(dino_path)
44
+ self.size = image_size
45
+ self.data_list = list(self.dict_data_image.keys())
46
+ self.zip_file_dict = {}
47
+ config_gan_model = OmegaConf.load(model_names)
48
+ all_models = config_gan_model['gan_models'].keys()
49
+ for model_name in all_models:
50
+ zipfile_path = os.path.join(self.data_base_dir, model_name + '.zip')
51
+ zipfile_load = zipfile.ZipFile(zipfile_path)
52
+ self.zip_file_dict[model_name] = zipfile_load
53
+ self.transform = transforms.Compose([
54
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
55
+ transforms.CenterCrop(self.size),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize([0.5], [0.5]),
58
+ ])
59
+ self.clip_image_processor = CLIPImageProcessor()
60
+ self.i_drop_rate = i_drop_rate
61
+
62
+
63
+ def getdata(self, idx):
64
+
65
+ data_name = self.data_list[idx]
66
+ data_model_name = self.dict_data_image[data_name]['model_name']
67
+ zipfile_loaded = self.zip_file_dict[data_model_name]
68
+ # zipfile_path = os.path.join(self.data_base_dir, data_model_name)
69
+ # zipfile_loaded = zipfile.ZipFile(zipfile_path)
70
+ with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
71
+ buffer = io.BytesIO(f.read())
72
+ data_z = torch.load(buffer)
73
+
74
+ with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as f:
75
+ buffer = io.BytesIO(f.read())
76
+ data_vert = torch.load(buffer)
77
+
78
+ with zipfile_loaded.open(self.dict_data_image[data_name]['img_dir'], 'r') as f:
79
+ raw_image = to_rgb_image(Image.open(f))
80
+ dino_img = self.dino_img_processor(images=raw_image, return_tensors="pt").pixel_values
81
+ image = self.transform(raw_image.convert("RGB"))
82
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
83
+ drop_image_embed = 0
84
+ rand_num = random.random()
85
+ if rand_num < self.i_drop_rate:
86
+ drop_image_embed = 1
87
+ return {
88
+ "raw_image": raw_image,
89
+ "dino_img": dino_img,
90
+ "image": image,
91
+ "clip_image": clip_image.clone(),
92
+ "data_z": data_z,
93
+ "data_vert": data_vert,
94
+ "data_model_name": data_model_name,
95
+ "drop_image_embed": drop_image_embed,
96
+
97
+ }
98
+
99
+ #
100
+ # img_path = self.img_samples[index]
101
+ # npz_path = self.txt_feat_samples[index]
102
+ # npy_path = self.vae_feat_samples[index]
103
+ # prompt = self.prompt_samples[index]
104
+ # data_info = {
105
+ # 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
106
+ # 'aspect_ratio': torch.tensor(1.)
107
+ # }
108
+ #
109
+ # img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
110
+ # txt_info = np.load(npz_path)
111
+ # txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
112
+ # attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT
113
+ # if 'attention_mask' in txt_info.keys():
114
+ # attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
115
+ # if txt_fea.shape[1] != self.max_lenth:
116
+ # txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
117
+ # attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
118
+ #
119
+ # if self.transform:
120
+ # img = self.transform(img)
121
+ #
122
+ # data_info['prompt'] = prompt
123
+ # return img, txt_fea, attention_mask, data_info
124
+
125
+ def __getitem__(self, idx):
126
+ for _ in range(20):
127
+ try:
128
+ return self.getdata(idx)
129
+ except Exception as e:
130
+ print(f"Error details: {str(e)}")
131
+ idx = np.random.randint(len(self))
132
+ raise RuntimeError('Too many bad data.')
133
+
134
+
135
+ def __len__(self):
136
+ return len(self.data_list)
137
+
138
+ def __getattr__(self, name):
139
+ if name == "set_epoch":
140
+ return lambda epoch: None
141
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
DiT_VAE/diffusion/data/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .TriplaneData import TriplaneData
2
+ from .utils import *
DiT_VAE/diffusion/data/datasets/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ASPECT_RATIO_1024 = {
4
+ '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
5
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
6
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
7
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
8
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
9
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
10
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
11
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
12
+ '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
13
+ '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
14
+ }
15
+
16
+ ASPECT_RATIO_512 = {
17
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
18
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
19
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
20
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
21
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
22
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
23
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
24
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
25
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
26
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
27
+ }
28
+
29
+ ASPECT_RATIO_256 = {
30
+ '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
31
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
32
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
33
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
34
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
35
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
36
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
37
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
38
+ '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
39
+ '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
40
+ }
41
+
42
+ ASPECT_RATIO_256_TEST = {
43
+ '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
44
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
45
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
46
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
47
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
48
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
49
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
50
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
51
+ '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
52
+ '4.0': [512.0, 128.0]
53
+ }
54
+
55
+ ASPECT_RATIO_512_TEST = {
56
+ '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
57
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
58
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
59
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
60
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
61
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
62
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
63
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
64
+ '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
65
+ '4.0': [1024.0, 256.0]
66
+ }
67
+
68
+ ASPECT_RATIO_1024_TEST = {
69
+ '0.25': [512., 2048.], '0.28': [512., 1856.],
70
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
71
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
72
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
73
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
74
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
75
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
76
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
77
+ '2.5': [1600., 640.], '3.0': [1728., 576.],
78
+ '4.0': [2048., 512.],
79
+ }
80
+
81
+
82
+ def get_chunks(lst, n):
83
+ for i in range(0, len(lst), n):
84
+ yield lst[i:i + n]