Spaces:
Sleeping
Sleeping
| """ | |
| Helper scripts for generating synthetic images using diffusion model. | |
| Functions: | |
| - get_top_misclassified | |
| - get_class_list | |
| - generateClassPairs | |
| - outputDirectory | |
| - pipe_img | |
| - createPrompts | |
| - interpolatePrompts | |
| - slerp | |
| - get_middle_elements | |
| - remove_middle | |
| - genClassImg | |
| - getMetadata | |
| - groupbyInterpolation | |
| - ungroupInterpolation | |
| - groupAllbyInterpolation | |
| - getPairIndices | |
| - generateImagesFromDataset | |
| - generateTrace | |
| """ | |
| import json | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from DeepCache import DeepCacheSDHelper | |
| from diffusers import ( | |
| LMSDiscreteScheduler, | |
| StableDiffusionImg2ImgPipeline, | |
| ) | |
| from torch import nn | |
| from torchmetrics.functional.image import structural_similarity_index_measure as ssim | |
| from torchvision import transforms | |
| def get_top_misclassified(val_classifier_json): | |
| """ | |
| Retrieves the top misclassified classes from a validation classifier JSON file. | |
| Args: | |
| val_classifier_json (str): The path to the validation classifier JSON file. | |
| Returns: | |
| dict: A dictionary containing the top misclassified classes, where the keys are the class names | |
| and the values are the number of misclassifications. | |
| """ | |
| with open(val_classifier_json) as f: | |
| val_output = json.load(f) | |
| val_metrics_df = pd.DataFrame.from_dict( | |
| val_output["val_metrics_details"], orient="index" | |
| ) | |
| class_dict = dict() | |
| for k, v in val_metrics_df["top_n_classes"].items(): | |
| class_dict[k] = v | |
| return class_dict | |
| def get_class_list(val_classifier_json): | |
| """ | |
| Retrieves the list of classes from the given validation classifier JSON file. | |
| Args: | |
| val_classifier_json (str): The path to the validation classifier JSON file. | |
| Returns: | |
| list: A sorted list of class names extracted from the JSON file. | |
| """ | |
| with open(val_classifier_json, "r") as f: | |
| data = json.load(f) | |
| return sorted(list(data["val_metrics_details"].keys())) | |
| def generateClassPairs(val_classifier_json): | |
| """ | |
| Generate pairs of misclassified classes from the given validation classifier JSON. | |
| Args: | |
| val_classifier_json (str): The path to the validation classifier JSON file. | |
| Returns: | |
| list: A sorted list of pairs of misclassified classes. | |
| """ | |
| pairs = set() | |
| misclassified_classes = get_top_misclassified(val_classifier_json) | |
| for key, value in misclassified_classes.items(): | |
| for v in value: | |
| pairs.add(tuple(sorted([key, v]))) | |
| return sorted(list(pairs)) | |
| def outputDirectory(class_pairs, synth_path, metadata_path): | |
| """ | |
| Creates the output directory structure for the synthesized data. | |
| Args: | |
| class_pairs (list): A list of class pairs. | |
| synth_path (str): The path to the directory where the synthesized data will be stored. | |
| metadata_path (str): The path to the directory where the metadata will be stored. | |
| Returns: | |
| None | |
| """ | |
| for id in class_pairs: | |
| class_folder = f"{synth_path}/{id}" | |
| if not (os.path.exists(class_folder)): | |
| os.makedirs(class_folder) | |
| if not (os.path.exists(metadata_path)): | |
| os.makedirs(metadata_path) | |
| print("Info: Output directory ready.") | |
| def pipe_img( | |
| model_path, | |
| device="cuda", | |
| apply_optimization=True, | |
| use_torchcompile=False, | |
| ci_cb=(5, 1), | |
| use_safetensors=None, | |
| cpu_offload=False, | |
| scheduler=None, | |
| ): | |
| """ | |
| Creates and returns an image-to-image pipeline for stable diffusion. | |
| Args: | |
| model_path (str): The path to the pretrained model. | |
| device (str, optional): The device to use for computation. Defaults to "cuda". | |
| apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True. | |
| use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False. | |
| ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1). | |
| use_safetensors (bool, optional): Whether to use safetensors. Defaults to None. | |
| cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False. | |
| scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None. | |
| Returns: | |
| StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion. | |
| """ | |
| ############################### | |
| # Reference: | |
| # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024). | |
| ############################### | |
| if scheduler is None: | |
| scheduler = LMSDiscreteScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| steps_offset=1, | |
| ) | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| scheduler=scheduler, | |
| torch_dtype=torch.float32, | |
| use_safetensors=use_safetensors, | |
| ).to(device) | |
| if cpu_offload: | |
| pipe.enable_model_cpu_offload() | |
| if apply_optimization: | |
| # tomesd.apply_patch(pipe, ratio=0.5) | |
| helper = DeepCacheSDHelper(pipe=pipe) | |
| cache_interval, cache_branch_id = ci_cb | |
| helper.set_params( | |
| cache_interval=cache_interval, cache_branch_id=cache_branch_id | |
| ) # lower is faster but lower quality | |
| helper.enable() | |
| # if torch.cuda.is_available(): | |
| # pipe.to("cuda") | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| if use_torchcompile: | |
| pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) | |
| return pipe | |
| def createPrompts( | |
| class_name_pairs, | |
| prompt_structure=None, | |
| use_default_negative_prompt=False, | |
| negative_prompt=None, | |
| ): | |
| """ | |
| Create prompts for image generation. | |
| Args: | |
| class_name_pairs (list): A list of two class names. | |
| prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a <class_name>". | |
| use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False. | |
| negative_prompt (str, optional): The negative prompt to steer the generation away from certain features. | |
| Returns: | |
| tuple: A tuple containing two lists - prompts and negative_prompts. | |
| prompts (list): Text prompts that describe the desired output image. | |
| negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features. | |
| """ | |
| if prompt_structure is None: | |
| prompt_structure = "a photo of a <class_name>" | |
| elif "<class_name>" not in prompt_structure: | |
| raise ValueError( | |
| "The prompt structure must contain the <class_name> placeholder." | |
| ) | |
| if use_default_negative_prompt: | |
| default_negative_prompt = ( | |
| "blurry image, disfigured, deformed, distorted, cartoon, drawings" | |
| ) | |
| negative_prompt = default_negative_prompt | |
| class1 = class_name_pairs[0] | |
| class2 = class_name_pairs[1] | |
| prompt1 = prompt_structure.replace("<class_name>", class1) | |
| prompt2 = prompt_structure.replace("<class_name>", class2) | |
| prompts = [prompt1, prompt2] | |
| if negative_prompt is None: | |
| print("Info: Negative prompt not provided, returning as None.") | |
| return prompts, None | |
| else: | |
| # Negative prompts that can be used to steer the generation away from certain features. | |
| negative_prompts = [negative_prompt] * len(prompts) | |
| return prompts, negative_prompts | |
| def interpolatePrompts( | |
| prompts, | |
| pipeline, | |
| num_interpolation_steps, | |
| sample_mid_interpolation, | |
| remove_n_middle=0, | |
| device="cuda", | |
| ): | |
| """ | |
| Interpolates prompts by generating intermediate embeddings between pairs of prompts. | |
| Args: | |
| prompts (List[str]): A list of prompts to be interpolated. | |
| pipeline: The pipeline object containing the tokenizer and text encoder. | |
| num_interpolation_steps (int): The number of interpolation steps between each pair of prompts. | |
| sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts. | |
| remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0. | |
| device (str, optional): The device to run the interpolation on. Defaults to "cuda". | |
| Returns: | |
| interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings. | |
| prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information. | |
| e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2 | |
| Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |
| Sampled: [2, 3, 4, 5, 6, 7] | |
| Removed: x x | |
| Returns: [2, 3, 6, 7] | |
| """ | |
| ############################### | |
| # Reference: | |
| # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024). | |
| ############################### | |
| def slerp(v0, v1, num, t0=0, t1=1): | |
| """ | |
| Performs spherical linear interpolation between two vectors. | |
| Args: | |
| v0 (torch.Tensor): The starting vector. | |
| v1 (torch.Tensor): The ending vector. | |
| num (int): The number of interpolation points. | |
| t0 (float, optional): The starting time. Defaults to 0. | |
| t1 (float, optional): The ending time. Defaults to 1. | |
| Returns: | |
| torch.Tensor: The interpolated vectors. | |
| """ | |
| ############################### | |
| # Reference: | |
| # Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024). | |
| ############################### | |
| v0 = v0.detach().cpu().numpy() | |
| v1 = v1.detach().cpu().numpy() | |
| def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995): | |
| """helper function to spherically interpolate two arrays v1 v2""" | |
| dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) | |
| if np.abs(dot) > DOT_THRESHOLD: | |
| v2 = (1 - t) * v0 + t * v1 | |
| else: | |
| theta_0 = np.arccos(dot) | |
| sin_theta_0 = np.sin(theta_0) | |
| theta_t = theta_0 * t | |
| sin_theta_t = np.sin(theta_t) | |
| s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
| s1 = sin_theta_t / sin_theta_0 | |
| v2 = s0 * v0 + s1 * v1 | |
| return v2 | |
| t = np.linspace(t0, t1, num) | |
| v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)])) | |
| return v3 | |
| def get_middle_elements(lst, n): | |
| """ | |
| Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements. | |
| Args: | |
| lst (list): The list from which to extract the middle elements. | |
| n (int): The number of middle elements to extract. | |
| Returns: | |
| tuple: A tuple containing the sublist of middle elements and a range of indices. | |
| Raises: | |
| None | |
| Examples: | |
| lst = [1, 2, 3, 4, 5] | |
| get_middle_elements(lst, 3) | |
| ([2, 3, 4], range(2, 5)) | |
| """ | |
| if n % 2 == 0: # Even number of elements | |
| middle_index = len(lst) // 2 - 1 | |
| start = middle_index - n // 2 + 1 | |
| end = middle_index + n // 2 + 1 | |
| return lst[start:end], range(start, end) | |
| else: # Odd number of elements | |
| middle_index = len(lst) // 2 | |
| start = middle_index - n // 2 | |
| end = middle_index + n // 2 + 1 | |
| return lst[start:end], range(start, end) | |
| def remove_middle(data, n): | |
| """ | |
| Remove the middle n elements from a list. | |
| Args: | |
| data (list): The input list. | |
| n (int): The number of elements to remove from the middle of the list. | |
| Returns: | |
| list: The modified list with the middle n elements removed. | |
| Raises: | |
| ValueError: If n is negative or greater than the length of the list. | |
| """ | |
| if n < 0 or n > len(data): | |
| raise ValueError( | |
| "Invalid value for n. It should be non-negative and less than half the list length" | |
| ) | |
| # Find the middle index | |
| middle = len(data) // 2 | |
| # Create slices to exclude the middle n elements | |
| if n == 1: | |
| return data[:middle] + data[middle + 1 :] | |
| elif n % 2 == 0: | |
| return data[: middle - n // 2] + data[middle + n // 2 :] | |
| else: | |
| return data[: middle - n // 2] + data[middle + n // 2 + 1 :] | |
| batch_size = len(prompts) | |
| # Tokenizing and encoding prompts into embeddings. | |
| prompts_tokens = pipeline.tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=pipeline.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0] | |
| # Interpolating between embeddings pairs for the given number of interpolation steps. | |
| interpolated_prompt_embeds = [] | |
| for i in range(batch_size - 1): | |
| interpolated_prompt_embeds.append( | |
| slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps) | |
| ) | |
| full_interpolated_prompt_embeds = interpolated_prompt_embeds[:] | |
| interpolated_prompt_embeds[0], sample_range = get_middle_elements( | |
| interpolated_prompt_embeds[0], sample_mid_interpolation | |
| ) | |
| if remove_n_middle > 0: | |
| interpolated_prompt_embeds[0] = remove_middle( | |
| interpolated_prompt_embeds[0], remove_n_middle | |
| ) | |
| prompt_metadata = dict() | |
| similarity = nn.CosineSimilarity(dim=-1, eps=1e-6) | |
| for i in range(num_interpolation_steps): | |
| class1_sim = ( | |
| similarity( | |
| full_interpolated_prompt_embeds[0][0], | |
| full_interpolated_prompt_embeds[0][i], | |
| ) | |
| .mean() | |
| .item() | |
| ) | |
| class2_sim = ( | |
| similarity( | |
| full_interpolated_prompt_embeds[0][num_interpolation_steps - 1], | |
| full_interpolated_prompt_embeds[0][i], | |
| ) | |
| .mean() | |
| .item() | |
| ) | |
| relative_distance = class1_sim / (class1_sim + class2_sim) | |
| prompt_metadata[i] = { | |
| "selected": i in sample_range, | |
| "similarity": { | |
| "class1": class1_sim, | |
| "class2": class2_sim, | |
| "class1_relative_distance": relative_distance, | |
| "class2_relative_distance": 1 - relative_distance, | |
| }, | |
| "nearest_class": int(relative_distance < 0.5), | |
| } | |
| interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device) | |
| return interpolated_prompt_embeds, prompt_metadata | |
| def genClassImg( | |
| pipeline, | |
| pos_embed, | |
| neg_embed, | |
| input_image, | |
| generator, | |
| latents, | |
| num_imgs=1, | |
| height=512, | |
| width=512, | |
| num_inference_steps=25, | |
| guidance_scale=7.5, | |
| ): | |
| """ | |
| Generate class image using the given inputs. | |
| Args: | |
| pipeline: The pipeline object used for image generation. | |
| pos_embed: The positive embedding for the class. | |
| neg_embed: The negative embedding for the class (optional). | |
| input_image: The input image for guidance (optional). | |
| generator: The generator model used for image generation. | |
| latents: The latent vectors used for image generation. | |
| num_imgs: The number of images to generate (default is 1). | |
| height: The height of the generated images (default is 512). | |
| width: The width of the generated images (default is 512). | |
| num_inference_steps: The number of inference steps for image generation (default is 25). | |
| guidance_scale: The scale factor for guidance (default is 7.5). | |
| Returns: | |
| The generated class image. | |
| """ | |
| if neg_embed is not None: | |
| npe = neg_embed[None, ...] | |
| else: | |
| npe = None | |
| return pipeline( | |
| height=height, | |
| width=width, | |
| num_images_per_prompt=num_imgs, | |
| prompt_embeds=pos_embed[None, ...], | |
| negative_prompt_embeds=npe, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| latents=latents, | |
| image=input_image, | |
| ).images[0] | |
| def getMetadata( | |
| class_pairs, | |
| path, | |
| seed, | |
| guidance_scale, | |
| num_inference_steps, | |
| num_interpolation_steps, | |
| sample_mid_interpolation, | |
| height, | |
| width, | |
| prompts, | |
| negative_prompts, | |
| pipeline, | |
| prompt_metadata, | |
| negative_prompt_metadata, | |
| ssim_metadata=None, | |
| save_json=True, | |
| save_path=".", | |
| ): | |
| """ | |
| Generate metadata for the given parameters. | |
| Args: | |
| class_pairs (list): List of class pairs. | |
| path (str): Path to the data. | |
| seed (int): Seed value for randomization. | |
| guidance_scale (float): Scale factor for guidance. | |
| num_inference_steps (int): Number of inference steps. | |
| num_interpolation_steps (int): Number of interpolation steps. | |
| sample_mid_interpolation (bool): Flag to sample mid-interpolation. | |
| height (int): Height of the image. | |
| width (int): Width of the image. | |
| prompts (list): List of prompts. | |
| negative_prompts (list): List of negative prompts. | |
| pipeline (object): Pipeline object. | |
| prompt_metadata (dict): Metadata for prompts. | |
| negative_prompt_metadata (dict): Metadata for negative prompts. | |
| ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None. | |
| save_json (bool, optional): Flag to save metadata as JSON. Defaults to True. | |
| save_path (str, optional): Path to save the JSON file. Defaults to ".". | |
| Returns: | |
| dict: Generated metadata. | |
| """ | |
| metadata = dict() | |
| metadata["class_pairs"] = class_pairs | |
| metadata["path"] = path | |
| metadata["seed"] = seed | |
| metadata["params"] = { | |
| "CFG": guidance_scale, | |
| "inferenceSteps": num_inference_steps, | |
| "interpolationSteps": num_interpolation_steps, | |
| "sampleMidInterpolation": sample_mid_interpolation, | |
| "height": height, | |
| "width": width, | |
| } | |
| for i in range(len(prompts)): | |
| metadata[f"prompt_text_{i}"] = prompts[i] | |
| if negative_prompts is not None: | |
| metadata[f"negative_prompt_text_{i}"] = negative_prompts[i] | |
| metadata["pipe_config"] = dict(pipeline.config) | |
| metadata["prompt_embed_similarity"] = prompt_metadata | |
| metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata | |
| if ssim_metadata is not None: | |
| print("Info: SSIM scores are available.") | |
| metadata["ssim_scores"] = ssim_metadata | |
| if save_json: | |
| with open( | |
| os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"), | |
| "w", | |
| ) as f: | |
| json.dump(metadata, f, indent=4) | |
| return metadata | |
| def groupbyInterpolation(dir_to_classfolder): | |
| """ | |
| Group files in a directory by interpolation step. | |
| Args: | |
| dir_to_classfolder (str): The path to the directory containing the files. | |
| Returns: | |
| None | |
| """ | |
| files = [ | |
| (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f)) | |
| for f in os.listdir(dir_to_classfolder) | |
| ] | |
| # create a subfolder for each step of the interpolation | |
| for interpolation_step, file_path in files: | |
| new_dir = os.path.join(dir_to_classfolder, interpolation_step) | |
| if not os.path.exists(new_dir): | |
| os.makedirs(new_dir) | |
| os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path))) | |
| def ungroupInterpolation(dir_to_classfolder): | |
| """ | |
| Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself, | |
| and then removes the subdirectories. | |
| Args: | |
| dir_to_classfolder (str): The path to the directory containing the subdirectories. | |
| Returns: | |
| None | |
| """ | |
| for interpolation_step in os.listdir(dir_to_classfolder): | |
| if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)): | |
| for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)): | |
| os.rename( | |
| os.path.join(dir_to_classfolder, interpolation_step, f), | |
| os.path.join(dir_to_classfolder, f), | |
| ) | |
| os.rmdir(os.path.join(dir_to_classfolder, interpolation_step)) | |
| def groupAllbyInterpolation( | |
| data_path, | |
| group=True, | |
| fn_group=groupbyInterpolation, | |
| fn_ungroup=ungroupInterpolation, | |
| ): | |
| """ | |
| Group or ungroup all data classes by interpolation. | |
| Args: | |
| data_path (str): The path to the data. | |
| group (bool, optional): Whether to group the data. Defaults to True. | |
| fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation. | |
| fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation. | |
| """ | |
| data_classes = sorted(os.listdir(data_path)) | |
| if group: | |
| fn = fn_group | |
| else: | |
| fn = fn_ungroup | |
| for c in data_classes: | |
| c_path = os.path.join(data_path, c) | |
| if os.path.isdir(c_path): | |
| fn(c_path) | |
| print(f"Processed {c}") | |
| def getPairIndices(subset_len, total_pair_count=1, seed=None): | |
| """ | |
| Generate pairs of indices for a given subset length. | |
| Args: | |
| subset_len (int): The length of the subset. | |
| total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1. | |
| seed (int, optional): The seed value for the random number generator. Defaults to None. | |
| Returns: | |
| list: A list of pairs of indices. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| group_size = (subset_len + total_pair_count - 1) // total_pair_count | |
| numbers = list(range(subset_len)) | |
| numbers_selection = list(range(subset_len)) | |
| rng.shuffle(numbers) | |
| for i in range(group_size - subset_len % group_size): | |
| numbers.append(numbers_selection[i]) | |
| numbers = np.array(numbers) | |
| groups = numbers[: group_size * total_pair_count].reshape(-1, group_size) | |
| return groups.tolist() | |
| def generateImagesFromDataset( | |
| img_subsets, | |
| class_iterables, | |
| pipeline, | |
| interpolated_prompt_embeds, | |
| interpolated_negative_prompts_embeds, | |
| num_inference_steps, | |
| guidance_scale, | |
| height=512, | |
| width=512, | |
| seed=None, | |
| save_path=".", | |
| class_pairs=("0", "1"), | |
| save_image=True, | |
| image_type="jpg", | |
| interpolate_range="full", | |
| device="cuda", | |
| return_images=False, | |
| ): | |
| """ | |
| Generates images from a dataset using the given parameters. | |
| Args: | |
| img_subsets (dict): A dictionary containing image subsets for each class. | |
| class_iterables (dict): A dictionary containing iterable objects for each class. | |
| pipeline (object): The pipeline object used for image generation. | |
| interpolated_prompt_embeds (list): A list of interpolated prompt embeddings. | |
| interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings. | |
| num_inference_steps (int): The number of inference steps for image generation. | |
| guidance_scale (float): The scale factor for guidance loss during image generation. | |
| height (int, optional): The height of the generated images. Defaults to 512. | |
| width (int, optional): The width of the generated images. Defaults to 512. | |
| seed (int, optional): The seed value for random number generation. Defaults to None. | |
| save_path (str, optional): The path to save the generated images. Defaults to ".". | |
| class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1"). | |
| save_image (bool, optional): Whether to save the generated images. Defaults to True. | |
| image_type (str, optional): The file format of the saved images. Defaults to "jpg". | |
| interpolate_range (str, optional): The range of interpolation for prompt embeddings. | |
| Possible values are "full", "nearest", or "furthest". Defaults to "full". | |
| device (str, optional): The device to use for image generation. Defaults to "cuda". | |
| return_images (bool, optional): Whether to return the generated images. Defaults to False. | |
| Returns: | |
| dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step. | |
| If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step. | |
| """ | |
| if interpolate_range == "nearest": | |
| nearest_half = True | |
| furthest_half = False | |
| elif interpolate_range == "furthest": | |
| nearest_half = False | |
| furthest_half = True | |
| else: | |
| nearest_half = False | |
| furthest_half = False | |
| if seed is None: | |
| seed = torch.Generator().seed() | |
| generator = torch.manual_seed(seed) | |
| rng = np.random.default_rng(seed) | |
| # Generating initial U-Net latent vectors from a random normal distribution. | |
| latents = torch.randn( | |
| (1, pipeline.unet.config.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ).to(device) | |
| embed_len = len(interpolated_prompt_embeds) | |
| embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds) | |
| embed_pairs_list = list(embed_pairs) | |
| if return_images: | |
| class_images = dict() | |
| class_ssim = dict() | |
| if nearest_half or furthest_half: | |
| if nearest_half: | |
| steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) | |
| mutiplier = 2 | |
| elif furthest_half: | |
| # uses opposite class of images of the text interpolation | |
| steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) | |
| mutiplier = 2 | |
| else: | |
| steps_range = (range(embed_len), range(embed_len)) | |
| mutiplier = 1 | |
| for class_iter, class_id in enumerate(class_pairs): | |
| if return_images: | |
| class_images[class_id] = list() | |
| class_ssim[class_id] = { | |
| i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len) | |
| } | |
| subset_len = len(img_subsets[class_id]) | |
| # to efficiently randomize the steps to interpolate for each image in the class, group_map is used | |
| # group_map: index is the image id, element is the group id | |
| # steps_range[class_iter] determines the range of steps to interpolate for the class, | |
| # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps | |
| # then the rest is to multiply the steps to cover the whole subset + remainder | |
| group_map = ( | |
| list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) | |
| ) | |
| rng.shuffle( | |
| group_map | |
| ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id | |
| iter_indices = class_iterables[class_id].pop() | |
| # generate images for each image in the class, randomly selecting an interpolated step | |
| for image_id in iter_indices: | |
| img, trg = img_subsets[class_id][image_id] | |
| input_image = img.unsqueeze(0) | |
| interpolate_step = group_map[image_id] | |
| prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step] | |
| generated_image = genClassImg( | |
| pipeline, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| input_image, | |
| generator, | |
| latents, | |
| num_imgs=1, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ) | |
| pred_image = transforms.ToTensor()(generated_image).unsqueeze(0) | |
| ssim_score = ssim(pred_image, input_image).item() | |
| class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score | |
| class_ssim[class_id][interpolate_step]["ssim_count"] += 1 | |
| if return_images: | |
| class_images[class_id].append(generated_image) | |
| if save_image: | |
| if image_type == "jpg": | |
| generated_image.save( | |
| f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", | |
| format="JPEG", | |
| quality=95, | |
| ) | |
| elif image_type == "png": | |
| generated_image.save( | |
| f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", | |
| format="PNG", | |
| ) | |
| else: | |
| generated_image.save( | |
| f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" | |
| ) | |
| # calculate ssim avg for the class | |
| for i_step in range(embed_len): | |
| if class_ssim[class_id][i_step]["ssim_count"] > 0: | |
| class_ssim[class_id][i_step]["ssim_avg"] = ( | |
| class_ssim[class_id][i_step]["ssim_sum"] | |
| / class_ssim[class_id][i_step]["ssim_count"] | |
| ) | |
| if return_images: | |
| return class_images, class_ssim | |
| else: | |
| return class_ssim | |
| def generateTrace( | |
| prompts, | |
| img_subsets, | |
| class_iterables, | |
| interpolated_prompt_embeds, | |
| interpolated_negative_prompts_embeds, | |
| subset_indices, | |
| seed=None, | |
| save_path=".", | |
| class_pairs=("0", "1"), | |
| image_type="jpg", | |
| interpolate_range="full", | |
| save_prompt_embeds=False, | |
| ): | |
| """ | |
| Generate a trace dictionary containing information about the generated images. | |
| Args: | |
| prompts (list): List of prompt texts. | |
| img_subsets (dict): Dictionary containing image subsets for each class. | |
| class_iterables (dict): Dictionary containing iterable objects for each class. | |
| interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings. | |
| interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings. | |
| subset_indices (dict): Dictionary containing indices of subsets for each class. | |
| seed (int, optional): Seed value for random number generation. Defaults to None. | |
| save_path (str, optional): Path to save the generated images. Defaults to ".". | |
| class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1"). | |
| image_type (str, optional): Type of the generated images. Defaults to "jpg". | |
| interpolate_range (str, optional): Range of interpolation. Defaults to "full". | |
| save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False. | |
| Returns: | |
| dict: Trace dictionary containing information about the generated images. | |
| """ | |
| trace_dict = { | |
| "class_pairs": list(), | |
| "class_id": list(), | |
| "image_id": list(), | |
| "interpolation_step": list(), | |
| "embed_len": list(), | |
| "pos_prompt_text": list(), | |
| "neg_prompt_text": list(), | |
| "input_file_path": list(), | |
| "output_file_path": list(), | |
| "input_prompts_embed": list(), | |
| } | |
| if interpolate_range == "nearest": | |
| nearest_half = True | |
| furthest_half = False | |
| elif interpolate_range == "furthest": | |
| nearest_half = False | |
| furthest_half = True | |
| else: | |
| nearest_half = False | |
| furthest_half = False | |
| if seed is None: | |
| seed = torch.Generator().seed() | |
| rng = np.random.default_rng(seed) | |
| embed_len = len(interpolated_prompt_embeds) | |
| embed_pairs = zip( | |
| interpolated_prompt_embeds.cpu().numpy(), | |
| interpolated_negative_prompts_embeds.cpu().numpy(), | |
| ) | |
| embed_pairs_list = list(embed_pairs) | |
| if nearest_half or furthest_half: | |
| if nearest_half: | |
| steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) | |
| mutiplier = 2 | |
| elif furthest_half: | |
| # uses opposite class of images of the text interpolation | |
| steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) | |
| mutiplier = 2 | |
| else: | |
| steps_range = (range(embed_len), range(embed_len)) | |
| mutiplier = 1 | |
| for class_iter, class_id in enumerate(class_pairs): | |
| subset_len = len(img_subsets[class_id]) | |
| # to efficiently randomize the steps to interpolate for each image in the class, group_map is used | |
| # group_map: index is the image id, element is the group id | |
| # steps_range[class_iter] determines the range of steps to interpolate for the class, | |
| # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps | |
| # then the rest is to multiply the steps to cover the whole subset + remainder | |
| group_map = ( | |
| list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) | |
| ) | |
| rng.shuffle( | |
| group_map | |
| ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id | |
| iter_indices = class_iterables[class_id].pop() | |
| # generate images for each image in the class, randomly selecting an interpolated step | |
| for image_id in iter_indices: | |
| class_ds = img_subsets[class_id] | |
| interpolate_step = group_map[image_id] | |
| sample_count = subset_indices[class_id][0] + image_id | |
| input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0]) | |
| pos_prompt = prompts[0] | |
| neg_prompt = prompts[1] | |
| output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" | |
| if save_prompt_embeds: | |
| input_prompts_embed = embed_pairs_list[interpolate_step] | |
| else: | |
| input_prompts_embed = None | |
| trace_dict["class_pairs"].append(class_pairs) | |
| trace_dict["class_id"].append(class_id) | |
| trace_dict["image_id"].append(image_id) | |
| trace_dict["interpolation_step"].append(interpolate_step) | |
| trace_dict["embed_len"].append(embed_len) | |
| trace_dict["pos_prompt_text"].append(pos_prompt) | |
| trace_dict["neg_prompt_text"].append(neg_prompt) | |
| trace_dict["input_file_path"].append(input_file) | |
| trace_dict["output_file_path"].append(output_file) | |
| trace_dict["input_prompts_embed"].append(input_prompts_embed) | |
| return trace_dict |