Spaces:
Running
on
Zero
Running
on
Zero
刘虹雨
commited on
Commit
·
ab06a25
1
Parent(s):
7ead217
update code
Browse files
.gitignore
CHANGED
@@ -5,7 +5,6 @@ output*
|
|
5 |
logs*
|
6 |
taming*
|
7 |
samples*
|
8 |
-
datasets*
|
9 |
asset*
|
10 |
temp_samples*
|
11 |
wandb*
|
|
|
5 |
logs*
|
6 |
taming*
|
7 |
samples*
|
|
|
8 |
asset*
|
9 |
temp_samples*
|
10 |
wandb*
|
DiT_VAE/diffusion/data/datasets/TriplaneData.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from transformers import AutoImageProcessor
|
8 |
+
from DiT_VAE.diffusion.data.builder import DATASETS
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import CLIPImageProcessor
|
12 |
+
import io
|
13 |
+
import zipfile
|
14 |
+
import numpy
|
15 |
+
import json
|
16 |
+
|
17 |
+
|
18 |
+
def to_rgb_image(maybe_rgba: Image.Image):
|
19 |
+
if maybe_rgba.mode == 'RGB':
|
20 |
+
return maybe_rgba
|
21 |
+
elif maybe_rgba.mode == 'RGBA':
|
22 |
+
rgba = maybe_rgba
|
23 |
+
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
24 |
+
img = Image.fromarray(img, 'RGB')
|
25 |
+
img.paste(rgba, mask=rgba.getchannel('A'))
|
26 |
+
return img
|
27 |
+
else:
|
28 |
+
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
29 |
+
|
30 |
+
|
31 |
+
@DATASETS.register_module()
|
32 |
+
class TriplaneData(Dataset):
|
33 |
+
def __init__(self,
|
34 |
+
data_base_dir,
|
35 |
+
model_names,
|
36 |
+
data_json_file,
|
37 |
+
dino_path,
|
38 |
+
i_drop_rate=0.1,
|
39 |
+
image_size=256,
|
40 |
+
**kwargs):
|
41 |
+
self.dict_data_image = json.load(open(data_json_file)) # {'image_name': pose}
|
42 |
+
self.data_base_dir = data_base_dir
|
43 |
+
self.dino_img_processor = AutoImageProcessor.from_pretrained(dino_path)
|
44 |
+
self.size = image_size
|
45 |
+
self.data_list = list(self.dict_data_image.keys())
|
46 |
+
self.zip_file_dict = {}
|
47 |
+
config_gan_model = OmegaConf.load(model_names)
|
48 |
+
all_models = config_gan_model['gan_models'].keys()
|
49 |
+
for model_name in all_models:
|
50 |
+
zipfile_path = os.path.join(self.data_base_dir, model_name + '.zip')
|
51 |
+
zipfile_load = zipfile.ZipFile(zipfile_path)
|
52 |
+
self.zip_file_dict[model_name] = zipfile_load
|
53 |
+
self.transform = transforms.Compose([
|
54 |
+
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
55 |
+
transforms.CenterCrop(self.size),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize([0.5], [0.5]),
|
58 |
+
])
|
59 |
+
self.clip_image_processor = CLIPImageProcessor()
|
60 |
+
self.i_drop_rate = i_drop_rate
|
61 |
+
|
62 |
+
|
63 |
+
def getdata(self, idx):
|
64 |
+
|
65 |
+
data_name = self.data_list[idx]
|
66 |
+
data_model_name = self.dict_data_image[data_name]['model_name']
|
67 |
+
zipfile_loaded = self.zip_file_dict[data_model_name]
|
68 |
+
# zipfile_path = os.path.join(self.data_base_dir, data_model_name)
|
69 |
+
# zipfile_loaded = zipfile.ZipFile(zipfile_path)
|
70 |
+
with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
|
71 |
+
buffer = io.BytesIO(f.read())
|
72 |
+
data_z = torch.load(buffer)
|
73 |
+
|
74 |
+
with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as f:
|
75 |
+
buffer = io.BytesIO(f.read())
|
76 |
+
data_vert = torch.load(buffer)
|
77 |
+
|
78 |
+
with zipfile_loaded.open(self.dict_data_image[data_name]['img_dir'], 'r') as f:
|
79 |
+
raw_image = to_rgb_image(Image.open(f))
|
80 |
+
dino_img = self.dino_img_processor(images=raw_image, return_tensors="pt").pixel_values
|
81 |
+
image = self.transform(raw_image.convert("RGB"))
|
82 |
+
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
83 |
+
drop_image_embed = 0
|
84 |
+
rand_num = random.random()
|
85 |
+
if rand_num < self.i_drop_rate:
|
86 |
+
drop_image_embed = 1
|
87 |
+
return {
|
88 |
+
"raw_image": raw_image,
|
89 |
+
"dino_img": dino_img,
|
90 |
+
"image": image,
|
91 |
+
"clip_image": clip_image.clone(),
|
92 |
+
"data_z": data_z,
|
93 |
+
"data_vert": data_vert,
|
94 |
+
"data_model_name": data_model_name,
|
95 |
+
"drop_image_embed": drop_image_embed,
|
96 |
+
|
97 |
+
}
|
98 |
+
|
99 |
+
#
|
100 |
+
# img_path = self.img_samples[index]
|
101 |
+
# npz_path = self.txt_feat_samples[index]
|
102 |
+
# npy_path = self.vae_feat_samples[index]
|
103 |
+
# prompt = self.prompt_samples[index]
|
104 |
+
# data_info = {
|
105 |
+
# 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
|
106 |
+
# 'aspect_ratio': torch.tensor(1.)
|
107 |
+
# }
|
108 |
+
#
|
109 |
+
# img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
|
110 |
+
# txt_info = np.load(npz_path)
|
111 |
+
# txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
|
112 |
+
# attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT
|
113 |
+
# if 'attention_mask' in txt_info.keys():
|
114 |
+
# attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
|
115 |
+
# if txt_fea.shape[1] != self.max_lenth:
|
116 |
+
# txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
|
117 |
+
# attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
|
118 |
+
#
|
119 |
+
# if self.transform:
|
120 |
+
# img = self.transform(img)
|
121 |
+
#
|
122 |
+
# data_info['prompt'] = prompt
|
123 |
+
# return img, txt_fea, attention_mask, data_info
|
124 |
+
|
125 |
+
def __getitem__(self, idx):
|
126 |
+
for _ in range(20):
|
127 |
+
try:
|
128 |
+
return self.getdata(idx)
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error details: {str(e)}")
|
131 |
+
idx = np.random.randint(len(self))
|
132 |
+
raise RuntimeError('Too many bad data.')
|
133 |
+
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.data_list)
|
137 |
+
|
138 |
+
def __getattr__(self, name):
|
139 |
+
if name == "set_epoch":
|
140 |
+
return lambda epoch: None
|
141 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
DiT_VAE/diffusion/data/datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .TriplaneData import TriplaneData
|
2 |
+
from .utils import *
|
DiT_VAE/diffusion/data/datasets/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
ASPECT_RATIO_1024 = {
|
4 |
+
'0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
|
5 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
6 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
7 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
8 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
9 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
10 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
11 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
12 |
+
'2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
|
13 |
+
'3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
|
14 |
+
}
|
15 |
+
|
16 |
+
ASPECT_RATIO_512 = {
|
17 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
18 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
19 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
20 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
21 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
22 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
23 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
24 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
25 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
26 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
27 |
+
}
|
28 |
+
|
29 |
+
ASPECT_RATIO_256 = {
|
30 |
+
'0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
|
31 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
32 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
33 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
34 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
35 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
36 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
37 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
38 |
+
'2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
|
39 |
+
'3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
|
40 |
+
}
|
41 |
+
|
42 |
+
ASPECT_RATIO_256_TEST = {
|
43 |
+
'0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
|
44 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
45 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
46 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
47 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
48 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
49 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
50 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
51 |
+
'2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
|
52 |
+
'4.0': [512.0, 128.0]
|
53 |
+
}
|
54 |
+
|
55 |
+
ASPECT_RATIO_512_TEST = {
|
56 |
+
'0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
|
57 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
58 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
59 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
60 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
61 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
62 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
63 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
64 |
+
'2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
|
65 |
+
'4.0': [1024.0, 256.0]
|
66 |
+
}
|
67 |
+
|
68 |
+
ASPECT_RATIO_1024_TEST = {
|
69 |
+
'0.25': [512., 2048.], '0.28': [512., 1856.],
|
70 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
71 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
72 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
73 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
74 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
75 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
76 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
77 |
+
'2.5': [1600., 640.], '3.0': [1728., 576.],
|
78 |
+
'4.0': [2048., 512.],
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
def get_chunks(lst, n):
|
83 |
+
for i in range(0, len(lst), n):
|
84 |
+
yield lst[i:i + n]
|