Chaitanya-02 commited on
Commit
f128780
·
verified ·
1 Parent(s): d6abc0e

create synth.py

Browse files
Files changed (1) hide show
  1. synth.py +893 -0
synth.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper scripts for generating synthetic images using diffusion model.
3
+ Functions:
4
+ - get_top_misclassified
5
+ - get_class_list
6
+ - generateClassPairs
7
+ - outputDirectory
8
+ - pipe_img
9
+ - createPrompts
10
+ - interpolatePrompts
11
+ - slerp
12
+ - get_middle_elements
13
+ - remove_middle
14
+ - genClassImg
15
+ - getMetadata
16
+ - groupbyInterpolation
17
+ - ungroupInterpolation
18
+ - groupAllbyInterpolation
19
+ - getPairIndices
20
+ - generateImagesFromDataset
21
+ - generateTrace
22
+ """
23
+
24
+ import json
25
+ import os
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ import torch
30
+ from DeepCache import DeepCacheSDHelper
31
+ from diffusers import (
32
+ LMSDiscreteScheduler,
33
+ StableDiffusionImg2ImgPipeline,
34
+ )
35
+ from torch import nn
36
+ from torchmetrics.functional.image import structural_similarity_index_measure as ssim
37
+ from torchvision import transforms
38
+
39
+
40
+ def get_top_misclassified(val_classifier_json):
41
+ """
42
+ Retrieves the top misclassified classes from a validation classifier JSON file.
43
+ Args:
44
+ val_classifier_json (str): The path to the validation classifier JSON file.
45
+ Returns:
46
+ dict: A dictionary containing the top misclassified classes, where the keys are the class names
47
+ and the values are the number of misclassifications.
48
+ """
49
+ with open(val_classifier_json) as f:
50
+ val_output = json.load(f)
51
+ val_metrics_df = pd.DataFrame.from_dict(
52
+ val_output["val_metrics_details"], orient="index"
53
+ )
54
+ class_dict = dict()
55
+ for k, v in val_metrics_df["top_n_classes"].items():
56
+ class_dict[k] = v
57
+ return class_dict
58
+
59
+
60
+ def get_class_list(val_classifier_json):
61
+ """
62
+ Retrieves the list of classes from the given validation classifier JSON file.
63
+ Args:
64
+ val_classifier_json (str): The path to the validation classifier JSON file.
65
+ Returns:
66
+ list: A sorted list of class names extracted from the JSON file.
67
+ """
68
+ with open(val_classifier_json, "r") as f:
69
+ data = json.load(f)
70
+ return sorted(list(data["val_metrics_details"].keys()))
71
+
72
+
73
+ def generateClassPairs(val_classifier_json):
74
+ """
75
+ Generate pairs of misclassified classes from the given validation classifier JSON.
76
+ Args:
77
+ val_classifier_json (str): The path to the validation classifier JSON file.
78
+ Returns:
79
+ list: A sorted list of pairs of misclassified classes.
80
+ """
81
+ pairs = set()
82
+ misclassified_classes = get_top_misclassified(val_classifier_json)
83
+ for key, value in misclassified_classes.items():
84
+ for v in value:
85
+ pairs.add(tuple(sorted([key, v])))
86
+ return sorted(list(pairs))
87
+
88
+
89
+ def outputDirectory(class_pairs, synth_path, metadata_path):
90
+ """
91
+ Creates the output directory structure for the synthesized data.
92
+ Args:
93
+ class_pairs (list): A list of class pairs.
94
+ synth_path (str): The path to the directory where the synthesized data will be stored.
95
+ metadata_path (str): The path to the directory where the metadata will be stored.
96
+ Returns:
97
+ None
98
+ """
99
+ for id in class_pairs:
100
+ class_folder = f"{synth_path}/{id}"
101
+ if not (os.path.exists(class_folder)):
102
+ os.makedirs(class_folder)
103
+ if not (os.path.exists(metadata_path)):
104
+ os.makedirs(metadata_path)
105
+ print("Info: Output directory ready.")
106
+
107
+
108
+ def pipe_img(
109
+ model_path,
110
+ device="cuda",
111
+ apply_optimization=True,
112
+ use_torchcompile=False,
113
+ ci_cb=(5, 1),
114
+ use_safetensors=None,
115
+ cpu_offload=False,
116
+ scheduler=None,
117
+ ):
118
+ """
119
+ Creates and returns an image-to-image pipeline for stable diffusion.
120
+ Args:
121
+ model_path (str): The path to the pretrained model.
122
+ device (str, optional): The device to use for computation. Defaults to "cuda".
123
+ apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True.
124
+ use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False.
125
+ ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1).
126
+ use_safetensors (bool, optional): Whether to use safetensors. Defaults to None.
127
+ cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False.
128
+ scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None.
129
+ Returns:
130
+ StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion.
131
+ """
132
+ ###############################
133
+ # Reference:
134
+ # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
135
+ ###############################
136
+ if scheduler is None:
137
+ scheduler = LMSDiscreteScheduler(
138
+ beta_start=0.00085,
139
+ beta_end=0.012,
140
+ beta_schedule="scaled_linear",
141
+ num_train_timesteps=1000,
142
+ steps_offset=1,
143
+ )
144
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
145
+ model_path,
146
+ scheduler=scheduler,
147
+ torch_dtype=torch.float32,
148
+ use_safetensors=use_safetensors,
149
+ ).to(device)
150
+ if cpu_offload:
151
+ pipe.enable_model_cpu_offload()
152
+ if apply_optimization:
153
+ # tomesd.apply_patch(pipe, ratio=0.5)
154
+ helper = DeepCacheSDHelper(pipe=pipe)
155
+ cache_interval, cache_branch_id = ci_cb
156
+ helper.set_params(
157
+ cache_interval=cache_interval, cache_branch_id=cache_branch_id
158
+ ) # lower is faster but lower quality
159
+ helper.enable()
160
+ # if torch.cuda.is_available():
161
+ # pipe.to("cuda")
162
+ # pipe.enable_xformers_memory_efficient_attention()
163
+ if use_torchcompile:
164
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
165
+ return pipe
166
+
167
+
168
+ def createPrompts(
169
+ class_name_pairs,
170
+ prompt_structure=None,
171
+ use_default_negative_prompt=False,
172
+ negative_prompt=None,
173
+ ):
174
+ """
175
+ Create prompts for image generation.
176
+ Args:
177
+ class_name_pairs (list): A list of two class names.
178
+ prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a <class_name>".
179
+ use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False.
180
+ negative_prompt (str, optional): The negative prompt to steer the generation away from certain features.
181
+ Returns:
182
+ tuple: A tuple containing two lists - prompts and negative_prompts.
183
+ prompts (list): Text prompts that describe the desired output image.
184
+ negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features.
185
+ """
186
+ if prompt_structure is None:
187
+ prompt_structure = "a photo of a <class_name>"
188
+ elif "<class_name>" not in prompt_structure:
189
+ raise ValueError(
190
+ "The prompt structure must contain the <class_name> placeholder."
191
+ )
192
+ if use_default_negative_prompt:
193
+ default_negative_prompt = (
194
+ "blurry image, disfigured, deformed, distorted, cartoon, drawings"
195
+ )
196
+ negative_prompt = default_negative_prompt
197
+
198
+ class1 = class_name_pairs[0]
199
+ class2 = class_name_pairs[1]
200
+ prompt1 = prompt_structure.replace("<class_name>", class1)
201
+ prompt2 = prompt_structure.replace("<class_name>", class2)
202
+ prompts = [prompt1, prompt2]
203
+ if negative_prompt is None:
204
+ print("Info: Negative prompt not provided, returning as None.")
205
+ return prompts, None
206
+ else:
207
+ # Negative prompts that can be used to steer the generation away from certain features.
208
+ negative_prompts = [negative_prompt] * len(prompts)
209
+ return prompts, negative_prompts
210
+
211
+
212
+ def interpolatePrompts(
213
+ prompts,
214
+ pipeline,
215
+ num_interpolation_steps,
216
+ sample_mid_interpolation,
217
+ remove_n_middle=0,
218
+ device="cuda",
219
+ ):
220
+ """
221
+ Interpolates prompts by generating intermediate embeddings between pairs of prompts.
222
+ Args:
223
+ prompts (List[str]): A list of prompts to be interpolated.
224
+ pipeline: The pipeline object containing the tokenizer and text encoder.
225
+ num_interpolation_steps (int): The number of interpolation steps between each pair of prompts.
226
+ sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts.
227
+ remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0.
228
+ device (str, optional): The device to run the interpolation on. Defaults to "cuda".
229
+ Returns:
230
+ interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings.
231
+ prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information.
232
+ e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2
233
+ Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
234
+ Sampled: [2, 3, 4, 5, 6, 7]
235
+ Removed: x x
236
+ Returns: [2, 3, 6, 7]
237
+ """
238
+
239
+ ###############################
240
+ # Reference:
241
+ # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
242
+ ###############################
243
+
244
+ def slerp(v0, v1, num, t0=0, t1=1):
245
+ """
246
+ Performs spherical linear interpolation between two vectors.
247
+ Args:
248
+ v0 (torch.Tensor): The starting vector.
249
+ v1 (torch.Tensor): The ending vector.
250
+ num (int): The number of interpolation points.
251
+ t0 (float, optional): The starting time. Defaults to 0.
252
+ t1 (float, optional): The ending time. Defaults to 1.
253
+ Returns:
254
+ torch.Tensor: The interpolated vectors.
255
+ """
256
+ ###############################
257
+ # Reference:
258
+ # Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024).
259
+ ###############################
260
+ v0 = v0.detach().cpu().numpy()
261
+ v1 = v1.detach().cpu().numpy()
262
+
263
+ def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995):
264
+ """helper function to spherically interpolate two arrays v1 v2"""
265
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
266
+ if np.abs(dot) > DOT_THRESHOLD:
267
+ v2 = (1 - t) * v0 + t * v1
268
+ else:
269
+ theta_0 = np.arccos(dot)
270
+ sin_theta_0 = np.sin(theta_0)
271
+ theta_t = theta_0 * t
272
+ sin_theta_t = np.sin(theta_t)
273
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
274
+ s1 = sin_theta_t / sin_theta_0
275
+ v2 = s0 * v0 + s1 * v1
276
+ return v2
277
+
278
+ t = np.linspace(t0, t1, num)
279
+
280
+ v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)]))
281
+
282
+ return v3
283
+
284
+ def get_middle_elements(lst, n):
285
+ """
286
+ Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements.
287
+ Args:
288
+ lst (list): The list from which to extract the middle elements.
289
+ n (int): The number of middle elements to extract.
290
+ Returns:
291
+ tuple: A tuple containing the sublist of middle elements and a range of indices.
292
+ Raises:
293
+ None
294
+ Examples:
295
+ lst = [1, 2, 3, 4, 5]
296
+ get_middle_elements(lst, 3)
297
+ ([2, 3, 4], range(2, 5))
298
+ """
299
+ if n % 2 == 0: # Even number of elements
300
+ middle_index = len(lst) // 2 - 1
301
+ start = middle_index - n // 2 + 1
302
+ end = middle_index + n // 2 + 1
303
+ return lst[start:end], range(start, end)
304
+ else: # Odd number of elements
305
+ middle_index = len(lst) // 2
306
+ start = middle_index - n // 2
307
+ end = middle_index + n // 2 + 1
308
+ return lst[start:end], range(start, end)
309
+
310
+ def remove_middle(data, n):
311
+ """
312
+ Remove the middle n elements from a list.
313
+ Args:
314
+ data (list): The input list.
315
+ n (int): The number of elements to remove from the middle of the list.
316
+ Returns:
317
+ list: The modified list with the middle n elements removed.
318
+ Raises:
319
+ ValueError: If n is negative or greater than the length of the list.
320
+ """
321
+ if n < 0 or n > len(data):
322
+ raise ValueError(
323
+ "Invalid value for n. It should be non-negative and less than half the list length"
324
+ )
325
+
326
+ # Find the middle index
327
+ middle = len(data) // 2
328
+
329
+ # Create slices to exclude the middle n elements
330
+ if n == 1:
331
+ return data[:middle] + data[middle + 1 :]
332
+ elif n % 2 == 0:
333
+ return data[: middle - n // 2] + data[middle + n // 2 :]
334
+ else:
335
+ return data[: middle - n // 2] + data[middle + n // 2 + 1 :]
336
+
337
+ batch_size = len(prompts)
338
+
339
+ # Tokenizing and encoding prompts into embeddings.
340
+ prompts_tokens = pipeline.tokenizer(
341
+ prompts,
342
+ padding="max_length",
343
+ max_length=pipeline.tokenizer.model_max_length,
344
+ truncation=True,
345
+ return_tensors="pt",
346
+ )
347
+ prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0]
348
+
349
+ # Interpolating between embeddings pairs for the given number of interpolation steps.
350
+ interpolated_prompt_embeds = []
351
+
352
+ for i in range(batch_size - 1):
353
+ interpolated_prompt_embeds.append(
354
+ slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps)
355
+ )
356
+
357
+ full_interpolated_prompt_embeds = interpolated_prompt_embeds[:]
358
+ interpolated_prompt_embeds[0], sample_range = get_middle_elements(
359
+ interpolated_prompt_embeds[0], sample_mid_interpolation
360
+ )
361
+
362
+ if remove_n_middle > 0:
363
+ interpolated_prompt_embeds[0] = remove_middle(
364
+ interpolated_prompt_embeds[0], remove_n_middle
365
+ )
366
+
367
+ prompt_metadata = dict()
368
+ similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
369
+ for i in range(num_interpolation_steps):
370
+ class1_sim = (
371
+ similarity(
372
+ full_interpolated_prompt_embeds[0][0],
373
+ full_interpolated_prompt_embeds[0][i],
374
+ )
375
+ .mean()
376
+ .item()
377
+ )
378
+ class2_sim = (
379
+ similarity(
380
+ full_interpolated_prompt_embeds[0][num_interpolation_steps - 1],
381
+ full_interpolated_prompt_embeds[0][i],
382
+ )
383
+ .mean()
384
+ .item()
385
+ )
386
+ relative_distance = class1_sim / (class1_sim + class2_sim)
387
+
388
+ prompt_metadata[i] = {
389
+ "selected": i in sample_range,
390
+ "similarity": {
391
+ "class1": class1_sim,
392
+ "class2": class2_sim,
393
+ "class1_relative_distance": relative_distance,
394
+ "class2_relative_distance": 1 - relative_distance,
395
+ },
396
+ "nearest_class": int(relative_distance < 0.5),
397
+ }
398
+
399
+ interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device)
400
+ return interpolated_prompt_embeds, prompt_metadata
401
+
402
+
403
+ def genClassImg(
404
+ pipeline,
405
+ pos_embed,
406
+ neg_embed,
407
+ input_image,
408
+ generator,
409
+ latents,
410
+ num_imgs=1,
411
+ height=512,
412
+ width=512,
413
+ num_inference_steps=25,
414
+ guidance_scale=7.5,
415
+ ):
416
+ """
417
+ Generate class image using the given inputs.
418
+ Args:
419
+ pipeline: The pipeline object used for image generation.
420
+ pos_embed: The positive embedding for the class.
421
+ neg_embed: The negative embedding for the class (optional).
422
+ input_image: The input image for guidance (optional).
423
+ generator: The generator model used for image generation.
424
+ latents: The latent vectors used for image generation.
425
+ num_imgs: The number of images to generate (default is 1).
426
+ height: The height of the generated images (default is 512).
427
+ width: The width of the generated images (default is 512).
428
+ num_inference_steps: The number of inference steps for image generation (default is 25).
429
+ guidance_scale: The scale factor for guidance (default is 7.5).
430
+ Returns:
431
+ The generated class image.
432
+ """
433
+
434
+ if neg_embed is not None:
435
+ npe = neg_embed[None, ...]
436
+ else:
437
+ npe = None
438
+
439
+ return pipeline(
440
+ height=height,
441
+ width=width,
442
+ num_images_per_prompt=num_imgs,
443
+ prompt_embeds=pos_embed[None, ...],
444
+ negative_prompt_embeds=npe,
445
+ num_inference_steps=num_inference_steps,
446
+ guidance_scale=guidance_scale,
447
+ generator=generator,
448
+ latents=latents,
449
+ image=input_image,
450
+ ).images[0]
451
+
452
+
453
+ def getMetadata(
454
+ class_pairs,
455
+ path,
456
+ seed,
457
+ guidance_scale,
458
+ num_inference_steps,
459
+ num_interpolation_steps,
460
+ sample_mid_interpolation,
461
+ height,
462
+ width,
463
+ prompts,
464
+ negative_prompts,
465
+ pipeline,
466
+ prompt_metadata,
467
+ negative_prompt_metadata,
468
+ ssim_metadata=None,
469
+ save_json=True,
470
+ save_path=".",
471
+ ):
472
+ """
473
+ Generate metadata for the given parameters.
474
+ Args:
475
+ class_pairs (list): List of class pairs.
476
+ path (str): Path to the data.
477
+ seed (int): Seed value for randomization.
478
+ guidance_scale (float): Scale factor for guidance.
479
+ num_inference_steps (int): Number of inference steps.
480
+ num_interpolation_steps (int): Number of interpolation steps.
481
+ sample_mid_interpolation (bool): Flag to sample mid-interpolation.
482
+ height (int): Height of the image.
483
+ width (int): Width of the image.
484
+ prompts (list): List of prompts.
485
+ negative_prompts (list): List of negative prompts.
486
+ pipeline (object): Pipeline object.
487
+ prompt_metadata (dict): Metadata for prompts.
488
+ negative_prompt_metadata (dict): Metadata for negative prompts.
489
+ ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None.
490
+ save_json (bool, optional): Flag to save metadata as JSON. Defaults to True.
491
+ save_path (str, optional): Path to save the JSON file. Defaults to ".".
492
+ Returns:
493
+ dict: Generated metadata.
494
+ """
495
+
496
+ metadata = dict()
497
+
498
+ metadata["class_pairs"] = class_pairs
499
+ metadata["path"] = path
500
+ metadata["seed"] = seed
501
+ metadata["params"] = {
502
+ "CFG": guidance_scale,
503
+ "inferenceSteps": num_inference_steps,
504
+ "interpolationSteps": num_interpolation_steps,
505
+ "sampleMidInterpolation": sample_mid_interpolation,
506
+ "height": height,
507
+ "width": width,
508
+ }
509
+ for i in range(len(prompts)):
510
+ metadata[f"prompt_text_{i}"] = prompts[i]
511
+ if negative_prompts is not None:
512
+ metadata[f"negative_prompt_text_{i}"] = negative_prompts[i]
513
+ metadata["pipe_config"] = dict(pipeline.config)
514
+ metadata["prompt_embed_similarity"] = prompt_metadata
515
+ metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata
516
+ if ssim_metadata is not None:
517
+ print("Info: SSIM scores are available.")
518
+ metadata["ssim_scores"] = ssim_metadata
519
+ if save_json:
520
+ with open(
521
+ os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"),
522
+ "w",
523
+ ) as f:
524
+ json.dump(metadata, f, indent=4)
525
+ return metadata
526
+
527
+
528
+ def groupbyInterpolation(dir_to_classfolder):
529
+ """
530
+ Group files in a directory by interpolation step.
531
+ Args:
532
+ dir_to_classfolder (str): The path to the directory containing the files.
533
+ Returns:
534
+ None
535
+ """
536
+ files = [
537
+ (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f))
538
+ for f in os.listdir(dir_to_classfolder)
539
+ ]
540
+ # create a subfolder for each step of the interpolation
541
+ for interpolation_step, file_path in files:
542
+ new_dir = os.path.join(dir_to_classfolder, interpolation_step)
543
+ if not os.path.exists(new_dir):
544
+ os.makedirs(new_dir)
545
+ os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path)))
546
+
547
+
548
+ def ungroupInterpolation(dir_to_classfolder):
549
+ """
550
+ Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself,
551
+ and then removes the subdirectories.
552
+ Args:
553
+ dir_to_classfolder (str): The path to the directory containing the subdirectories.
554
+ Returns:
555
+ None
556
+ """
557
+ for interpolation_step in os.listdir(dir_to_classfolder):
558
+ if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)):
559
+ for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)):
560
+ os.rename(
561
+ os.path.join(dir_to_classfolder, interpolation_step, f),
562
+ os.path.join(dir_to_classfolder, f),
563
+ )
564
+ os.rmdir(os.path.join(dir_to_classfolder, interpolation_step))
565
+
566
+
567
+ def groupAllbyInterpolation(
568
+ data_path,
569
+ group=True,
570
+ fn_group=groupbyInterpolation,
571
+ fn_ungroup=ungroupInterpolation,
572
+ ):
573
+ """
574
+ Group or ungroup all data classes by interpolation.
575
+ Args:
576
+ data_path (str): The path to the data.
577
+ group (bool, optional): Whether to group the data. Defaults to True.
578
+ fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation.
579
+ fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation.
580
+ """
581
+ data_classes = sorted(os.listdir(data_path))
582
+ if group:
583
+ fn = fn_group
584
+ else:
585
+ fn = fn_ungroup
586
+ for c in data_classes:
587
+ c_path = os.path.join(data_path, c)
588
+ if os.path.isdir(c_path):
589
+ fn(c_path)
590
+ print(f"Processed {c}")
591
+
592
+
593
+ def getPairIndices(subset_len, total_pair_count=1, seed=None):
594
+ """
595
+ Generate pairs of indices for a given subset length.
596
+ Args:
597
+ subset_len (int): The length of the subset.
598
+ total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1.
599
+ seed (int, optional): The seed value for the random number generator. Defaults to None.
600
+ Returns:
601
+ list: A list of pairs of indices.
602
+ """
603
+ rng = np.random.default_rng(seed)
604
+ group_size = (subset_len + total_pair_count - 1) // total_pair_count
605
+ numbers = list(range(subset_len))
606
+ numbers_selection = list(range(subset_len))
607
+ rng.shuffle(numbers)
608
+ for i in range(group_size - subset_len % group_size):
609
+ numbers.append(numbers_selection[i])
610
+ numbers = np.array(numbers)
611
+ groups = numbers[: group_size * total_pair_count].reshape(-1, group_size)
612
+ return groups.tolist()
613
+
614
+
615
+ def generateImagesFromDataset(
616
+ img_subsets,
617
+ class_iterables,
618
+ pipeline,
619
+ interpolated_prompt_embeds,
620
+ interpolated_negative_prompts_embeds,
621
+ num_inference_steps,
622
+ guidance_scale,
623
+ height=512,
624
+ width=512,
625
+ seed=None,
626
+ save_path=".",
627
+ class_pairs=("0", "1"),
628
+ save_image=True,
629
+ image_type="jpg",
630
+ interpolate_range="full",
631
+ device="cuda",
632
+ return_images=False,
633
+ ):
634
+ """
635
+ Generates images from a dataset using the given parameters.
636
+ Args:
637
+ img_subsets (dict): A dictionary containing image subsets for each class.
638
+ class_iterables (dict): A dictionary containing iterable objects for each class.
639
+ pipeline (object): The pipeline object used for image generation.
640
+ interpolated_prompt_embeds (list): A list of interpolated prompt embeddings.
641
+ interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings.
642
+ num_inference_steps (int): The number of inference steps for image generation.
643
+ guidance_scale (float): The scale factor for guidance loss during image generation.
644
+ height (int, optional): The height of the generated images. Defaults to 512.
645
+ width (int, optional): The width of the generated images. Defaults to 512.
646
+ seed (int, optional): The seed value for random number generation. Defaults to None.
647
+ save_path (str, optional): The path to save the generated images. Defaults to ".".
648
+ class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1").
649
+ save_image (bool, optional): Whether to save the generated images. Defaults to True.
650
+ image_type (str, optional): The file format of the saved images. Defaults to "jpg".
651
+ interpolate_range (str, optional): The range of interpolation for prompt embeddings.
652
+ Possible values are "full", "nearest", or "furthest". Defaults to "full".
653
+ device (str, optional): The device to use for image generation. Defaults to "cuda".
654
+ return_images (bool, optional): Whether to return the generated images. Defaults to False.
655
+ Returns:
656
+ dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step.
657
+ If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step.
658
+ """
659
+ if interpolate_range == "nearest":
660
+ nearest_half = True
661
+ furthest_half = False
662
+ elif interpolate_range == "furthest":
663
+ nearest_half = False
664
+ furthest_half = True
665
+ else:
666
+ nearest_half = False
667
+ furthest_half = False
668
+
669
+ if seed is None:
670
+ seed = torch.Generator().seed()
671
+ generator = torch.manual_seed(seed)
672
+ rng = np.random.default_rng(seed)
673
+ # Generating initial U-Net latent vectors from a random normal distribution.
674
+ latents = torch.randn(
675
+ (1, pipeline.unet.config.in_channels, height // 8, width // 8),
676
+ generator=generator,
677
+ ).to(device)
678
+
679
+ embed_len = len(interpolated_prompt_embeds)
680
+ embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
681
+ embed_pairs_list = list(embed_pairs)
682
+ if return_images:
683
+ class_images = dict()
684
+ class_ssim = dict()
685
+
686
+ if nearest_half or furthest_half:
687
+ if nearest_half:
688
+ steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
689
+ mutiplier = 2
690
+ elif furthest_half:
691
+ # uses opposite class of images of the text interpolation
692
+ steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
693
+ mutiplier = 2
694
+ else:
695
+ steps_range = (range(embed_len), range(embed_len))
696
+ mutiplier = 1
697
+
698
+ for class_iter, class_id in enumerate(class_pairs):
699
+ if return_images:
700
+ class_images[class_id] = list()
701
+ class_ssim[class_id] = {
702
+ i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len)
703
+ }
704
+ subset_len = len(img_subsets[class_id])
705
+ # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
706
+ # group_map: index is the image id, element is the group id
707
+ # steps_range[class_iter] determines the range of steps to interpolate for the class,
708
+ # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
709
+ # then the rest is to multiply the steps to cover the whole subset + remainder
710
+ group_map = (
711
+ list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
712
+ )
713
+ rng.shuffle(
714
+ group_map
715
+ ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
716
+
717
+ iter_indices = class_iterables[class_id].pop()
718
+ # generate images for each image in the class, randomly selecting an interpolated step
719
+ for image_id in iter_indices:
720
+ img, trg = img_subsets[class_id][image_id]
721
+ input_image = img.unsqueeze(0)
722
+ interpolate_step = group_map[image_id]
723
+ prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step]
724
+ generated_image = genClassImg(
725
+ pipeline,
726
+ prompt_embeds,
727
+ negative_prompt_embeds,
728
+ input_image,
729
+ generator,
730
+ latents,
731
+ num_imgs=1,
732
+ height=height,
733
+ width=width,
734
+ num_inference_steps=num_inference_steps,
735
+ guidance_scale=guidance_scale,
736
+ )
737
+ pred_image = transforms.ToTensor()(generated_image).unsqueeze(0)
738
+ ssim_score = ssim(pred_image, input_image).item()
739
+ class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score
740
+ class_ssim[class_id][interpolate_step]["ssim_count"] += 1
741
+ if return_images:
742
+ class_images[class_id].append(generated_image)
743
+ if save_image:
744
+ if image_type == "jpg":
745
+ generated_image.save(
746
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
747
+ format="JPEG",
748
+ quality=95,
749
+ )
750
+ elif image_type == "png":
751
+ generated_image.save(
752
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
753
+ format="PNG",
754
+ )
755
+ else:
756
+ generated_image.save(
757
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
758
+ )
759
+
760
+ # calculate ssim avg for the class
761
+ for i_step in range(embed_len):
762
+ if class_ssim[class_id][i_step]["ssim_count"] > 0:
763
+ class_ssim[class_id][i_step]["ssim_avg"] = (
764
+ class_ssim[class_id][i_step]["ssim_sum"]
765
+ / class_ssim[class_id][i_step]["ssim_count"]
766
+ )
767
+
768
+ if return_images:
769
+ return class_images, class_ssim
770
+ else:
771
+ return class_ssim
772
+
773
+
774
+ def generateTrace(
775
+ prompts,
776
+ img_subsets,
777
+ class_iterables,
778
+ interpolated_prompt_embeds,
779
+ interpolated_negative_prompts_embeds,
780
+ subset_indices,
781
+ seed=None,
782
+ save_path=".",
783
+ class_pairs=("0", "1"),
784
+ image_type="jpg",
785
+ interpolate_range="full",
786
+ save_prompt_embeds=False,
787
+ ):
788
+ """
789
+ Generate a trace dictionary containing information about the generated images.
790
+ Args:
791
+ prompts (list): List of prompt texts.
792
+ img_subsets (dict): Dictionary containing image subsets for each class.
793
+ class_iterables (dict): Dictionary containing iterable objects for each class.
794
+ interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings.
795
+ interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings.
796
+ subset_indices (dict): Dictionary containing indices of subsets for each class.
797
+ seed (int, optional): Seed value for random number generation. Defaults to None.
798
+ save_path (str, optional): Path to save the generated images. Defaults to ".".
799
+ class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1").
800
+ image_type (str, optional): Type of the generated images. Defaults to "jpg".
801
+ interpolate_range (str, optional): Range of interpolation. Defaults to "full".
802
+ save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False.
803
+ Returns:
804
+ dict: Trace dictionary containing information about the generated images.
805
+ """
806
+ trace_dict = {
807
+ "class_pairs": list(),
808
+ "class_id": list(),
809
+ "image_id": list(),
810
+ "interpolation_step": list(),
811
+ "embed_len": list(),
812
+ "pos_prompt_text": list(),
813
+ "neg_prompt_text": list(),
814
+ "input_file_path": list(),
815
+ "output_file_path": list(),
816
+ "input_prompts_embed": list(),
817
+ }
818
+
819
+ if interpolate_range == "nearest":
820
+ nearest_half = True
821
+ furthest_half = False
822
+ elif interpolate_range == "furthest":
823
+ nearest_half = False
824
+ furthest_half = True
825
+ else:
826
+ nearest_half = False
827
+ furthest_half = False
828
+
829
+ if seed is None:
830
+ seed = torch.Generator().seed()
831
+ rng = np.random.default_rng(seed)
832
+
833
+ embed_len = len(interpolated_prompt_embeds)
834
+ embed_pairs = zip(
835
+ interpolated_prompt_embeds.cpu().numpy(),
836
+ interpolated_negative_prompts_embeds.cpu().numpy(),
837
+ )
838
+ embed_pairs_list = list(embed_pairs)
839
+
840
+ if nearest_half or furthest_half:
841
+ if nearest_half:
842
+ steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
843
+ mutiplier = 2
844
+ elif furthest_half:
845
+ # uses opposite class of images of the text interpolation
846
+ steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
847
+ mutiplier = 2
848
+ else:
849
+ steps_range = (range(embed_len), range(embed_len))
850
+ mutiplier = 1
851
+
852
+ for class_iter, class_id in enumerate(class_pairs):
853
+
854
+ subset_len = len(img_subsets[class_id])
855
+ # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
856
+ # group_map: index is the image id, element is the group id
857
+ # steps_range[class_iter] determines the range of steps to interpolate for the class,
858
+ # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
859
+ # then the rest is to multiply the steps to cover the whole subset + remainder
860
+ group_map = (
861
+ list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
862
+ )
863
+ rng.shuffle(
864
+ group_map
865
+ ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
866
+
867
+ iter_indices = class_iterables[class_id].pop()
868
+ # generate images for each image in the class, randomly selecting an interpolated step
869
+ for image_id in iter_indices:
870
+ class_ds = img_subsets[class_id]
871
+ interpolate_step = group_map[image_id]
872
+ sample_count = subset_indices[class_id][0] + image_id
873
+ input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0])
874
+ pos_prompt = prompts[0]
875
+ neg_prompt = prompts[1]
876
+ output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
877
+ if save_prompt_embeds:
878
+ input_prompts_embed = embed_pairs_list[interpolate_step]
879
+ else:
880
+ input_prompts_embed = None
881
+
882
+ trace_dict["class_pairs"].append(class_pairs)
883
+ trace_dict["class_id"].append(class_id)
884
+ trace_dict["image_id"].append(image_id)
885
+ trace_dict["interpolation_step"].append(interpolate_step)
886
+ trace_dict["embed_len"].append(embed_len)
887
+ trace_dict["pos_prompt_text"].append(pos_prompt)
888
+ trace_dict["neg_prompt_text"].append(neg_prompt)
889
+ trace_dict["input_file_path"].append(input_file)
890
+ trace_dict["output_file_path"].append(output_file)
891
+ trace_dict["input_prompts_embed"].append(input_prompts_embed)
892
+
893
+ return trace_dict