Spaces:
Paused
Paused
| import os | |
| import random | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers import DDPMScheduler, UNet2DConditionModel | |
| from diffusers.models import AutoencoderKL | |
| from diffusers.training_utils import compute_snr | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline | |
| from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
| from ..utils.core import find | |
| from ..utils.typing import * | |
| from .base import BaseSystem | |
| from .utils import encode_prompt, vae_encode | |
| def compute_embeddings( | |
| prompt_batch, | |
| empty_prompt_indices, | |
| text_encoders, | |
| tokenizers, | |
| is_train=True, | |
| **kwargs, | |
| ): | |
| original_size = kwargs["original_size"] | |
| target_size = kwargs["target_size"] | |
| crops_coords_top_left = kwargs["crops_coords_top_left"] | |
| for i in range(empty_prompt_indices.shape[0]): | |
| if empty_prompt_indices[i]: | |
| prompt_batch[i] = "" | |
| prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
| prompt_batch, text_encoders, tokenizers, 0, is_train | |
| ) | |
| add_text_embeds = pooled_prompt_embeds.to( | |
| device=prompt_embeds.device, dtype=prompt_embeds.dtype | |
| ) | |
| # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids]) | |
| add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) | |
| add_time_ids = add_time_ids.to( | |
| device=prompt_embeds.device, dtype=prompt_embeds.dtype | |
| ) | |
| unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | |
| return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} | |
| class IG2MVSDXLSystem(BaseSystem): | |
| class Config(BaseSystem.Config): | |
| # Model / Adapter | |
| pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0" | |
| pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix" | |
| pretrained_adapter_name_or_path: Optional[str] = None | |
| pretrained_unet_name_or_path: Optional[str] = None | |
| init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict) | |
| use_fp16_vae: bool = True | |
| use_fp16_clip: bool = True | |
| # Training | |
| trainable_modules: List[str] = field(default_factory=list) | |
| train_cond_encoder: bool = True | |
| prompt_drop_prob: float = 0.0 | |
| image_drop_prob: float = 0.0 | |
| cond_drop_prob: float = 0.0 | |
| gradient_checkpointing: bool = False | |
| # Noise sampler | |
| noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) | |
| noise_offset: float = 0.0 | |
| input_perturbation: float = 0.0 | |
| snr_gamma: Optional[float] = 5.0 | |
| prediction_type: Optional[str] = None | |
| shift_noise: bool = False | |
| shift_noise_mode: str = "interpolated" | |
| shift_noise_scale: float = 1.0 | |
| # Evaluation | |
| eval_seed: int = 0 | |
| eval_num_inference_steps: int = 30 | |
| eval_guidance_scale: float = 1.0 | |
| eval_height: int = 512 | |
| eval_width: int = 512 | |
| cfg: Config | |
| def configure(self): | |
| super().configure() | |
| # Prepare pipeline | |
| pipeline_kwargs = {} | |
| if self.cfg.pretrained_vae_name_or_path is not None: | |
| pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained( | |
| self.cfg.pretrained_vae_name_or_path | |
| ) | |
| if self.cfg.pretrained_unet_name_or_path is not None: | |
| pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained( | |
| self.cfg.pretrained_unet_name_or_path | |
| ) | |
| pipeline: IG2MVSDXLPipeline | |
| pipeline = IG2MVSDXLPipeline.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, **pipeline_kwargs | |
| ) | |
| init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs) | |
| if "self_attn_processor" in init_adapter_kwargs: | |
| self_attn_processor = init_adapter_kwargs["self_attn_processor"] | |
| if self_attn_processor is not None and isinstance(self_attn_processor, str): | |
| self_attn_processor = find(self_attn_processor) | |
| init_adapter_kwargs["self_attn_processor"] = self_attn_processor | |
| pipeline.init_custom_adapter(**init_adapter_kwargs) | |
| if self.cfg.pretrained_adapter_name_or_path: | |
| pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path) | |
| adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path) | |
| pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name) | |
| noise_scheduler = DDPMScheduler.from_config( | |
| pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs | |
| ) | |
| if self.cfg.shift_noise: | |
| noise_scheduler = ShiftSNRScheduler.from_scheduler( | |
| noise_scheduler, | |
| shift_mode=self.cfg.shift_noise_mode, | |
| shift_scale=self.cfg.shift_noise_scale, | |
| scheduler_class=DDPMScheduler, | |
| ) | |
| pipeline.scheduler = noise_scheduler | |
| # Prepare models | |
| self.pipeline: IG2MVSDXLPipeline = pipeline | |
| self.vae = self.pipeline.vae.to( | |
| dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32 | |
| ) | |
| self.tokenizer = self.pipeline.tokenizer | |
| self.tokenizer_2 = self.pipeline.tokenizer_2 | |
| self.text_encoder = self.pipeline.text_encoder.to( | |
| dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 | |
| ) | |
| self.text_encoder_2 = self.pipeline.text_encoder_2.to( | |
| dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 | |
| ) | |
| self.feature_extractor = self.pipeline.feature_extractor | |
| self.cond_encoder = self.pipeline.cond_encoder | |
| self.unet = self.pipeline.unet | |
| self.noise_scheduler = self.pipeline.scheduler | |
| self.inference_scheduler = DDPMScheduler.from_config( | |
| self.noise_scheduler.config | |
| ) | |
| self.pipeline.scheduler = self.inference_scheduler | |
| if self.cfg.prediction_type is not None: | |
| self.noise_scheduler.register_to_config( | |
| prediction_type=self.cfg.prediction_type | |
| ) | |
| # Prepare trainable / non-trainable modules | |
| trainable_modules = self.cfg.trainable_modules | |
| if trainable_modules and len(trainable_modules) > 0: | |
| self.unet.requires_grad_(False) | |
| for name, module in self.unet.named_modules(): | |
| for trainable_module in trainable_modules: | |
| if trainable_module in name: | |
| module.requires_grad_(True) | |
| else: | |
| self.unet.requires_grad_(True) | |
| self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder) | |
| self.vae.requires_grad_(False) | |
| self.text_encoder.requires_grad_(False) | |
| self.text_encoder_2.requires_grad_(False) | |
| # Others | |
| # Prepare gradient checkpointing | |
| if self.cfg.gradient_checkpointing: | |
| self.unet.enable_gradient_checkpointing() | |
| def forward( | |
| self, | |
| noisy_latents: Tensor, | |
| conditioning_pixel_values: Tensor, | |
| timesteps: Tensor, | |
| ref_latents: Tensor, | |
| prompts: List[str], | |
| num_views: int, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| bsz = noisy_latents.shape[0] | |
| b_samples = bsz // num_views | |
| num_batch_images = num_views | |
| prompt_drop_mask = ( | |
| torch.rand(b_samples, device=noisy_latents.device) | |
| < self.cfg.prompt_drop_prob | |
| ) | |
| image_drop_mask = ( | |
| torch.rand(b_samples, device=noisy_latents.device) | |
| < self.cfg.image_drop_prob | |
| ) | |
| cond_drop_mask = ( | |
| torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob | |
| ) | |
| prompt_drop_mask = prompt_drop_mask | cond_drop_mask | |
| image_drop_mask = image_drop_mask | cond_drop_mask | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): | |
| # Here, we compute not just the text embeddings but also the additional embeddings | |
| # needed for the SD XL UNet to operate. | |
| additional_embeds = compute_embeddings( | |
| prompts, | |
| prompt_drop_mask, | |
| [self.text_encoder, self.text_encoder_2], | |
| [self.tokenizer, self.tokenizer_2], | |
| **kwargs, | |
| ) | |
| # Process reference latents to obtain reference features | |
| with torch.no_grad(): | |
| ref_timesteps = torch.zeros_like(timesteps[:b_samples]) | |
| ref_hidden_states = {} | |
| self.unet( | |
| ref_latents, | |
| ref_timesteps, | |
| encoder_hidden_states=additional_embeds["prompt_embeds"], | |
| added_cond_kwargs={ | |
| "text_embeds": additional_embeds["text_embeds"], | |
| "time_ids": additional_embeds["time_ids"], | |
| }, | |
| cross_attention_kwargs={ | |
| "cache_hidden_states": ref_hidden_states, | |
| "use_mv": False, | |
| "use_ref": False, | |
| }, | |
| return_dict=False, | |
| ) | |
| for k, v in ref_hidden_states.items(): | |
| v_ = v | |
| v_[image_drop_mask] = 0.0 | |
| ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0) | |
| # Repeat additional embeddings for each image in the batch | |
| for key, value in additional_embeds.items(): | |
| kwargs[key] = value.repeat_interleave(num_batch_images, dim=0) | |
| conditioning_features = self.cond_encoder(conditioning_pixel_values) | |
| added_cond_kwargs = { | |
| "text_embeds": kwargs["text_embeds"], | |
| "time_ids": kwargs["time_ids"], | |
| } | |
| noise_pred = self.unet( | |
| noisy_latents, | |
| timesteps, | |
| encoder_hidden_states=kwargs["prompt_embeds"], | |
| added_cond_kwargs=added_cond_kwargs, | |
| down_intrablock_additional_residuals=conditioning_features, | |
| cross_attention_kwargs={ | |
| "ref_hidden_states": ref_hidden_states, | |
| "num_views": num_views, | |
| }, | |
| ).sample | |
| return {"noise_pred": noise_pred} | |
| def training_step(self, batch, batch_idx): | |
| num_views = batch["num_views"] | |
| vae_max_slice = 8 | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): | |
| latents = [] | |
| for i in range(0, batch["rgb"].shape[0], vae_max_slice): | |
| latents.append( | |
| vae_encode( | |
| self.vae, | |
| batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1, | |
| sample=True, | |
| apply_scale=True, | |
| ).float() | |
| ) | |
| latents = torch.cat(latents, dim=0) | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): | |
| ref_latents = vae_encode( | |
| self.vae, | |
| batch["reference_rgb"].to(self.vae.dtype) * 2 - 1, | |
| sample=True, | |
| apply_scale=True, | |
| ).float() | |
| bsz = latents.shape[0] | |
| b_samples = bsz // num_views | |
| noise = torch.randn_like(latents) | |
| if self.cfg.noise_offset is not None: | |
| # # https://www.crosslabs.org//blog/diffusion-with-offset-noise | |
| noise += self.cfg.noise_offset * torch.randn( | |
| (latents.shape[0], latents.shape[1], 1, 1), device=latents.device | |
| ) | |
| noise_mask = ( | |
| batch["noise_mask"] | |
| if "noise_mask" in batch | |
| else torch.ones((bsz,), dtype=torch.bool, device=latents.device) | |
| ) | |
| timesteps = torch.randint( | |
| 0, | |
| self.noise_scheduler.config.num_train_timesteps, | |
| (b_samples,), | |
| device=latents.device, | |
| dtype=torch.long, | |
| ) | |
| timesteps = timesteps.repeat_interleave(num_views) | |
| timesteps[~noise_mask] = 0 | |
| if self.cfg.input_perturbation is not None: | |
| new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise) | |
| noisy_latents = self.noise_scheduler.add_noise( | |
| latents, new_noise, timesteps | |
| ) | |
| else: | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| noisy_latents[~noise_mask] = latents[~noise_mask] | |
| if self.noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif self.noise_scheduler.config.prediction_type == "v_prediction": | |
| target = self.noise_scheduler.get_velocity(latents, noise, timesteps) | |
| else: | |
| raise ValueError( | |
| f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}" | |
| ) | |
| model_pred = self( | |
| noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch | |
| )["noise_pred"] | |
| model_pred = model_pred[noise_mask] | |
| target = target[noise_mask] | |
| if self.cfg.snr_gamma is None: | |
| loss = F.mse_loss(model_pred, target, reduction="mean") | |
| else: | |
| # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. | |
| # Since we predict the noise instead of x_0, the original formulation is slightly changed. | |
| # This is discussed in Section 4.2 of the same paper. | |
| snr = compute_snr(self.noise_scheduler, timesteps) | |
| if self.noise_scheduler.config.prediction_type == "v_prediction": | |
| # Velocity objective requires that we add one to SNR values before we divide by them. | |
| snr = snr + 1 | |
| mse_loss_weights = ( | |
| torch.stack( | |
| [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 | |
| ).min(dim=1)[0] | |
| / snr | |
| ) | |
| loss = F.mse_loss(model_pred, target, reduction="none") | |
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights | |
| loss = loss.mean() | |
| self.log("train/loss", loss, prog_bar=True) | |
| # will execute self.on_check_train every self.cfg.check_train_every_n_steps steps | |
| self.check_train(batch) | |
| return {"loss": loss} | |
| def on_train_batch_end(self, outputs, batch, batch_idx): | |
| pass | |
| def get_input_visualizations(self, batch): | |
| return [ | |
| { | |
| "type": "rgb", | |
| "img": rearrange( | |
| batch["source_rgb"], | |
| "(B N) C H W -> (B H) (N W) C", | |
| N=batch["num_views"], | |
| ), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| { | |
| "type": "rgb", | |
| "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| { | |
| "type": "rgb", | |
| "img": rearrange( | |
| batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] | |
| ), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| def get_output_visualizations(self, batch, outputs): | |
| images = [ | |
| { | |
| "type": "rgb", | |
| "img": rearrange( | |
| batch["source_rgb"], | |
| "(B N) C H W -> (B H) (N W) C", | |
| N=batch["num_views"], | |
| ), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| { | |
| "type": "rgb", | |
| "img": rearrange( | |
| batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] | |
| ), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| { | |
| "type": "rgb", | |
| "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| { | |
| "type": "rgb", | |
| "img": rearrange( | |
| outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] | |
| ), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| return images | |
| def generate_images(self, batch, **kwargs): | |
| return self.pipeline( | |
| prompt=batch["prompts"], | |
| control_image=batch["source_rgb"], | |
| num_images_per_prompt=batch["num_views"], | |
| generator=torch.Generator(device=self.device).manual_seed( | |
| self.cfg.eval_seed | |
| ), | |
| num_inference_steps=self.cfg.eval_num_inference_steps, | |
| guidance_scale=self.cfg.eval_guidance_scale, | |
| height=self.cfg.eval_height, | |
| width=self.cfg.eval_width, | |
| reference_image=batch["reference_rgb"], | |
| output_type="pt", | |
| ).images | |
| def on_save_checkpoint(self, checkpoint): | |
| if self.global_rank == 0: | |
| self.pipeline.save_custom_adapter( | |
| os.path.dirname(self.get_save_dir()), | |
| "step1x-3d-ig2v.safetensors", | |
| safe_serialization=True, | |
| include_keys=self.cfg.trainable_modules, | |
| ) | |
| def on_check_train(self, batch): | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-train.jpg", | |
| self.get_input_visualizations(batch), | |
| name="train_step_input", | |
| step=self.true_global_step, | |
| ) | |
| def validation_step(self, batch, batch_idx): | |
| out = self.generate_images(batch) | |
| if ( | |
| self.cfg.check_val_limit_rank > 0 | |
| and self.global_rank < self.cfg.check_val_limit_rank | |
| ): | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", | |
| self.get_output_visualizations(batch, out), | |
| name=f"validation_step_output_{self.global_rank}_{batch_idx}", | |
| step=self.true_global_step, | |
| ) | |
| def on_validation_epoch_end(self): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| out = self.generate_images(batch) | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg", | |
| self.get_output_visualizations(batch, out), | |
| name=f"test_step_output_{self.global_rank}_{batch_idx}", | |
| step=self.true_global_step, | |
| ) | |
| def on_test_end(self): | |
| pass | |