import json import os import numpy as np import pandas as pd import torch from DeepCache import DeepCacheSDHelper from diffusers import ( LMSDiscreteScheduler, StableDiffusionImg2ImgPipeline, ) from torch import nn from torchmetrics.functional.image import structural_similarity_index_measure as ssim from torchvision import transforms def get_top_misclassified(val_classifier_json): with open(val_classifier_json) as f: val_output = json.load(f) val_metrics_df = pd.DataFrame.from_dict( val_output["val_metrics_details"], orient="index" ) class_dict = dict() for k, v in val_metrics_df["top_n_classes"].items(): class_dict[k] = v return class_dict def get_class_list(val_classifier_json): with open(val_classifier_json, "r") as f: data = json.load(f) return sorted(list(data["val_metrics_details"].keys())) def generateClassPairs(val_classifier_json): pairs = set() misclassified_classes = get_top_misclassified(val_classifier_json) for key, value in misclassified_classes.items(): for v in value: pairs.add(tuple(sorted([key, v]))) return sorted(list(pairs)) def outputDirectory(class_pairs, synth_path, metadata_path): for id in class_pairs: class_folder = f"{synth_path}/{id}" if not (os.path.exists(class_folder)): os.makedirs(class_folder) if not (os.path.exists(metadata_path)): os.makedirs(metadata_path) print("Info: Output directory ready.") def pipe_img( model_path, device="cuda", apply_optimization=True, use_torchcompile=False, ci_cb=(5, 1), use_safetensors=None, cpu_offload=False, scheduler=None, ): if scheduler is None: scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, steps_offset=1, ) pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_path, scheduler=scheduler, torch_dtype=torch.float32, use_safetensors=use_safetensors, ).to(device) if cpu_offload: pipe.enable_model_cpu_offload() if apply_optimization: helper = DeepCacheSDHelper(pipe=pipe) cache_interval, cache_branch_id = ci_cb helper.set_params( cache_interval=cache_interval, cache_branch_id=cache_branch_id ) helper.enable() if use_torchcompile: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) return pipe def createPrompts( class_name_pairs, prompt_structure=None, use_default_negative_prompt=False, negative_prompt=None, ): if prompt_structure is None: prompt_structure = "a photo of a " elif "" not in prompt_structure: raise ValueError( "The prompt structure must contain the placeholder." ) if use_default_negative_prompt: default_negative_prompt = ( "blurry image, disfigured, deformed, distorted, cartoon, drawings" ) negative_prompt = default_negative_prompt class1 = class_name_pairs[0] class2 = class_name_pairs[1] prompt1 = prompt_structure.replace("", class1) prompt2 = prompt_structure.replace("", class2) prompts = [prompt1, prompt2] if negative_prompt is None: print("Info: Negative prompt not provided, returning as None.") return prompts, None else: negative_prompts = [negative_prompt] * len(prompts) return prompts, negative_prompts def interpolatePrompts( prompts, pipeline, num_interpolation_steps, sample_mid_interpolation, remove_n_middle=0, device="cuda", ): def slerp(v0, v1, num, t0=0, t1=1): v0 = v0.detach().cpu().numpy() v1 = v1.detach().cpu().numpy() def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995): dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 return v2 t = np.linspace(t0, t1, num) v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)])) return v3 def get_middle_elements(lst, n): if n % 2 == 0: # Even number of elements middle_index = len(lst) // 2 - 1 start = middle_index - n // 2 + 1 end = middle_index + n // 2 + 1 return lst[start:end], range(start, end) else: # Odd number of elements middle_index = len(lst) // 2 start = middle_index - n // 2 end = middle_index + n // 2 + 1 return lst[start:end], range(start, end) def remove_middle(data, n): if n < 0 or n > len(data): raise ValueError( "Invalid value for n. It should be non-negative and less than half the list length" ) middle = len(data) // 2 if n == 1: return data[:middle] + data[middle + 1 :] elif n % 2 == 0: return data[: middle - n // 2] + data[middle + n // 2 :] else: return data[: middle - n // 2] + data[middle + n // 2 + 1 :] batch_size = len(prompts) prompts_tokens = pipeline.tokenizer( prompts, padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0] interpolated_prompt_embeds = [] for i in range(batch_size - 1): interpolated_prompt_embeds.append( slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps) ) full_interpolated_prompt_embeds = interpolated_prompt_embeds[:] interpolated_prompt_embeds[0], sample_range = get_middle_elements( interpolated_prompt_embeds[0], sample_mid_interpolation ) if remove_n_middle > 0: interpolated_prompt_embeds[0] = remove_middle( interpolated_prompt_embeds[0], remove_n_middle ) prompt_metadata = dict() similarity = nn.CosineSimilarity(dim=-1, eps=1e-6) for i in range(num_interpolation_steps): class1_sim = ( similarity( full_interpolated_prompt_embeds[0][0], full_interpolated_prompt_embeds[0][i], ) .mean() .item() ) class2_sim = ( similarity( full_interpolated_prompt_embeds[0][num_interpolation_steps - 1], full_interpolated_prompt_embeds[0][i], ) .mean() .item() ) relative_distance = class1_sim / (class1_sim + class2_sim) prompt_metadata[i] = { "selected": i in sample_range, "similarity": { "class1": class1_sim, "class2": class2_sim, "class1_relative_distance": relative_distance, "class2_relative_distance": 1 - relative_distance, }, "nearest_class": int(relative_distance < 0.5), } interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device) return interpolated_prompt_embeds, prompt_metadata def genClassImg( pipeline, pos_embed, neg_embed, input_image, generator, latents, num_imgs=1, height=512, width=512, num_inference_steps=25, guidance_scale=7.5, ): if neg_embed is not None: npe = neg_embed[None, ...] else: npe = None return pipeline( height=height, width=width, num_images_per_prompt=num_imgs, prompt_embeds=pos_embed[None, ...], negative_prompt_embeds=npe, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, latents=latents, image=input_image, ).images[0] def getMetadata( class_pairs, path, seed, guidance_scale, num_inference_steps, num_interpolation_steps, sample_mid_interpolation, height, width, prompts, negative_prompts, pipeline, prompt_metadata, negative_prompt_metadata, ssim_metadata=None, save_json=True, save_path=".", ): metadata = dict() metadata["class_pairs"] = class_pairs metadata["path"] = path metadata["seed"] = seed metadata["params"] = { "CFG": guidance_scale, "inferenceSteps": num_inference_steps, "interpolationSteps": num_interpolation_steps, "sampleMidInterpolation": sample_mid_interpolation, "height": height, "width": width, } for i in range(len(prompts)): metadata[f"prompt_text_{i}"] = prompts[i] if negative_prompts is not None: metadata[f"negative_prompt_text_{i}"] = negative_prompts[i] metadata["pipe_config"] = dict(pipeline.config) metadata["prompt_embed_similarity"] = prompt_metadata metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata if ssim_metadata is not None: print("Info: SSIM scores are available.") metadata["ssim_scores"] = ssim_metadata if save_json: with open( os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"), "w", ) as f: json.dump(metadata, f, indent=4) return metadata def groupbyInterpolation(dir_to_classfolder): files = [ (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f)) for f in os.listdir(dir_to_classfolder) ] for interpolation_step, file_path in files: new_dir = os.path.join(dir_to_classfolder, interpolation_step) if not os.path.exists(new_dir): os.makedirs(new_dir) os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path))) def ungroupInterpolation(dir_to_classfolder): for interpolation_step in os.listdir(dir_to_classfolder): if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)): for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)): os.rename( os.path.join(dir_to_classfolder, interpolation_step, f), os.path.join(dir_to_classfolder, f), ) os.rmdir(os.path.join(dir_to_classfolder, interpolation_step)) def groupAllbyInterpolation( data_path, group=True, fn_group=groupbyInterpolation, fn_ungroup=ungroupInterpolation, ): data_classes = sorted(os.listdir(data_path)) if group: fn = fn_group else: fn = fn_ungroup for c in data_classes: c_path = os.path.join(data_path, c) if os.path.isdir(c_path): fn(c_path) print(f"Processed {c}") def getPairIndices(subset_len, total_pair_count=1, seed=None): rng = np.random.default_rng(seed) group_size = (subset_len + total_pair_count - 1) // total_pair_count numbers = list(range(subset_len)) numbers_selection = list(range(subset_len)) rng.shuffle(numbers) for i in range(group_size - subset_len % group_size): numbers.append(numbers_selection[i]) numbers = np.array(numbers) groups = numbers[: group_size * total_pair_count].reshape(-1, group_size) return groups.tolist() def generateImagesFromDataset( img_subsets, class_iterables, pipeline, interpolated_prompt_embeds, interpolated_negative_prompts_embeds, num_inference_steps, guidance_scale, height=512, width=512, seed=None, save_path=".", class_pairs=("0", "1"), save_image=True, image_type="jpg", interpolate_range="full", device="cuda", return_images=False, ): if interpolate_range == "nearest": nearest_half = True furthest_half = False elif interpolate_range == "furthest": nearest_half = False furthest_half = True else: nearest_half = False furthest_half = False if seed is None: seed = torch.Generator().seed() generator = torch.manual_seed(seed) rng = np.random.default_rng(seed) # Generating initial U-Net latent vectors from a random normal distribution. latents = torch.randn( (1, pipeline.unet.config.in_channels, height // 8, width // 8), generator=generator, ).to(device) embed_len = len(interpolated_prompt_embeds) embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds) embed_pairs_list = list(embed_pairs) if return_images: class_images = dict() class_ssim = dict() if nearest_half or furthest_half: if nearest_half: steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) mutiplier = 2 elif furthest_half: # uses opposite class of images of the text interpolation steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) mutiplier = 2 else: steps_range = (range(embed_len), range(embed_len)) mutiplier = 1 for class_iter, class_id in enumerate(class_pairs): if return_images: class_images[class_id] = list() class_ssim[class_id] = { i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len) } subset_len = len(img_subsets[class_id]) group_map = ( list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) ) rng.shuffle( group_map ) iter_indices = class_iterables[class_id].pop() for image_id in iter_indices: img, trg = img_subsets[class_id][image_id] input_image = img.unsqueeze(0) interpolate_step = group_map[image_id] prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step] generated_image = genClassImg( pipeline, prompt_embeds, negative_prompt_embeds, input_image, generator, latents, num_imgs=1, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) pred_image = transforms.ToTensor()(generated_image).unsqueeze(0) ssim_score = ssim(pred_image, input_image).item() class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score class_ssim[class_id][interpolate_step]["ssim_count"] += 1 if return_images: class_images[class_id].append(generated_image) if save_image: if image_type == "jpg": generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", format="JPEG", quality=95, ) elif image_type == "png": generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", format="PNG", ) else: generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" ) for i_step in range(embed_len): if class_ssim[class_id][i_step]["ssim_count"] > 0: class_ssim[class_id][i_step]["ssim_avg"] = ( class_ssim[class_id][i_step]["ssim_sum"] / class_ssim[class_id][i_step]["ssim_count"] ) if return_images: return class_images, class_ssim else: return class_ssim def generateTrace( prompts, img_subsets, class_iterables, interpolated_prompt_embeds, interpolated_negative_prompts_embeds, subset_indices, seed=None, save_path=".", class_pairs=("0", "1"), image_type="jpg", interpolate_range="full", save_prompt_embeds=False, ): trace_dict = { "class_pairs": list(), "class_id": list(), "image_id": list(), "interpolation_step": list(), "embed_len": list(), "pos_prompt_text": list(), "neg_prompt_text": list(), "input_file_path": list(), "output_file_path": list(), "input_prompts_embed": list(), } if interpolate_range == "nearest": nearest_half = True furthest_half = False elif interpolate_range == "furthest": nearest_half = False furthest_half = True else: nearest_half = False furthest_half = False if seed is None: seed = torch.Generator().seed() rng = np.random.default_rng(seed) embed_len = len(interpolated_prompt_embeds) embed_pairs = zip( interpolated_prompt_embeds.cpu().numpy(), interpolated_negative_prompts_embeds.cpu().numpy(), ) embed_pairs_list = list(embed_pairs) if nearest_half or furthest_half: if nearest_half: steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) mutiplier = 2 elif furthest_half: # uses opposite class of images of the text interpolation steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) mutiplier = 2 else: steps_range = (range(embed_len), range(embed_len)) mutiplier = 1 for class_iter, class_id in enumerate(class_pairs): subset_len = len(img_subsets[class_id]) group_map = ( list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) ) rng.shuffle( group_map ) iter_indices = class_iterables[class_id].pop() for image_id in iter_indices: class_ds = img_subsets[class_id] interpolate_step = group_map[image_id] sample_count = subset_indices[class_id][0] + image_id input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0]) pos_prompt = prompts[0] neg_prompt = prompts[1] output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" if save_prompt_embeds: input_prompts_embed = embed_pairs_list[interpolate_step] else: input_prompts_embed = None trace_dict["class_pairs"].append(class_pairs) trace_dict["class_id"].append(class_id) trace_dict["image_id"].append(image_id) trace_dict["interpolation_step"].append(interpolate_step) trace_dict["embed_len"].append(embed_len) trace_dict["pos_prompt_text"].append(pos_prompt) trace_dict["neg_prompt_text"].append(neg_prompt) trace_dict["input_file_path"].append(input_file) trace_dict["output_file_path"].append(output_file) trace_dict["input_prompts_embed"].append(input_prompts_embed) return trace_dict