|  | import os | 
					
						
						|  | import pathlib | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from functools import partial | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torchvision import transforms | 
					
						
						|  | from torchvision.datasets.folder import is_image_file | 
					
						
						|  | from tqdm.auto import tqdm | 
					
						
						|  | import numpy as np | 
					
						
						|  | from skimage.color import rgb2gray | 
					
						
						|  | import diffusers | 
					
						
						|  |  | 
					
						
						|  | from libs.engine import ModelState | 
					
						
						|  | from libs.metric.lpips_origin import LPIPS | 
					
						
						|  | from libs.metric.piq.perceptual import DISTS as DISTS_PIQ | 
					
						
						|  | from libs.metric.clip_score import CLIPScoreWrapper | 
					
						
						|  | from methods.painter.diffsketchedit import ( | 
					
						
						|  | Painter, SketchPainterOptimizer, Token2AttnMixinASDSPipeline, Token2AttnMixinASDSSDXLPipeline) | 
					
						
						|  | from methods.painter.diffsketchedit.sketch_utils import ( | 
					
						
						|  | log_tensor_img, plt_batch, plt_attn, save_tensor_img, fix_image_scale) | 
					
						
						|  | from methods.painter.diffsketchedit.mask_utils import get_mask_u2net | 
					
						
						|  | from methods.token2attn.attn_control import AttentionStore, EmptyControl, \ | 
					
						
						|  | LocalBlend, AttentionReplace, AttentionRefine, AttentionReweight, get_equalizer | 
					
						
						|  | from methods.token2attn.ptp_utils import view_images, get_word_inds | 
					
						
						|  | from methods.diffusers_warp import init_diffusion_pipeline, model2res | 
					
						
						|  | from methods.diffvg_warp import init_diffvg | 
					
						
						|  | from methods.painter.diffsketchedit.process_svg import remove_low_opacity_paths | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DiffSketchEditPipeline(ModelState): | 
					
						
						|  | def __init__(self, args): | 
					
						
						|  | super().__init__(args, ignore_log=True) | 
					
						
						|  |  | 
					
						
						|  | init_diffvg(self.device, True, args.print_timing) | 
					
						
						|  |  | 
					
						
						|  | if args.model_id == "sdxl": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | custom_pipeline = Token2AttnMixinASDSSDXLPipeline | 
					
						
						|  | custom_scheduler = diffusers.DPMSolverMultistepScheduler | 
					
						
						|  | self.args.cross_attn_res = self.args.cross_attn_res * 2 | 
					
						
						|  | elif args.model_id == 'sd21': | 
					
						
						|  | custom_pipeline = Token2AttnMixinASDSPipeline | 
					
						
						|  | custom_scheduler = diffusers.DDIMScheduler | 
					
						
						|  | elif args.model_id == 'sd15': | 
					
						
						|  | custom_pipeline = Token2AttnMixinASDSPipeline | 
					
						
						|  | custom_scheduler = diffusers.DDIMScheduler | 
					
						
						|  | else: | 
					
						
						|  | custom_pipeline = Token2AttnMixinASDSPipeline | 
					
						
						|  | custom_scheduler = None | 
					
						
						|  |  | 
					
						
						|  | self.diffusion = init_diffusion_pipeline( | 
					
						
						|  | self.args.model_id, | 
					
						
						|  | custom_pipeline=custom_pipeline, | 
					
						
						|  | custom_scheduler=custom_scheduler, | 
					
						
						|  | device=self.device, | 
					
						
						|  | local_files_only=not args.download, | 
					
						
						|  | force_download=args.force_download, | 
					
						
						|  | resume_download=args.resume_download, | 
					
						
						|  | ldm_speed_up=args.ldm_speed_up, | 
					
						
						|  | enable_xformers=args.enable_xformers, | 
					
						
						|  | gradient_checkpoint=args.gradient_checkpoint, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.cargs = self.args.clip | 
					
						
						|  | self.clip_score_fn = CLIPScoreWrapper(self.cargs.model_name, | 
					
						
						|  | device=self.device, | 
					
						
						|  | visual_score=True, | 
					
						
						|  | feats_loss_type=self.cargs.feats_loss_type, | 
					
						
						|  | feats_loss_weights=self.cargs.feats_loss_weights, | 
					
						
						|  | fc_loss_weight=self.cargs.fc_loss_weight) | 
					
						
						|  |  | 
					
						
						|  | def update_info(self, seed, token_ind, prompt_input): | 
					
						
						|  | prompt_dir_name = prompt_input.split(' ') | 
					
						
						|  | prompt_dir_name = '_'.join(prompt_dir_name) | 
					
						
						|  |  | 
					
						
						|  | attn_log_ = f"-tk{token_ind}" | 
					
						
						|  | logdir_ = f"seed{seed}" \ | 
					
						
						|  | f"{attn_log_}" \ | 
					
						
						|  | f"-stage={self.args.run_stage}" | 
					
						
						|  | logdir_sec_ = f"" | 
					
						
						|  | self.args.path_svg = "" | 
					
						
						|  |  | 
					
						
						|  | if self.args.run_stage > 0: | 
					
						
						|  | logdir_sec_ = f"{logdir_sec_}-local={self.args.vector_local_edit}" | 
					
						
						|  | last_svg_base = os.path.join(self.args.results_path, self.args.edit_type, prompt_dir_name, logdir_[:-1] + str(self.args.run_stage - 1)) | 
					
						
						|  | if self.args.run_stage != 1: | 
					
						
						|  | last_svg_base += logdir_sec_ | 
					
						
						|  | self.args.path_svg = os.path.join(last_svg_base, "visual_best.svg") | 
					
						
						|  | self.args.attention_init = False | 
					
						
						|  |  | 
					
						
						|  | logdir_ = f"{prompt_dir_name}" + f"/" + logdir_ + logdir_sec_ | 
					
						
						|  | super().__init__(self.args, log_path_suffix=logdir_) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.png_logs_dir = self.results_path / "png_logs" | 
					
						
						|  | self.svg_logs_dir = self.results_path / "svg_logs" | 
					
						
						|  | self.attn_logs_dir = self.results_path / "attn_logs" | 
					
						
						|  | if self.accelerator.is_main_process: | 
					
						
						|  | self.png_logs_dir.mkdir(parents=True, exist_ok=True) | 
					
						
						|  | self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | 
					
						
						|  | self.attn_logs_dir.mkdir(parents=True, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | self.g_device = torch.Generator().manual_seed(seed) | 
					
						
						|  |  | 
					
						
						|  | def load_render(self, target_img, attention_map, mask=None): | 
					
						
						|  | renderer = Painter(self.args, | 
					
						
						|  | num_strokes=self.args.num_paths, | 
					
						
						|  | num_segments=self.args.num_segments, | 
					
						
						|  | imsize=self.args.image_size, | 
					
						
						|  | device=self.device, | 
					
						
						|  | target_im=target_img, | 
					
						
						|  | attention_map=attention_map, | 
					
						
						|  | mask=mask) | 
					
						
						|  | return renderer | 
					
						
						|  |  | 
					
						
						|  | def attn_map_normalizing(self, cross_attn_map): | 
					
						
						|  | cross_attn_map = 255 * cross_attn_map / cross_attn_map.max() | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = cross_attn_map.unsqueeze(-1).expand(*cross_attn_map.shape, 3) | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = cross_attn_map.permute(2, 0, 1).unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = F.interpolate(cross_attn_map, size=self.args.image_size, mode='bicubic') | 
					
						
						|  | cross_attn_map = torch.clamp(cross_attn_map, min=0, max=255) | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = rgb2gray(cross_attn_map.squeeze(0).permute(1, 2, 0)).astype(np.float32) | 
					
						
						|  |  | 
					
						
						|  | if cross_attn_map.shape[-1] != self.args.image_size and cross_attn_map.shape[-2] != self.args.image_size: | 
					
						
						|  | cross_attn_map = cross_attn_map.reshape(self.args.image_size, self.args.image_size) | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = (cross_attn_map - cross_attn_map.min()) / (cross_attn_map.max() - cross_attn_map.min()) | 
					
						
						|  | return cross_attn_map | 
					
						
						|  |  | 
					
						
						|  | def compute_local_edit_maps(self, cross_attn_maps_src_tar, prompts, words, save_path, threshold=0.3): | 
					
						
						|  | """ | 
					
						
						|  | cross_attn_maps_src_tar: [(res, res, 77), (res, res, 77)] | 
					
						
						|  | """ | 
					
						
						|  | local_edit_region = np.zeros(shape=(self.args.image_size, self.args.image_size), dtype=np.float32) | 
					
						
						|  | for i, (prompt, word) in enumerate(zip(prompts, words)): | 
					
						
						|  | ind = get_word_inds(prompt, word, self.diffusion.tokenizer) | 
					
						
						|  | assert len(ind) == 1 | 
					
						
						|  | ind = ind[0] | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = cross_attn_maps_src_tar[i][:, :, ind] | 
					
						
						|  | cross_attn_map = self.attn_map_normalizing(cross_attn_map) | 
					
						
						|  | cross_attn_map_bin = cross_attn_map >= threshold | 
					
						
						|  | local_edit_region += cross_attn_map_bin | 
					
						
						|  | local_edit_region = (np.clip(local_edit_region, 0, 1) * 255).astype(np.uint8) | 
					
						
						|  | local_edit_region = Image.fromarray(local_edit_region, 'L') | 
					
						
						|  | local_edit_region.save(save_path, 'PNG') | 
					
						
						|  |  | 
					
						
						|  | def extract_ldm_attn(self, prompts, token_ind, changing_region_words, reweight_word, reweight_weight): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not self.args.attention_init: | 
					
						
						|  | controller = EmptyControl() | 
					
						
						|  | else: | 
					
						
						|  | lb = LocalBlend(prompts=prompts, | 
					
						
						|  | words=changing_region_words, tokenizer=self.diffusion.tokenizer, | 
					
						
						|  | device=self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.edit_type == "replace": | 
					
						
						|  | controller = AttentionReplace(prompts=prompts, | 
					
						
						|  | num_steps=self.args.num_inference_steps, | 
					
						
						|  | cross_replace_steps=0.4, | 
					
						
						|  | self_replace_steps=0.4, | 
					
						
						|  | local_blend=lb, | 
					
						
						|  | tokenizer=self.diffusion.tokenizer, | 
					
						
						|  | device=self.device) | 
					
						
						|  | elif self.args.edit_type == "refine": | 
					
						
						|  | controller = AttentionRefine(prompts=prompts, | 
					
						
						|  | num_steps=self.args.num_inference_steps, | 
					
						
						|  | cross_replace_steps=0.8, | 
					
						
						|  | self_replace_steps=0.4, | 
					
						
						|  | local_blend=lb, | 
					
						
						|  | tokenizer=self.diffusion.tokenizer, | 
					
						
						|  | device=self.device) | 
					
						
						|  | elif self.args.edit_type == "reweight": | 
					
						
						|  | equalizer = get_equalizer(self.diffusion.tokenizer, prompts[1:], | 
					
						
						|  | reweight_word, reweight_weight) | 
					
						
						|  | controller = AttentionReweight(prompts=prompts, | 
					
						
						|  | num_steps=self.args.num_inference_steps, | 
					
						
						|  | cross_replace_steps=0.8, | 
					
						
						|  | self_replace_steps=0.4, | 
					
						
						|  | local_blend=lb, | 
					
						
						|  | equalizer=equalizer, | 
					
						
						|  |  | 
					
						
						|  | tokenizer=self.diffusion.tokenizer, | 
					
						
						|  | device=self.device) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('Unknown edit_type:', self.args.edit_type) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | height = width = model2res(self.args.model_id) | 
					
						
						|  | outputs = self.diffusion(prompt=prompts, | 
					
						
						|  | negative_prompt=[self.args.negative_prompt] * len(prompts), | 
					
						
						|  | height=height, | 
					
						
						|  | width=width, | 
					
						
						|  | controller=controller, | 
					
						
						|  | num_inference_steps=self.args.num_inference_steps, | 
					
						
						|  | guidance_scale=self.args.guidance_scale, | 
					
						
						|  | generator=self.g_device) | 
					
						
						|  |  | 
					
						
						|  | print('outputs.images', len(outputs.images)) | 
					
						
						|  | for ii, img in enumerate(outputs.images): | 
					
						
						|  | if ii == 0: | 
					
						
						|  | filename = "ldm_generated_image.png" | 
					
						
						|  | target_file = self.results_path / filename | 
					
						
						|  | else: | 
					
						
						|  | filename = "ldm_generated_image" + str(ii) + ".png" | 
					
						
						|  |  | 
					
						
						|  | target_file_tmp = self.results_path / filename | 
					
						
						|  | view_images([np.array(img)], save_image=True, fp=target_file_tmp) | 
					
						
						|  |  | 
					
						
						|  | if self.args.attention_init: | 
					
						
						|  | """ldm cross-attention map""" | 
					
						
						|  | cross_attention_maps, tokens = \ | 
					
						
						|  | self.diffusion.get_cross_attention(prompts, | 
					
						
						|  | controller, | 
					
						
						|  | res=self.args.cross_attn_res, | 
					
						
						|  | from_where=("up", "down"), | 
					
						
						|  | save_path=self.results_path / "cross_attn.png", | 
					
						
						|  | select=0) | 
					
						
						|  | for ii in range(1, len(outputs.images)): | 
					
						
						|  | cross_attn_png_name = "cross_attn" + str(ii) + ".png" | 
					
						
						|  | cross_attention_maps_i, tokens_i = \ | 
					
						
						|  | self.diffusion.get_cross_attention(prompts, | 
					
						
						|  | controller, | 
					
						
						|  | res=self.args.cross_attn_res, | 
					
						
						|  | from_where=("up", "down"), | 
					
						
						|  | save_path=self.results_path / cross_attn_png_name, | 
					
						
						|  | select=ii) | 
					
						
						|  |  | 
					
						
						|  | self.print(f"the length of tokens is {len(tokens)}, select {token_ind}-th token") | 
					
						
						|  |  | 
					
						
						|  | self.print(f"origin cross_attn_map shape: {cross_attention_maps.shape}") | 
					
						
						|  |  | 
					
						
						|  | cross_attn_map = cross_attention_maps[:, :, token_ind] | 
					
						
						|  | self.print(f"select cross_attn_map shape: {cross_attn_map.shape}\n") | 
					
						
						|  | cross_attn_map = self.attn_map_normalizing(cross_attn_map) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cross_attention_maps_local_list = [] | 
					
						
						|  | for ii in range(len(outputs.images)): | 
					
						
						|  | cross_attention_maps_local = \ | 
					
						
						|  | self.diffusion.get_cross_attention2(prompts, | 
					
						
						|  | controller, | 
					
						
						|  | res=self.args.vector_local_edit_attn_res, | 
					
						
						|  | from_where=("up", "down"), | 
					
						
						|  | select=ii) | 
					
						
						|  | cross_attention_maps_local_list.append(cross_attention_maps_local) | 
					
						
						|  |  | 
					
						
						|  | if ii == 0: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | save_name = "cross_attn_local_edit_" + str(self.args.vector_local_edit_attn_res) + "-" + str(ii) + ".png" | 
					
						
						|  |  | 
					
						
						|  | if self.args.edit_type == "replace": | 
					
						
						|  | self.compute_local_edit_maps([cross_attention_maps_local_list[ii-1]], [prompts[ii-1]], [changing_region_words[ii][0]], | 
					
						
						|  | save_path=self.results_path / save_name, | 
					
						
						|  | threshold=self.args.vector_local_edit_bin_threshold_replace) | 
					
						
						|  | elif self.args.edit_type == "refine": | 
					
						
						|  | self.compute_local_edit_maps([cross_attention_maps_local_list[ii]], [prompts[ii]], [changing_region_words[ii][1]], | 
					
						
						|  | save_path=self.results_path / save_name, | 
					
						
						|  | threshold=self.args.vector_local_edit_bin_threshold_refine) | 
					
						
						|  | elif self.args.edit_type == "reweight": | 
					
						
						|  | self.compute_local_edit_maps([cross_attention_maps_local_list[ii-1]], [prompts[ii-1]], [changing_region_words[ii][0]], | 
					
						
						|  | save_path=self.results_path / save_name, | 
					
						
						|  | threshold=self.args.vector_local_edit_bin_threshold_reweight) | 
					
						
						|  |  | 
					
						
						|  | if self.args.sd_image_only: | 
					
						
						|  | return target_file.as_posix(), None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ldm self-attention map""" | 
					
						
						|  | self_attention_maps, svd, vh_ = \ | 
					
						
						|  | self.diffusion.get_self_attention_comp(prompts, | 
					
						
						|  | controller, | 
					
						
						|  | res=self.args.self_attn_res, | 
					
						
						|  | from_where=("up", "down"), | 
					
						
						|  | img_size=self.args.image_size, | 
					
						
						|  | max_com=self.args.max_com, | 
					
						
						|  | save_path=self.results_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.mean_comp: | 
					
						
						|  | self_attn = np.mean(vh_, axis=0) | 
					
						
						|  | self.print(f"use the mean of {self.args.max_com} comps.") | 
					
						
						|  | else: | 
					
						
						|  | self_attn = vh_[self.args.comp_idx] | 
					
						
						|  | self.print(f"select {self.args.comp_idx}-th comp.") | 
					
						
						|  |  | 
					
						
						|  | self_attn = (self_attn - self_attn.min()) / (self_attn.max() - self_attn.min()) | 
					
						
						|  |  | 
					
						
						|  | self_attn_vis = np.copy(self_attn) | 
					
						
						|  | self_attn_vis = self_attn_vis * 255 | 
					
						
						|  | self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8) | 
					
						
						|  | view_images(self_attn_vis, save_image=True, fp=self.results_path / "self-attn-final.png") | 
					
						
						|  |  | 
					
						
						|  | """attention map fusion""" | 
					
						
						|  | attn_map = self.args.attn_coeff * cross_attn_map + (1 - self.args.attn_coeff) * self_attn | 
					
						
						|  |  | 
					
						
						|  | attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) | 
					
						
						|  |  | 
					
						
						|  | self.print(f"-> fusion attn_map: {attn_map.shape}") | 
					
						
						|  | else: | 
					
						
						|  | attn_map = None | 
					
						
						|  |  | 
					
						
						|  | return target_file.as_posix(), attn_map | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def clip_norm_(self): | 
					
						
						|  | return transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | 
					
						
						|  |  | 
					
						
						|  | def clip_pair_augment(self, | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | y: torch.Tensor, | 
					
						
						|  | im_res: int, | 
					
						
						|  | augments: str = "affine_norm", | 
					
						
						|  | num_aug: int = 4): | 
					
						
						|  |  | 
					
						
						|  | augment_list = [] | 
					
						
						|  | if "affine" in augments: | 
					
						
						|  | augment_list.append( | 
					
						
						|  | transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5) | 
					
						
						|  | ) | 
					
						
						|  | augment_list.append( | 
					
						
						|  | transforms.RandomResizedCrop(im_res, scale=(0.8, 0.8), ratio=(1.0, 1.0)) | 
					
						
						|  | ) | 
					
						
						|  | augment_list.append(self.clip_norm_) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | augment_compose = transforms.Compose(augment_list) | 
					
						
						|  |  | 
					
						
						|  | x_augs, y_augs = [self.clip_score_fn.normalize(x)], [self.clip_score_fn.normalize(y)] | 
					
						
						|  |  | 
					
						
						|  | for n in range(num_aug): | 
					
						
						|  | augmented_pair = augment_compose(torch.cat([x, y])) | 
					
						
						|  | x_augs.append(augmented_pair[0].unsqueeze(0)) | 
					
						
						|  | y_augs.append(augmented_pair[1].unsqueeze(0)) | 
					
						
						|  | xs = torch.cat(x_augs, dim=0) | 
					
						
						|  | ys = torch.cat(y_augs, dim=0) | 
					
						
						|  | return xs, ys | 
					
						
						|  |  | 
					
						
						|  | def painterly_rendering(self, prompts, token_ind, changing_region_words, reweight_word, reweight_weight): | 
					
						
						|  |  | 
					
						
						|  | self.print(f"prompts: {prompts}") | 
					
						
						|  | self.print(f"negative_prompt: {self.args.negative_prompt}") | 
					
						
						|  | self.print(f"token_ind: {token_ind}") | 
					
						
						|  | self.print(f"changing_region_words: {changing_region_words}") | 
					
						
						|  | self.print(f"reweight_word: {reweight_word}") | 
					
						
						|  | self.print(f"reweight_weight: {reweight_weight}\n") | 
					
						
						|  | if self.args.negative_prompt is None: | 
					
						
						|  | self.args.negative_prompt = "" | 
					
						
						|  |  | 
					
						
						|  | log_path = os.path.join(self.results_path.as_posix(), 'log.txt') | 
					
						
						|  | with open(log_path, "w") as f: | 
					
						
						|  | f.write("prompts: " + str(prompts) + "\n") | 
					
						
						|  | f.write("negative_prompt: " + self.args.negative_prompt + "\n") | 
					
						
						|  | f.write("token_ind: " + str(token_ind) + "\n") | 
					
						
						|  | f.write("changing_region_words: " + str(changing_region_words) + "\n") | 
					
						
						|  | f.write("reweight_word: " + str(reweight_word) + "\n") | 
					
						
						|  | f.write("reweight_weight: " + str(reweight_weight) + "\n") | 
					
						
						|  | f.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.run_stage == 0: | 
					
						
						|  | target_file, attention_map = self.extract_ldm_attn(prompts, token_ind, changing_region_words, | 
					
						
						|  | reweight_word, reweight_weight) | 
					
						
						|  | else: | 
					
						
						|  | results_base = self.results_path.as_posix() | 
					
						
						|  | target_file = os.path.join(results_base[:results_base.find('stage=' + str(self.args.run_stage))] + 'stage=0', "ldm_generated_image" + str(self.args.run_stage) + ".png") | 
					
						
						|  | attention_map = None | 
					
						
						|  |  | 
					
						
						|  | if not self.args.sd_image_only: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | perceptual_loss_fn = None | 
					
						
						|  | if self.args.perceptual.coeff > 0: | 
					
						
						|  | if self.args.perceptual.name == "lpips": | 
					
						
						|  | lpips_loss_fn = LPIPS(net=self.args.perceptual.lpips_net).to(self.device) | 
					
						
						|  | perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False) | 
					
						
						|  | elif self.args.perceptual.name == "dists": | 
					
						
						|  | perceptual_loss_fn = DISTS_PIQ() | 
					
						
						|  |  | 
					
						
						|  | inputs, mask = self.get_target(target_file, | 
					
						
						|  | self.args.image_size, | 
					
						
						|  | self.results_path, | 
					
						
						|  | self.args.u2net_path, | 
					
						
						|  | self.args.mask_object, | 
					
						
						|  | self.args.fix_scale, | 
					
						
						|  | self.device) | 
					
						
						|  | inputs = inputs.detach() | 
					
						
						|  | self.print("inputs shape: ", inputs.shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | renderer = Painter(self.args, | 
					
						
						|  | num_strokes=self.args.num_paths, | 
					
						
						|  | num_segments=self.args.num_segments, | 
					
						
						|  | imsize=self.args.image_size, | 
					
						
						|  | device=self.device, | 
					
						
						|  | target_im=inputs, | 
					
						
						|  | attention_map=attention_map, | 
					
						
						|  | mask=mask, | 
					
						
						|  | results_base=self.results_path.as_posix()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = renderer.init_image(stage=0) | 
					
						
						|  | self.print("init_image shape: ", img.shape) | 
					
						
						|  | log_tensor_img(img, self.results_path, output_prefix="init_sketch") | 
					
						
						|  |  | 
					
						
						|  | optimizer = SketchPainterOptimizer(renderer, | 
					
						
						|  | self.args.lr, | 
					
						
						|  | self.args.optim_opacity, | 
					
						
						|  | self.args.optim_rgba, | 
					
						
						|  | self.args.color_lr, | 
					
						
						|  | self.args.optim_width, | 
					
						
						|  | self.args.width_lr) | 
					
						
						|  | optimizer.init_optimizers() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.print(f"-> Painter points Params: {len(renderer.get_points_params())}") | 
					
						
						|  | self.print(f"-> Painter width Params: {len(renderer.get_width_parameters())}") | 
					
						
						|  | self.print(f"-> Painter opacity Params: {len(renderer.get_color_parameters())}") | 
					
						
						|  |  | 
					
						
						|  | best_visual_loss, best_semantic_loss = 100, 100 | 
					
						
						|  | best_iter_v, best_iter_s = 0, 0 | 
					
						
						|  | min_delta = 1e-6 | 
					
						
						|  | vid_idx = 1 | 
					
						
						|  |  | 
					
						
						|  | self.print(f"\ntotal optimization steps: {self.args.num_iter}") | 
					
						
						|  | with tqdm(initial=self.step, total=self.args.num_iter, disable=not self.accelerator.is_main_process) as pbar: | 
					
						
						|  | while self.step < self.args.num_iter: | 
					
						
						|  | raster_sketch = renderer.get_image().to(self.device) | 
					
						
						|  |  | 
					
						
						|  | target_prompt = prompts[self.args.run_stage] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sds_loss, grad = torch.tensor(0), torch.tensor(0) | 
					
						
						|  | if self.step >= self.args.sds.warmup: | 
					
						
						|  | grad_scale = self.args.sds.grad_scale if self.step > self.args.sds.warmup else 0 | 
					
						
						|  | sds_loss, grad = self.diffusion.score_distillation_sampling( | 
					
						
						|  | raster_sketch, | 
					
						
						|  | crop_size=self.args.sds.crop_size, | 
					
						
						|  | augments=self.args.sds.augmentations, | 
					
						
						|  | prompt=[target_prompt], | 
					
						
						|  | negative_prompt=[self.args.negative_prompt], | 
					
						
						|  | guidance_scale=self.args.sds.guidance_scale, | 
					
						
						|  | grad_scale=grad_scale, | 
					
						
						|  | t_range=list(self.args.sds.t_range), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | raster_sketch_aug, inputs_aug = self.clip_pair_augment( | 
					
						
						|  | raster_sketch, inputs, | 
					
						
						|  | im_res=224, | 
					
						
						|  | augments=self.cargs.augmentations, | 
					
						
						|  | num_aug=self.cargs.num_aug | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total_visual_loss = torch.tensor(0) | 
					
						
						|  | l_clip_fc, l_clip_conv, clip_conv_loss_sum = torch.tensor(0), [], torch.tensor(0) | 
					
						
						|  | if self.args.clip.vis_loss > 0: | 
					
						
						|  | l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance( | 
					
						
						|  | raster_sketch_aug, inputs_aug, clip_norm=False | 
					
						
						|  | ) | 
					
						
						|  | clip_conv_loss_sum = sum(l_clip_conv) | 
					
						
						|  | total_visual_loss = self.args.clip.vis_loss * (clip_conv_loss_sum + l_clip_fc) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | l_percep = torch.tensor(0.) | 
					
						
						|  | if perceptual_loss_fn is not None: | 
					
						
						|  | l_perceptual = perceptual_loss_fn(raster_sketch, inputs).mean() | 
					
						
						|  | l_percep = l_perceptual * self.args.perceptual.coeff | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | l_tvd = torch.tensor(0.) | 
					
						
						|  | if self.cargs.text_visual_coeff > 0: | 
					
						
						|  | l_tvd = self.clip_score_fn.compute_text_visual_distance( | 
					
						
						|  | raster_sketch_aug, target_prompt | 
					
						
						|  | ) * self.cargs.text_visual_coeff | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = sds_loss + total_visual_loss + l_percep + l_tvd | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | optimizer.zero_grad_() | 
					
						
						|  | loss.backward() | 
					
						
						|  | optimizer.step_() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.lr_scheduler: | 
					
						
						|  | optimizer.update_lr(self.step, self.args.lr, self.args.decay_steps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pbar.set_description( | 
					
						
						|  | f"lr: {optimizer.get_lr():.2f}, " | 
					
						
						|  | f"l_total: {loss.item():.4f}, " | 
					
						
						|  | f"l_clip_fc: {l_clip_fc.item():.4f}, " | 
					
						
						|  | f"l_clip_conv({len(l_clip_conv)}): {clip_conv_loss_sum.item():.4f}, " | 
					
						
						|  | f"l_tvd: {l_tvd.item():.4f}, " | 
					
						
						|  | f"l_percep: {l_percep.item():.4f}, " | 
					
						
						|  | f"sds: {grad.item():.4e}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.make_video and (self.step % self.args.video_frame_freq == 0) \ | 
					
						
						|  | and self.accelerator.is_main_process: | 
					
						
						|  | log_tensor_img(raster_sketch, output_dir=self.png_logs_dir, | 
					
						
						|  | output_prefix=f'frame{vid_idx}', dpi=100) | 
					
						
						|  | vid_idx += 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | 
					
						
						|  |  | 
					
						
						|  | plt_batch(inputs, | 
					
						
						|  | raster_sketch, | 
					
						
						|  | self.step, | 
					
						
						|  | target_prompt, | 
					
						
						|  | save_path=self.png_logs_dir.as_posix(), | 
					
						
						|  | name=f"iter{self.step}") | 
					
						
						|  |  | 
					
						
						|  | renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.log_cross_attn: | 
					
						
						|  | controller = AttentionStore() | 
					
						
						|  | _, _ = self.diffusion.get_cross_attention([target_prompt], | 
					
						
						|  | controller, | 
					
						
						|  | res=self.args.cross_attn_res, | 
					
						
						|  | from_where=("up", "down"), | 
					
						
						|  | save_path=self.attn_logs_dir / f"iter{self.step}.png") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.step % self.args.eval_step == 0 and self.accelerator.is_main_process: | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  |  | 
					
						
						|  | l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance( | 
					
						
						|  | raster_sketch_aug, inputs_aug, clip_norm=False | 
					
						
						|  | ) | 
					
						
						|  | loss_eval = sum(l_clip_conv) + l_clip_fc | 
					
						
						|  |  | 
					
						
						|  | cur_delta = loss_eval.item() - best_visual_loss | 
					
						
						|  | if abs(cur_delta) > min_delta and cur_delta < 0: | 
					
						
						|  | best_visual_loss = loss_eval.item() | 
					
						
						|  | best_iter_v = self.step | 
					
						
						|  | plt_batch(inputs, | 
					
						
						|  | raster_sketch, | 
					
						
						|  | best_iter_v, | 
					
						
						|  | target_prompt, | 
					
						
						|  | save_path=self.results_path.as_posix(), | 
					
						
						|  | name="visual_best") | 
					
						
						|  | renderer.save_svg(self.results_path.as_posix(), "visual_best") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss_eval = self.clip_score_fn.compute_text_visual_distance( | 
					
						
						|  | raster_sketch_aug, target_prompt | 
					
						
						|  | ) | 
					
						
						|  | cur_delta = loss_eval.item() - best_semantic_loss | 
					
						
						|  | if abs(cur_delta) > min_delta and cur_delta < 0: | 
					
						
						|  | best_semantic_loss = loss_eval.item() | 
					
						
						|  | best_iter_s = self.step | 
					
						
						|  | plt_batch(inputs, | 
					
						
						|  | raster_sketch, | 
					
						
						|  | best_iter_s, | 
					
						
						|  | target_prompt, | 
					
						
						|  | save_path=self.results_path.as_posix(), | 
					
						
						|  | name="semantic_best") | 
					
						
						|  | renderer.save_svg(self.results_path.as_posix(), "semantic_best") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.step == 0 and self.args.attention_init and self.accelerator.is_main_process: | 
					
						
						|  | plt_attn(renderer.get_attn(), | 
					
						
						|  | renderer.get_thresh(), | 
					
						
						|  | inputs, | 
					
						
						|  | renderer.get_inds(), | 
					
						
						|  | (self.results_path / "attention_map.jpg").as_posix()) | 
					
						
						|  |  | 
					
						
						|  | self.step += 1 | 
					
						
						|  | pbar.update(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | renderer.save_svg(self.svg_logs_dir.as_posix(), "final_svg_tmp") | 
					
						
						|  |  | 
					
						
						|  | if self.args.opacity_delta != 0: | 
					
						
						|  | remove_low_opacity_paths(self.svg_logs_dir / "final_svg_tmp.svg", | 
					
						
						|  | self.results_path / "final_svg.svg", | 
					
						
						|  | self.args.opacity_delta) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_raster_sketch = renderer.get_image().to(self.device) | 
					
						
						|  | save_tensor_img(final_raster_sketch, | 
					
						
						|  | save_path=self.results_path, | 
					
						
						|  | name='final_render') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.args.make_video: | 
					
						
						|  | from subprocess import call | 
					
						
						|  | call([ | 
					
						
						|  | "ffmpeg", | 
					
						
						|  | "-framerate", 24, | 
					
						
						|  | "-i", (self.png_logs_dir / "frame%d.png").as_posix(), | 
					
						
						|  | "-vb", "20M", | 
					
						
						|  | (self.results_path / "out.mp4").as_posix() | 
					
						
						|  | ]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_target(self, | 
					
						
						|  | target_file, | 
					
						
						|  | image_size, | 
					
						
						|  | output_dir, | 
					
						
						|  | u2net_path, | 
					
						
						|  | mask_object, | 
					
						
						|  | fix_scale, | 
					
						
						|  | device): | 
					
						
						|  | if not is_image_file(target_file): | 
					
						
						|  | raise TypeError(f"{target_file} is not image file.") | 
					
						
						|  |  | 
					
						
						|  | target = Image.open(target_file) | 
					
						
						|  |  | 
					
						
						|  | if target.mode == "RGBA": | 
					
						
						|  |  | 
					
						
						|  | new_image = Image.new("RGBA", target.size, "WHITE") | 
					
						
						|  |  | 
					
						
						|  | new_image.paste(target, (0, 0), target) | 
					
						
						|  | target = new_image | 
					
						
						|  | target = target.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask = target | 
					
						
						|  | if mask_object: | 
					
						
						|  | if pathlib.Path(u2net_path).exists(): | 
					
						
						|  | masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, device) | 
					
						
						|  | target = masked_im | 
					
						
						|  | else: | 
					
						
						|  | self.print(f"'{u2net_path}' is not exist, disable mask target") | 
					
						
						|  |  | 
					
						
						|  | if fix_scale: | 
					
						
						|  | target = fix_image_scale(target) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | transforms_ = [] | 
					
						
						|  | if target.size[0] != target.size[1]: | 
					
						
						|  | transforms_.append(transforms.Resize((image_size, image_size))) | 
					
						
						|  | else: | 
					
						
						|  | transforms_.append(transforms.Resize(image_size)) | 
					
						
						|  | transforms_.append(transforms.CenterCrop(image_size)) | 
					
						
						|  | transforms_.append(transforms.ToTensor()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data_transforms = transforms.Compose(transforms_) | 
					
						
						|  | target_ = data_transforms(target).unsqueeze(0).to(self.device) | 
					
						
						|  |  | 
					
						
						|  | return target_, mask | 
					
						
						|  |  |