Mariam-Elz commited on
Commit
0f95337
·
verified ·
1 Parent(s): 1e5acec

Upload libs/data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. libs/data.py +580 -0
libs/data.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import random
6
+ import pandas as pd
7
+ import os.path as osp
8
+ import PIL.Image as Image
9
+ from torch.utils.data import Dataset
10
+ from pathlib import Path
11
+ from imagedream.ldm.util import add_random_background
12
+ from imagedream.camera_utils import get_camera_for_index
13
+ from libs.base_utils import do_resize_content, add_stroke
14
+
15
+ import torchvision.transforms as transforms
16
+
17
+
18
+ def to_rgb_image(maybe_rgba: Image.Image):
19
+ if maybe_rgba.mode == "RGB":
20
+ return maybe_rgba
21
+ elif maybe_rgba.mode == "RGBA":
22
+ rgba = maybe_rgba
23
+ img = numpy.random.randint(
24
+ 127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8
25
+ )
26
+ img = Image.fromarray(img, "RGB")
27
+ img.paste(rgba, mask=rgba.getchannel("A"))
28
+ return img
29
+ else:
30
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
31
+
32
+
33
+ def axis_rotate_xyz(img: Image.Image, rotate_axis="z", angle=90.0):
34
+ img = img.convert("RGB")
35
+ img = np.array(img) - 127
36
+ img = img.astype(np.float32)
37
+ # perform element-wise sin-cos rotation
38
+ if rotate_axis == "z":
39
+ img = np.stack(
40
+ [
41
+ img[..., 0] * np.cos(angle) - img[..., 1] * np.sin(angle),
42
+ img[..., 0] * np.sin(angle) + img[..., 1] * np.cos(angle),
43
+ img[..., 2],
44
+ ],
45
+ -1,
46
+ )
47
+ elif rotate_axis == "y":
48
+ img = np.stack(
49
+ [
50
+ img[..., 0] * np.cos(angle) + img[..., 2] * np.sin(angle),
51
+ img[..., 1],
52
+ -img[..., 0] * np.sin(angle) + img[..., 2] * np.cos(angle),
53
+ ],
54
+ -1,
55
+ )
56
+ elif rotate_axis == "x":
57
+ img = np.stack(
58
+ [
59
+ img[..., 0],
60
+ img[..., 1] * np.cos(angle) - img[..., 2] * np.sin(angle),
61
+ img[..., 1] * np.sin(angle) + img[..., 2] * np.cos(angle),
62
+ ],
63
+ -1,
64
+ )
65
+
66
+ return Image.fromarray(img.astype(np.uint8) + 127)
67
+
68
+
69
+ class DataHQCRelative(Dataset):
70
+ """
71
+ - base_dir
72
+ - uid1
73
+ - 000.png
74
+ - 001.png
75
+ - ...
76
+ - uid2
77
+ - xyz_base
78
+ - uid1
79
+ - xyz_new_000.png
80
+ - xyz_new_001.png
81
+ - ...
82
+ accepte caption data(in csv format)
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ base_dir,
88
+ caption_csv,
89
+ ref_indexs=[0],
90
+ ref_position=-1,
91
+ xyz_base=None,
92
+ camera_views=[3, 6, 9, 12, 15], # camera views are relative views, not abs
93
+ split="train",
94
+ image_size=256,
95
+ random_background=False,
96
+ resize_rate=1,
97
+ num_frames=5,
98
+ repeat=100,
99
+ outer_file=None,
100
+ debug=False,
101
+ eval_size=100,
102
+ ):
103
+ print(__class__)
104
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
105
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
106
+ df = pd.read_csv(caption_csv, sep=",", names=["id", "caption"])
107
+ id_to_caption = {}
108
+ for i in range(len(df.index)):
109
+ item = df.iloc[i]
110
+ id_to_caption[item["id"]] = item["caption"]
111
+
112
+ # outer file is txt file, containing each ident per line, specific idents not included in the train process
113
+ outer_set = (
114
+ set(open(outer_file, "r").read().strip().split("\n"))
115
+ if outer_file is not None
116
+ else set()
117
+ )
118
+ xyz_set = set(os.listdir(xyz_base)) if xyz_base is not None else set()
119
+ common_keys = set(id_to_caption.keys()) & set(os.listdir(base_dir))
120
+ common_keys = common_keys & xyz_set if xyz_base is not None else common_keys
121
+ common_keys = common_keys - outer_set
122
+ self.common_keys = common_keys
123
+ self.id_to_caption = id_to_caption
124
+ final_dict = {key: id_to_caption[key] for key in common_keys}
125
+ self.image_size = image_size
126
+ self.base_dir = Path(base_dir)
127
+ self.xyz_base = xyz_base
128
+ self.repeat = repeat
129
+ self.num_frames = num_frames
130
+ self.camera_views = camera_views[:num_frames]
131
+ self.split = split
132
+ self.ref_indexs = ref_indexs
133
+ self.ref_position = ref_position
134
+ self.resize_rate = resize_rate
135
+ self.random_background = random_background
136
+ self.debug = debug
137
+ assert split in ["train", "eval"]
138
+
139
+ clip_size = 224
140
+ self.transfrom_clip = transforms.Compose(
141
+ [
142
+ transforms.Resize(
143
+ (clip_size, clip_size),
144
+ interpolation=Image.BICUBIC,
145
+ antialias="warn",
146
+ ),
147
+ transforms.ToTensor(),
148
+ transforms.Normalize(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
149
+ ]
150
+ )
151
+
152
+ self.transfrom_vae = transforms.Compose(
153
+ [
154
+ transforms.Resize((image_size, image_size)),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
157
+ ]
158
+ )
159
+ # 对于第i个视角作为参考时,左边,下面,背面,右边,上面,的图片名称index
160
+ import torchvision.transforms.functional as TF
161
+ from functools import partial as PA
162
+
163
+ self.index_mapping = [
164
+ # 正 左 下 背 右 上
165
+ [0, 1, 2, 3, 4, 5], # 0
166
+ [1, 3, 2, 4, 0, 5], # 1
167
+ [2, 1, 3, 5, 4, 0], # 2
168
+ [3, 4, 2, 0, 1, 5], # 3
169
+ [4, 0, 2, 1, 3, 5], # 4
170
+ [5, 1, 0, 2, 4, 3], # 5
171
+ ]
172
+ TT = {
173
+ "r90": PA(TF.rotate, angle=-90.0), # 顺时针90
174
+ "r180": PA(TF.rotate, angle=-180.0), # 顺时针180
175
+ "r270": PA(TF.rotate, angle=-270.0), # 顺时针270
176
+ "s90": PA(TF.rotate, angle=90.0), # 逆时针90
177
+ "s180": PA(TF.rotate, angle=180.0), # 逆时针180
178
+ "s270": PA(TF.rotate, angle=270.0), # 逆时针270
179
+ }
180
+
181
+ self.transfroms_mapping = [
182
+ # 正 左 下 背 右 上
183
+ [None, None, None, None, None, None], # 0
184
+ [None, None, TT["r90"], None, None, TT["s90"]], # 1
185
+ [None, TT["s90"], TT["s180"], TT["r180"], TT["r90"], None], # 2
186
+ [None, None, TT["r180"], None, None, TT["s180"]], # 3
187
+ [None, None, TT["s90"], None, None, TT["r90"]], # 4
188
+ [None, TT["r90"], None, TT["r180"], TT["s90"], TT["s180"]], # 5
189
+ ]
190
+
191
+ XT = { # xyz transforms
192
+ "zRrota90": PA(axis_rotate_xyz, rotate_axis="z", angle=np.pi / 2),
193
+ "zSrota90": PA(axis_rotate_xyz, rotate_axis="z", angle=-np.pi / 2),
194
+ "zrota180": PA(axis_rotate_xyz, rotate_axis="z", angle=-np.pi),
195
+ "xRrota90": PA(axis_rotate_xyz, rotate_axis="x", angle=np.pi / 2),
196
+ "xSrota90": PA(axis_rotate_xyz, rotate_axis="x", angle=-np.pi / 2),
197
+ }
198
+
199
+ self.xyz_transforms_mapping = [
200
+ # 正 左 下 背 右 上
201
+ [None,] * 6, # 0
202
+ [XT["zRrota90"],] * 6, # 1
203
+ [XT["xRrota90"],] * 6, # 2
204
+ [XT["zrota180"],] * 6, # 3
205
+ [XT["zSrota90"],] * 6, # 4
206
+ [XT["xSrota90"],] * 6, # 5
207
+ ]
208
+
209
+ total_items = [
210
+ {
211
+ "path": os.path.join(base_dir, k),
212
+ "xyz_path": os.path.join(xyz_base, k) if xyz_base is not None else None,
213
+ "caption": v,
214
+ }
215
+ for k, v in final_dict.items()
216
+ ]
217
+ total_items.sort(key=lambda x: x["path"])
218
+
219
+ if len(total_items) > eval_size:
220
+ if split == "train":
221
+ self.items = total_items[eval_size:]
222
+ else:
223
+ self.items = total_items[:eval_size]
224
+ else:
225
+ self.items = total_items
226
+
227
+ print("============= length of dataset %d =============" % len(self.items))
228
+
229
+ def __len__(self):
230
+ return len(self.items) * self.repeat
231
+
232
+ def __getitem__(self, index):
233
+ """
234
+ choose index for target 6 images
235
+ select one of them as input image
236
+ target_images_vae: batch of `num_frame` images of one object from different views, processed by vae_processor
237
+ ref_ip: ref image in piexl space
238
+ ref_ip_img:
239
+ camera views decide the logical camera pose of images:
240
+ 000 is front , ev: 0, azimuth: 0
241
+ 001 is left , ev: 0, azimuth: -90
242
+ 002 is down , ev: -90, azimuth: 0
243
+ 003 is back , ev: 0, azimuth: 180
244
+ 004 is right , ev: 0, azimuth: 90
245
+ 005 is top , ev: 90, azimuth: 0
246
+ ref_index decides which image choose to be input image
247
+
248
+ for example when camera views = [1, 2, 3, 4, 5, 0], ref_position=5
249
+ then dataset return the instance images in order as [left, down, back, right, top, front]
250
+ in which view[ref_position] = view[5] = 0, so the refrence image is the front image
251
+
252
+ as all the faces can be rotated to the front face, so any image can be placed to ref_position as ref image(need some tramsforms)
253
+ to do a better control of which image can be placed to ref_position, we can set ref_indexs.
254
+ ref_indexs set [0] default, that means only 000 named images will be placed to ref_position.
255
+ on the situation of ref_indexs=[0, 1, 3, 4], only 000, 001, 003, 004 named images will be placed to ref_position.
256
+ """
257
+ index_mapping = self.index_mapping
258
+ transfroms_mapping = self.transfroms_mapping
259
+ index = index % len(self.items)
260
+
261
+ target_dir = self.items[index]["path"]
262
+ target_xyz_dir = self.items[index]["xyz_path"]
263
+ caption = self.items[index]["caption"]
264
+
265
+ bg_color = np.random.rand() * 255
266
+ target_images = []
267
+ target_xyz_images = []
268
+ raw_images = []
269
+ raw_xyz_images = []
270
+ alpha_masks = []
271
+ ref_index = random.choice(self.ref_indexs)
272
+ cur_index_mapping = index_mapping[ref_index]
273
+ cur_transfroms_mapping = transfroms_mapping[ref_index]
274
+ cur_xyz_transfroms_mapping = self.xyz_transforms_mapping[ref_index]
275
+ for relative_view in self.camera_views:
276
+ image_index = cur_index_mapping[relative_view]
277
+ trans = cur_transfroms_mapping[relative_view]
278
+ trans_xyz = cur_xyz_transfroms_mapping[relative_view]
279
+ # open
280
+ img = Image.open(
281
+ os.path.join(target_dir, f"{image_index:03d}.png")
282
+ ).convert("RGBA")
283
+ if trans is not None:
284
+ img = trans(img)
285
+ img = do_resize_content(img, self.resize_rate)
286
+ alpha_mask = img.getchannel("A")
287
+ alpha_masks.append(alpha_mask)
288
+ if self.random_background:
289
+ img = add_random_background(img, bg_color)
290
+ img = img.convert("RGB")
291
+ target_images.append(self.transfrom_vae(img))
292
+
293
+ raw_images.append(img)
294
+
295
+ if self.xyz_base is not None:
296
+ img_xyz = Image.open(
297
+ os.path.join(target_xyz_dir, f"xyz_new_{image_index:03d}.png")
298
+ ).convert("RGBA")
299
+ img_xyz = trans_xyz(img_xyz) if trans_xyz is not None else img_xyz
300
+ img_xyz = trans(img_xyz) if trans is not None else img_xyz
301
+ img_xyz = do_resize_content(img_xyz, self.resize_rate)
302
+ img_xyz.putalpha(alpha_mask)
303
+ if self.random_background:
304
+ img_xyz = add_random_background(img_xyz, bg_color)
305
+ img_xyz = img_xyz.convert("RGB")
306
+ target_xyz_images.append(self.transfrom_vae(img_xyz))
307
+ if self.debug:
308
+ raw_xyz_images.append(img_xyz)
309
+
310
+ cameras = [get_camera_for_index(i).squeeze() for i in self.camera_views]
311
+ if self.ref_position is not None:
312
+ cameras[self.ref_position] = torch.zeros_like(
313
+ cameras[self.ref_position]
314
+ ) # set ref camera to zero
315
+
316
+ cameras = torch.stack(cameras)
317
+
318
+ input_img = Image.open(
319
+ os.path.join(target_dir, f"{ref_index:03d}.png")
320
+ ).convert("RGBA")
321
+ input_img = do_resize_content(input_img, self.resize_rate)
322
+ if self.random_background:
323
+ input_img = add_random_background(input_img, bg_color)
324
+ input_img = input_img.convert("RGB")
325
+
326
+ clip_cond = self.transfrom_clip(input_img)
327
+ vae_cond = self.transfrom_vae(input_img)
328
+
329
+ vae_target = torch.stack(target_images, dim=0)
330
+ if self.xyz_base is not None:
331
+ xyz_vae_target = torch.stack(target_xyz_images, dim=0)
332
+ else:
333
+ xyz_vae_target = []
334
+
335
+ if self.debug:
336
+ print(f"debug!!,{bg_color}")
337
+ return {
338
+ "target_images": raw_images,
339
+ "target_images_xyz": raw_xyz_images,
340
+ "input_img": input_img,
341
+ "cameras": cameras,
342
+ "caption": caption,
343
+ "item": self.items[index],
344
+ "alpha_masks": alpha_masks,
345
+ }
346
+
347
+ if self.split == "train":
348
+ return {
349
+ "target_images_vae": vae_target,
350
+ "target_images_xyz_vae": xyz_vae_target,
351
+ "clip_cond": clip_cond,
352
+ "vae_cond": vae_cond,
353
+ "cameras": cameras,
354
+ "caption": caption,
355
+ }
356
+ else: # eval
357
+ path = os.path.join(target_dir, f"{ref_index:03d}.png")
358
+ return dict(
359
+ path=path,
360
+ target_dir=target_dir,
361
+ cond_raw_images=raw_images,
362
+ cond=input_img,
363
+ ref_index=ref_index,
364
+ ident=f"{index}-{Path(target_dir).stem}",
365
+ )
366
+
367
+
368
+ class DataRelativeStroke(DataHQCRelative):
369
+ """a temp dataset for add sync base using fov data as ref image"""
370
+
371
+ def __init__(
372
+ self,
373
+ base_dir,
374
+ caption_csv,
375
+ ref_indexs=[0],
376
+ ref_position=-1,
377
+ xyz_base=None,
378
+ camera_views=[3, 6, 9, 12, 15], # camera views are relative views, not abs
379
+ split="train",
380
+ image_size=256,
381
+ random_background=False,
382
+ resize_rate=1,
383
+ num_frames=5,
384
+ repeat=100,
385
+ outer_file=None,
386
+ debug=False,
387
+ eval_size=100,
388
+ stroke_p=0.3,
389
+ resize_range=None,
390
+ ):
391
+ print(__class__)
392
+ super().__init__(
393
+ base_dir,
394
+ caption_csv,
395
+ ref_indexs=ref_indexs,
396
+ ref_position=ref_position,
397
+ xyz_base=xyz_base,
398
+ camera_views=camera_views,
399
+ split=split,
400
+ image_size=image_size,
401
+ random_background=random_background,
402
+ resize_rate=resize_rate,
403
+ num_frames=num_frames,
404
+ repeat=repeat,
405
+ outer_file=outer_file,
406
+ debug=debug,
407
+ eval_size=eval_size,
408
+ )
409
+ self.stroke_p = stroke_p
410
+ assert (
411
+ resize_range is None or len(resize_range) == 2
412
+ ), "resize_range should be a tuple of 2 elements"
413
+ self.resize_range = resize_range
414
+
415
+ def __len__(self):
416
+ return len(self.items) * self.repeat
417
+
418
+ def __getitem__(self, index):
419
+ index_mapping = self.index_mapping
420
+ transfroms_mapping = self.transfroms_mapping
421
+ index = index % len(self.items)
422
+
423
+ target_dir = self.items[index]["path"]
424
+ target_xyz_dir = self.items[index]["xyz_path"]
425
+ caption = self.items[index]["caption"]
426
+
427
+ bg_color = np.random.rand() * 255
428
+ target_images = []
429
+ target_xyz_images = []
430
+ raw_images = []
431
+ raw_xyz_images = []
432
+ alpha_masks = []
433
+ ref_index = random.choice(self.ref_indexs)
434
+ cur_index_mapping = index_mapping[ref_index]
435
+ cur_transfroms_mapping = transfroms_mapping[ref_index]
436
+ cur_xyz_transfroms_mapping = self.xyz_transforms_mapping[ref_index]
437
+ cur_resize_rate = (
438
+ random.uniform(*self.resize_range) * self.resize_rate
439
+ if self.resize_range is not None
440
+ else self.resize_rate
441
+ )
442
+ for relative_view in self.camera_views:
443
+ image_index = cur_index_mapping[relative_view]
444
+ trans = cur_transfroms_mapping[relative_view]
445
+ trans_xyz = cur_xyz_transfroms_mapping[relative_view]
446
+ # open
447
+ img = Image.open(
448
+ os.path.join(target_dir, f"{image_index:03d}.png")
449
+ ).convert("RGBA")
450
+ if trans is not None:
451
+ img = trans(img)
452
+ img = do_resize_content(img, cur_resize_rate)
453
+ alpha_mask = img.getchannel("A")
454
+ alpha_masks.append(alpha_mask)
455
+ if self.random_background:
456
+ img = add_random_background(img, bg_color)
457
+
458
+ img = img.convert("RGB")
459
+ target_images.append(self.transfrom_vae(img))
460
+ raw_images.append(img)
461
+
462
+ if self.xyz_base is not None:
463
+ img_xyz = Image.open(
464
+ os.path.join(target_xyz_dir, f"xyz_new_{image_index:03d}.png")
465
+ ).convert("RGBA")
466
+ img_xyz = trans_xyz(img_xyz) if trans_xyz is not None else img_xyz
467
+ img_xyz = trans(img_xyz) if trans is not None else img_xyz
468
+ img_xyz = do_resize_content(img_xyz, cur_resize_rate)
469
+ img_xyz.putalpha(alpha_mask)
470
+ if self.random_background:
471
+ img_xyz = add_random_background(img_xyz, bg_color)
472
+ img_xyz = img_xyz.convert("RGB")
473
+ target_xyz_images.append(self.transfrom_vae(img_xyz))
474
+ if self.debug:
475
+ raw_xyz_images.append(img_xyz)
476
+
477
+ cameras = [get_camera_for_index(i).squeeze() for i in self.camera_views]
478
+ if self.ref_position is not None:
479
+ cameras[self.ref_position] = torch.zeros_like(
480
+ cameras[self.ref_position]
481
+ ) # set ref camera to zero
482
+
483
+ cameras = torch.stack(cameras)
484
+
485
+ input_img = Image.open(
486
+ os.path.join(target_dir, f"{ref_index:03d}.png")
487
+ ).convert("RGBA")
488
+ input_img = do_resize_content(input_img, cur_resize_rate)
489
+ if random.random() < self.stroke_p:
490
+ ## random rgb color
491
+ color = (
492
+ random.randint(0, 255),
493
+ random.randint(0, 255),
494
+ random.randint(0, 255),
495
+ )
496
+ radius = random.randint(1, 3)
497
+ input_img = add_stroke(input_img, color=color, stroke_radius=radius)
498
+ if self.random_background:
499
+ input_img = add_random_background(input_img, bg_color)
500
+ input_img = input_img.convert("RGB")
501
+
502
+ clip_cond = self.transfrom_clip(input_img)
503
+ vae_cond = self.transfrom_vae(input_img)
504
+
505
+ vae_target = torch.stack(target_images, dim=0)
506
+ if self.xyz_base is not None:
507
+ xyz_vae_target = torch.stack(target_xyz_images, dim=0)
508
+ else:
509
+ xyz_vae_target = []
510
+
511
+ if self.debug:
512
+ print(f"debug!!,{bg_color}")
513
+ return {
514
+ "target_images": raw_images,
515
+ "target_images_xyz": raw_xyz_images,
516
+ "input_img": input_img,
517
+ "cameras": cameras,
518
+ "caption": caption,
519
+ "item": self.items[index],
520
+ "alpha_masks": alpha_masks,
521
+ "cur_resize_rate": cur_resize_rate,
522
+ }
523
+
524
+ if self.split == "train":
525
+ return {
526
+ "target_images_vae": vae_target,
527
+ "target_images_xyz_vae": xyz_vae_target,
528
+ "clip_cond": clip_cond,
529
+ "vae_cond": vae_cond,
530
+ "cameras": cameras,
531
+ "caption": caption,
532
+ }
533
+ else: # eval
534
+ path = os.path.join(target_dir, f"{ref_index:03d}.png")
535
+ return dict(
536
+ path=path,
537
+ target_dir=target_dir,
538
+ cond_raw_images=raw_images,
539
+ cond=input_img,
540
+ ref_index=ref_index,
541
+ ident=f"{index}-{Path(target_dir).stem}",
542
+ )
543
+
544
+
545
+ class InTheWildImages(Dataset):
546
+ """
547
+ a data set for in the wild images,
548
+ receive base floders, image path ls, path files as input
549
+ """
550
+
551
+ def __init__(self, base_dirs=[], image_paths=[], path_files=[]):
552
+ print(__class__)
553
+ self.base_dirs = base_dirs
554
+ self.image_paths = image_paths
555
+ self.path_files = path_files
556
+ self.init_item()
557
+
558
+ def init_item(self):
559
+ items = []
560
+ for d in self.base_dirs:
561
+ items += [osp.join(d, f) for f in os.listdir(d)]
562
+ items = items + self.image_paths
563
+
564
+ for file in self.path_files:
565
+ with open(file, "r") as f:
566
+ items += [line.strip() for line in f.readlines()]
567
+ items.sort()
568
+ self.items = items
569
+
570
+ def __len__(self):
571
+ return len(self.items)
572
+
573
+ def __getitem__(self, index):
574
+ item = self.items[index]
575
+ img = Image.open(item)
576
+ background = Image.new("RGBA", img.size, (0, 0, 0, 0))
577
+ cond = Image.alpha_composite(background, img)
578
+ return dict(
579
+ path=item, ident=f"{index}-{Path(item).stem}", cond=cond.convert("RGB")
580
+ )