|  | from dataclasses import dataclass, field | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import json | 
					
						
						|  | import copy | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from skimage import measure | 
					
						
						|  | from einops import repeat | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | from diffusers import ( | 
					
						
						|  | DDPMScheduler, | 
					
						
						|  | DDIMScheduler, | 
					
						
						|  | UniPCMultistepScheduler, | 
					
						
						|  | KarrasVeScheduler, | 
					
						
						|  | DPMSolverMultistepScheduler | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | import craftsman | 
					
						
						|  | from craftsman.systems.base import BaseSystem | 
					
						
						|  | from craftsman.utils.ops import generate_dense_grid_points | 
					
						
						|  | from craftsman.utils.misc import get_rank | 
					
						
						|  | from craftsman.utils.typing import * | 
					
						
						|  |  | 
					
						
						|  | def compute_snr(noise_scheduler, timesteps): | 
					
						
						|  | """ | 
					
						
						|  | Computes SNR as per | 
					
						
						|  | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | 
					
						
						|  | """ | 
					
						
						|  | alphas_cumprod = noise_scheduler.alphas_cumprod | 
					
						
						|  | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 
					
						
						|  | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 
					
						
						|  | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | 
					
						
						|  | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 
					
						
						|  | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | 
					
						
						|  |  | 
					
						
						|  | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 
					
						
						|  | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 
					
						
						|  | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 
					
						
						|  | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | snr = (alpha / sigma) ** 2 | 
					
						
						|  | return snr | 
					
						
						|  |  | 
					
						
						|  | def ddim_sample(ddim_scheduler: DDIMScheduler, | 
					
						
						|  | diffusion_model: torch.nn.Module, | 
					
						
						|  | shape: Union[List[int], Tuple[int]], | 
					
						
						|  | cond: torch.FloatTensor, | 
					
						
						|  | steps: int, | 
					
						
						|  | eta: float = 0.0, | 
					
						
						|  | guidance_scale: float = 3.0, | 
					
						
						|  | do_classifier_free_guidance: bool = True, | 
					
						
						|  | generator: Optional[torch.Generator] = None, | 
					
						
						|  | device: torch.device = "cuda:0", | 
					
						
						|  | disable_prog: bool = True): | 
					
						
						|  |  | 
					
						
						|  | assert steps > 0, f"{steps} must > 0." | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bsz = cond.shape[0] | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | bsz = bsz // 2 | 
					
						
						|  |  | 
					
						
						|  | latents = torch.randn( | 
					
						
						|  | (bsz, *shape), | 
					
						
						|  | generator=generator, | 
					
						
						|  | device=cond.device, | 
					
						
						|  | dtype=cond.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | latents = latents * ddim_scheduler.init_noise_sigma | 
					
						
						|  |  | 
					
						
						|  | ddim_scheduler.set_timesteps(steps) | 
					
						
						|  | timesteps = ddim_scheduler.timesteps.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | extra_step_kwargs = { | 
					
						
						|  |  | 
					
						
						|  | "generator": generator | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = ( | 
					
						
						|  | torch.cat([latents] * 2) | 
					
						
						|  | if do_classifier_free_guidance | 
					
						
						|  | else latents | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) | 
					
						
						|  | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) | 
					
						
						|  | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
						
						|  | noise_pred = noise_pred_uncond + guidance_scale * ( | 
					
						
						|  | noise_pred_text - noise_pred_uncond | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents = ddim_scheduler.step( | 
					
						
						|  | noise_pred, t, latents, **extra_step_kwargs | 
					
						
						|  | ).prev_sample | 
					
						
						|  |  | 
					
						
						|  | yield latents, t | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @craftsman.register("shape-diffusion-system") | 
					
						
						|  | class ShapeDiffusionSystem(BaseSystem): | 
					
						
						|  | @dataclass | 
					
						
						|  | class Config(BaseSystem.Config): | 
					
						
						|  | val_samples_json: str = None | 
					
						
						|  | z_scale_factor: float = 1.0 | 
					
						
						|  | guidance_scale: float = 7.5 | 
					
						
						|  | num_inference_steps: int = 50 | 
					
						
						|  | eta: float = 0.0 | 
					
						
						|  | snr_gamma: float = 5.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | shape_model_type: str = None | 
					
						
						|  | shape_model: dict = field(default_factory=dict) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | condition_model_type: str = None | 
					
						
						|  | condition_model: dict = field(default_factory=dict) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | denoiser_model_type: str = None | 
					
						
						|  | denoiser_model: dict = field(default_factory=dict) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_scheduler_type: str = None | 
					
						
						|  | noise_scheduler: dict = field(default_factory=dict) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | denoise_scheduler_type: str = None | 
					
						
						|  | denoise_scheduler: dict = field(default_factory=dict) | 
					
						
						|  |  | 
					
						
						|  | cfg: Config | 
					
						
						|  |  | 
					
						
						|  | def configure(self): | 
					
						
						|  | super().configure() | 
					
						
						|  |  | 
					
						
						|  | self.shape_model = craftsman.find(self.cfg.shape_model_type)(self.cfg.shape_model) | 
					
						
						|  | self.shape_model.eval() | 
					
						
						|  | self.shape_model.requires_grad_(False) | 
					
						
						|  |  | 
					
						
						|  | self.condition = craftsman.find(self.cfg.condition_model_type)(self.cfg.condition_model) | 
					
						
						|  |  | 
					
						
						|  | self.denoiser_model = craftsman.find(self.cfg.denoiser_model_type)(self.cfg.denoiser_model) | 
					
						
						|  |  | 
					
						
						|  | self.noise_scheduler = craftsman.find(self.cfg.noise_scheduler_type)(**self.cfg.noise_scheduler) | 
					
						
						|  | self.denoise_scheduler = craftsman.find(self.cfg.denoise_scheduler_type)(**self.cfg.denoise_scheduler) | 
					
						
						|  |  | 
					
						
						|  | self.z_scale_factor = self.cfg.z_scale_factor | 
					
						
						|  |  | 
					
						
						|  | def forward(self, batch: Dict[str, Any]): | 
					
						
						|  |  | 
					
						
						|  | shape_embeds, kl_embed, posterior = self.shape_model.encode( | 
					
						
						|  | batch["surface"][..., :3 + self.cfg.shape_model.point_feats], | 
					
						
						|  | sample_posterior=True | 
					
						
						|  | ) | 
					
						
						|  | latents = kl_embed * self.z_scale_factor | 
					
						
						|  |  | 
					
						
						|  | cond_latents = self.condition(batch) | 
					
						
						|  | cond_latents = cond_latents.to(latents).view(latents.shape[0], -1, cond_latents.shape[-1]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise = torch.randn_like(latents).to(latents) | 
					
						
						|  | bs = latents.shape[0] | 
					
						
						|  |  | 
					
						
						|  | timesteps = torch.randint( | 
					
						
						|  | 0, | 
					
						
						|  | self.noise_scheduler.config.num_train_timesteps, | 
					
						
						|  | (bs,), | 
					
						
						|  | device=latents.device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps = timesteps.long() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_pred = self.denoiser_model(noisy_z, timesteps, cond_latents) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.noise_scheduler.config.prediction_type == "epsilon": | 
					
						
						|  | target = noise | 
					
						
						|  | elif self.noise_scheduler.config.prediction_type == "v_prediction": | 
					
						
						|  | target = self.noise_scheduler.get_velocity(latents, noise, timesteps) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") | 
					
						
						|  | if self.cfg.snr_gamma == 0: | 
					
						
						|  | if self.cfg.loss.loss_type == "l1": | 
					
						
						|  | loss = F.l1_loss(noise_pred, target, reduction="mean") | 
					
						
						|  | elif self.cfg.loss.loss_type in ["mse", "l2"]: | 
					
						
						|  | loss = F.mse_loss(noise_pred, target, reduction="mean") | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(f"Loss Type: {self.cfg.loss.loss_type} not yet supported.") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | snr = compute_snr(self.noise_scheduler, timesteps) | 
					
						
						|  | mse_loss_weights = torch.stack([snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1).min( | 
					
						
						|  | dim=1 | 
					
						
						|  | )[0] | 
					
						
						|  | if self.noise_scheduler.config.prediction_type == "epsilon": | 
					
						
						|  | mse_loss_weights = mse_loss_weights / snr | 
					
						
						|  | elif noise_scheduler.config.prediction_type == "v_prediction": | 
					
						
						|  | mse_loss_weights = mse_loss_weights / (snr + 1) | 
					
						
						|  |  | 
					
						
						|  | if self.cfg.loss.loss_type == "l1": | 
					
						
						|  | loss = F.l1_loss(noise_pred, target, reduction="none") | 
					
						
						|  | elif self.cfg.loss.loss_type in ["mse", "l2"]: | 
					
						
						|  | loss = F.mse_loss(noise_pred, target, reduction="none") | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(f"Loss Type: {self.cfg.loss.loss_type} not yet supported.") | 
					
						
						|  | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights | 
					
						
						|  | loss = loss.mean() | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | "loss_diffusion": loss, | 
					
						
						|  | "latents": latents, | 
					
						
						|  | "x_0": noisy_z, | 
					
						
						|  | "noise": noise, | 
					
						
						|  | "noise_pred": noise_pred, | 
					
						
						|  | "timesteps": timesteps, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def training_step(self, batch, batch_idx): | 
					
						
						|  | out = self(batch) | 
					
						
						|  |  | 
					
						
						|  | loss = 0. | 
					
						
						|  | for name, value in out.items(): | 
					
						
						|  | if name.startswith("loss_"): | 
					
						
						|  | self.log(f"train/{name}", value) | 
					
						
						|  | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) | 
					
						
						|  |  | 
					
						
						|  | for name, value in self.cfg.loss.items(): | 
					
						
						|  | if name.startswith("lambda_"): | 
					
						
						|  | self.log(f"train_params/{name}", self.C(value)) | 
					
						
						|  |  | 
					
						
						|  | return {"loss": loss} | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def validation_step(self, batch, batch_idx): | 
					
						
						|  | self.eval() | 
					
						
						|  |  | 
					
						
						|  | if get_rank() == 0: | 
					
						
						|  | sample_inputs = json.loads(open(self.cfg.val_samples_json).read()) | 
					
						
						|  | sample_inputs_ = copy.deepcopy(sample_inputs) | 
					
						
						|  | sample_outputs = self.sample(sample_inputs) | 
					
						
						|  | for i, sample_output in enumerate(sample_outputs): | 
					
						
						|  | mesh_v_f, has_surface = self.shape_model.extract_geometry(sample_output, octree_depth=7) | 
					
						
						|  | for j in range(len(mesh_v_f)): | 
					
						
						|  | if "text" in sample_inputs_ and "image" in sample_inputs_: | 
					
						
						|  | name = sample_inputs_["image"][j].split("/")[-1].replace(".png", "") | 
					
						
						|  | elif "text" in sample_inputs_ and "mvimage" in sample_inputs_: | 
					
						
						|  | name = sample_inputs_["mvimages"][j][0].split("/")[-2].replace(".png", "") | 
					
						
						|  | elif "text" in sample_inputs_: | 
					
						
						|  | name = sample_inputs_["text"][j].replace(" ", "_") | 
					
						
						|  | elif "image" in sample_inputs_: | 
					
						
						|  | name = sample_inputs_["image"][j].split("/")[-1].replace(".png", "") | 
					
						
						|  | elif "mvimages" in sample_inputs_: | 
					
						
						|  | name = sample_inputs_["mvimages"][j][0].split("/")[-2].replace(".png", "") | 
					
						
						|  | self.save_mesh( | 
					
						
						|  | f"it{self.true_global_step}/{name}_{i}.obj", | 
					
						
						|  | mesh_v_f[j][0], mesh_v_f[j][1] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | out = self(batch) | 
					
						
						|  | if self.global_step == 0: | 
					
						
						|  | latents = self.shape_model.decode(out["latents"]) | 
					
						
						|  | mesh_v_f, has_surface = self.shape_model.extract_geometry(latents) | 
					
						
						|  | self.save_mesh( | 
					
						
						|  | f"it{self.true_global_step}/{batch['uid'][0]}_{batch['sel_idx'][0] if 'sel_idx' in batch.keys() else 0}.obj", | 
					
						
						|  | mesh_v_f[0][0], mesh_v_f[0][1] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  | return {"val/loss": out["loss_diffusion"]} | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def sample(self, | 
					
						
						|  | sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]], | 
					
						
						|  | sample_times: int = 1, | 
					
						
						|  | steps: Optional[int] = None, | 
					
						
						|  | guidance_scale: Optional[float] = None, | 
					
						
						|  | eta: float = 0.0, | 
					
						
						|  | return_intermediates: bool = False, | 
					
						
						|  | camera_embeds: Optional[torch.Tensor] = None, | 
					
						
						|  | seed: Optional[int] = None, | 
					
						
						|  | **kwargs): | 
					
						
						|  |  | 
					
						
						|  | if steps is None: | 
					
						
						|  | steps = self.cfg.num_inference_steps | 
					
						
						|  | if guidance_scale is None: | 
					
						
						|  | guidance_scale = self.cfg.guidance_scale | 
					
						
						|  | do_classifier_free_guidance = guidance_scale > 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if "image" in sample_inputs: | 
					
						
						|  | sample_inputs["image"] = [Image.open(img) for img in sample_inputs["image"]] | 
					
						
						|  | cond = self.condition.encode_image(sample_inputs["image"]) | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | un_cond = self.condition.empty_image_embeds.repeat(len(sample_inputs["image"]), 1, 1).to(cond) | 
					
						
						|  | cond = torch.cat([un_cond, cond], dim=0) | 
					
						
						|  | elif "mvimages" in sample_inputs: | 
					
						
						|  | bs = len(sample_inputs["mvimages"]) | 
					
						
						|  | cond = [] | 
					
						
						|  | for image in sample_inputs["mvimages"]: | 
					
						
						|  | if isinstance(image, list) and isinstance(image[0], str): | 
					
						
						|  | sample_inputs["image"] = [Image.open(img) for img in image] | 
					
						
						|  | else: | 
					
						
						|  | sample_inputs["image"] = image | 
					
						
						|  | cond += [self.condition.encode_image(sample_inputs["image"])] | 
					
						
						|  | cond = torch.stack(cond, dim=0) | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | un_cond = self.condition.empty_image_embeds.unsqueeze(0).repeat(len(sample_inputs["mvimages"]), cond.shape[1] // self.condition.cfg.n_views, 1, 1).to(cond) | 
					
						
						|  | cond = torch.cat([un_cond, cond], dim=0).view(bs * 2, -1, cond[0].shape[-1]) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("Only text, image or mvimages condition is supported.") | 
					
						
						|  |  | 
					
						
						|  | outputs = [] | 
					
						
						|  | latents = None | 
					
						
						|  |  | 
					
						
						|  | if seed != None: | 
					
						
						|  | generator = torch.Generator(device="cuda").manual_seed(seed) | 
					
						
						|  | else: | 
					
						
						|  | generator = None | 
					
						
						|  |  | 
					
						
						|  | if not return_intermediates: | 
					
						
						|  | for _ in range(sample_times): | 
					
						
						|  | sample_loop = ddim_sample( | 
					
						
						|  | self.denoise_scheduler, | 
					
						
						|  | self.denoiser_model.eval(), | 
					
						
						|  | shape=self.shape_model.latent_shape, | 
					
						
						|  | cond=cond, | 
					
						
						|  | steps=steps, | 
					
						
						|  | guidance_scale=guidance_scale, | 
					
						
						|  | do_classifier_free_guidance=do_classifier_free_guidance, | 
					
						
						|  | device=self.device, | 
					
						
						|  | eta=eta, | 
					
						
						|  | disable_prog=False, | 
					
						
						|  | generator= generator | 
					
						
						|  | ) | 
					
						
						|  | for sample, t in sample_loop: | 
					
						
						|  | latents = sample | 
					
						
						|  | outputs.append(self.shape_model.decode(latents / self.z_scale_factor, **kwargs)) | 
					
						
						|  | else: | 
					
						
						|  | sample_loop = ddim_sample( | 
					
						
						|  | self.denoise_scheduler, | 
					
						
						|  | self.denoiser_model.eval(), | 
					
						
						|  | shape=self.shape_model.latent_shape, | 
					
						
						|  | cond=cond, | 
					
						
						|  | steps=steps, | 
					
						
						|  | guidance_scale=guidance_scale, | 
					
						
						|  | do_classifier_free_guidance=do_classifier_free_guidance, | 
					
						
						|  | device=self.device, | 
					
						
						|  | eta=eta, | 
					
						
						|  | disable_prog=False, | 
					
						
						|  | generator= generator | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | iter_size = steps // sample_times | 
					
						
						|  | i = 0 | 
					
						
						|  | for sample, t in sample_loop: | 
					
						
						|  | latents = sample | 
					
						
						|  | if i % iter_size == 0 or i == steps - 1: | 
					
						
						|  | outputs.append(self.shape_model.decode(latents / self.z_scale_factor, **kwargs)) | 
					
						
						|  | i += 1 | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def on_validation_epoch_end(self): | 
					
						
						|  | pass | 
					
						
						|  |  |