Spaces:
Running
Running
| import torch | |
| from typing import Literal, Optional | |
| from toolkit.basic import value_map | |
| from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
| from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| from toolkit.train_tools import get_torch_dtype | |
| from toolkit.config_modules import TrainConfig | |
| GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"] | |
| DIFFERENTIAL_SCALER = 0.2 | |
| # DIFFERENTIAL_SCALER = 0.25 | |
| def get_differential_mask( | |
| conditional_latents: torch.Tensor, | |
| unconditional_latents: torch.Tensor, | |
| threshold: float = 0.2, | |
| gradient: bool = False, | |
| ): | |
| # make a differential mask | |
| differential_mask = torch.abs(conditional_latents - unconditional_latents) | |
| if len(differential_mask.shape) == 4: | |
| max_differential = \ | |
| differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] | |
| elif len(differential_mask.shape) == 5: | |
| max_differential = \ | |
| differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0] | |
| differential_scaler = 1.0 / max_differential | |
| differential_mask = differential_mask * differential_scaler | |
| if gradient: | |
| # wew need to scale it to 0-1 | |
| # differential_mask = differential_mask - differential_mask.min() | |
| # differential_mask = differential_mask / differential_mask.max() | |
| # add 0.2 threshold to both sides and clip | |
| differential_mask = value_map( | |
| differential_mask, | |
| differential_mask.min(), | |
| differential_mask.max(), | |
| 0 - threshold, | |
| 1 + threshold | |
| ) | |
| differential_mask = torch.clamp(differential_mask, 0.0, 1.0) | |
| else: | |
| # make everything less than 0.2 be 0.0 and everything else be 1.0 | |
| differential_mask = torch.where( | |
| differential_mask < threshold, | |
| torch.zeros_like(differential_mask), | |
| torch.ones_like(differential_mask) | |
| ) | |
| return differential_mask | |
| def get_targeted_polarity_loss( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| **kwargs | |
| ): | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| with torch.no_grad(): | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| # inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True) | |
| # noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True) | |
| differential_scaler = DIFFERENTIAL_SCALER | |
| unconditional_diff = (unconditional_latents - conditional_latents) | |
| unconditional_diff_noise = unconditional_diff * differential_scaler | |
| conditional_diff = (conditional_latents - unconditional_latents) | |
| conditional_diff_noise = conditional_diff * differential_scaler | |
| conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False) | |
| unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) | |
| # | |
| baseline_conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| baseline_unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| conditional_noise = noise + unconditional_diff_noise | |
| unconditional_noise = noise + conditional_diff_noise | |
| conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| conditional_noise, | |
| timesteps | |
| ).detach() | |
| unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| unconditional_noise, | |
| timesteps | |
| ).detach() | |
| # double up everything to run it through all at once | |
| cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) | |
| cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) | |
| cat_timesteps = torch.cat([timesteps, timesteps], dim=0) | |
| # cat_baseline_noisy_latents = torch.cat( | |
| # [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents], | |
| # dim=0 | |
| # ) | |
| # Disable the LoRA network so we can predict parent network knowledge without it | |
| # sd.network.is_active = False | |
| # sd.unet.eval() | |
| # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. | |
| # This acts as our control to preserve the unaltered parts of the image. | |
| # baseline_prediction = sd.predict_noise( | |
| # latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(), | |
| # conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), | |
| # timestep=cat_timesteps, | |
| # guidance_scale=1.0, | |
| # **pred_kwargs # adapter residuals in here | |
| # ).detach() | |
| # conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0) | |
| # negative_network_weights = [weight * -1.0 for weight in network_weight_list] | |
| # positive_network_weights = [weight * 1.0 for weight in network_weight_list] | |
| # cat_network_weight_list = positive_network_weights + negative_network_weights | |
| # turn the LoRA network back on. | |
| sd.unet.train() | |
| # sd.network.is_active = True | |
| # sd.network.multiplier = cat_network_weight_list | |
| # do our prediction with LoRA active on the scaled guidance latents | |
| prediction = sd.predict_noise( | |
| latents=cat_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), | |
| timestep=cat_timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| # prediction = prediction - baseline_prediction | |
| pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) | |
| # pred_pos = pred_pos - conditional_baseline_prediction | |
| # pred_neg = pred_neg - unconditional_baseline_prediction | |
| pred_loss = torch.nn.functional.mse_loss( | |
| pred_pos.float(), | |
| conditional_noise.float(), | |
| reduction="none" | |
| ) | |
| pred_loss = pred_loss.mean([1, 2, 3]) | |
| pred_neg_loss = torch.nn.functional.mse_loss( | |
| pred_neg.float(), | |
| unconditional_noise.float(), | |
| reduction="none" | |
| ) | |
| pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) | |
| loss = pred_loss + pred_neg_loss | |
| loss = loss.mean() | |
| loss.backward() | |
| # detach it so parent class can run backward on no grads without throwing error | |
| loss = loss.detach() | |
| loss.requires_grad_(True) | |
| return loss | |
| def get_direct_guidance_loss( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: 'PromptEmbeds', | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| mask_multiplier=None, | |
| prior_pred=None, | |
| **kwargs | |
| ): | |
| with torch.no_grad(): | |
| # Perform targeted guidance (working title) | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| # target_noise, | |
| noise, | |
| timesteps | |
| ).detach() | |
| unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| # turn the LoRA network back on. | |
| sd.unet.train() | |
| # sd.network.is_active = True | |
| # sd.network.multiplier = network_weight_list | |
| # do our prediction with LoRA active on the scaled guidance latents | |
| if unconditional_embeds is not None: | |
| unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach() | |
| unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) | |
| prediction = sd.predict_noise( | |
| latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(), | |
| conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(), | |
| unconditional_embeddings=unconditional_embeds, | |
| timestep=torch.cat([timesteps, timesteps]), | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) | |
| guidance_scale = 1.1 | |
| guidance_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_cond - noise_pred_uncond | |
| ) | |
| guidance_loss = torch.nn.functional.mse_loss( | |
| guidance_pred.float(), | |
| noise.detach().float(), | |
| reduction="none" | |
| ) | |
| if mask_multiplier is not None: | |
| guidance_loss = guidance_loss * mask_multiplier | |
| guidance_loss = guidance_loss.mean([1, 2, 3]) | |
| guidance_loss = guidance_loss.mean() | |
| # loss = guidance_loss + masked_noise_loss | |
| loss = guidance_loss | |
| loss.backward() | |
| # detach it so parent class can run backward on no grads without throwing error | |
| loss = loss.detach() | |
| loss.requires_grad_(True) | |
| return loss | |
| # targeted | |
| def get_targeted_guidance_loss( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: 'PromptEmbeds', | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| **kwargs | |
| ): | |
| with torch.no_grad(): | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| # Encode the unconditional image into latents | |
| unconditional_noisy_latents = sd.noise_scheduler.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ) | |
| conditional_noisy_latents = sd.noise_scheduler.add_noise( | |
| conditional_latents, | |
| noise, | |
| timesteps | |
| ) | |
| # was_network_active = self.network.is_active | |
| sd.network.is_active = False | |
| sd.unet.eval() | |
| target_differential = unconditional_latents - conditional_latents | |
| # scale our loss by the differential scaler | |
| target_differential_abs = target_differential.abs() | |
| target_differential_abs_min = \ | |
| target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] | |
| target_differential_abs_max = \ | |
| target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] | |
| min_guidance = 1.0 | |
| max_guidance = 2.0 | |
| differential_scaler = value_map( | |
| target_differential_abs, | |
| target_differential_abs_min, | |
| target_differential_abs_max, | |
| min_guidance, | |
| max_guidance | |
| ).detach() | |
| # With LoRA network bypassed, predict noise to get a baseline of what the network | |
| # wants to do with the latents + noise. Pass our target latents here for the input. | |
| target_unconditional = sd.predict_noise( | |
| latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), | |
| timestep=timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ).detach() | |
| prior_prediction_loss = torch.nn.functional.mse_loss( | |
| target_unconditional.float(), | |
| noise.float(), | |
| reduction="none" | |
| ).detach().clone() | |
| # turn the LoRA network back on. | |
| sd.unet.train() | |
| sd.network.is_active = True | |
| sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list] | |
| # with LoRA active, predict the noise with the scaled differential latents added. This will allow us | |
| # the opportunity to predict the differential + noise that was added to the latents. | |
| prediction = sd.predict_noise( | |
| latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(), | |
| conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(), | |
| timestep=torch.cat([timesteps, timesteps], dim=0), | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0) | |
| conditional_loss = torch.nn.functional.mse_loss( | |
| prediction_conditional.float(), | |
| noise.float(), | |
| reduction="none" | |
| ) | |
| unconditional_loss = torch.nn.functional.mse_loss( | |
| prediction_unconditional.float(), | |
| noise.float(), | |
| reduction="none" | |
| ) | |
| positive_loss = torch.abs( | |
| conditional_loss.float() - prior_prediction_loss.float(), | |
| ) | |
| # scale our loss by the differential scaler | |
| positive_loss = positive_loss * differential_scaler | |
| positive_loss = positive_loss.mean([1, 2, 3]) | |
| polar_loss = torch.abs( | |
| conditional_loss.float() - unconditional_loss.float(), | |
| ).mean([1, 2, 3]) | |
| positive_loss = positive_loss.mean() + polar_loss.mean() | |
| positive_loss.backward() | |
| # loss = positive_loss.detach() + negative_loss.detach() | |
| loss = positive_loss.detach() | |
| # add a grad so other backward does not fail | |
| loss.requires_grad_(True) | |
| # restore network | |
| sd.network.multiplier = network_weight_list | |
| return loss | |
| def get_guided_loss_polarity( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| train_config: 'TrainConfig', | |
| scaler=None, | |
| **kwargs | |
| ): | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| with torch.no_grad(): | |
| dtype = get_torch_dtype(dtype) | |
| noise = noise.to(device, dtype=dtype).detach() | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| target_pos = noise | |
| target_neg = noise | |
| if sd.is_flow_matching: | |
| linear_timesteps = any([ | |
| train_config.linear_timesteps, | |
| train_config.linear_timesteps2, | |
| train_config.timestep_type == 'linear', | |
| ]) | |
| timestep_type = 'linear' if linear_timesteps else None | |
| if timestep_type is None: | |
| timestep_type = train_config.timestep_type | |
| sd.noise_scheduler.set_train_timesteps( | |
| 1000, | |
| device=device, | |
| timestep_type=timestep_type, | |
| latents=conditional_latents | |
| ) | |
| target_pos = (noise - conditional_latents).detach() | |
| target_neg = (noise - unconditional_latents).detach() | |
| conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) | |
| unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) | |
| # double up everything to run it through all at once | |
| cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) | |
| cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) | |
| cat_timesteps = torch.cat([timesteps, timesteps], dim=0) | |
| negative_network_weights = [weight * -1.0 for weight in network_weight_list] | |
| positive_network_weights = [weight * 1.0 for weight in network_weight_list] | |
| cat_network_weight_list = positive_network_weights + negative_network_weights | |
| # turn the LoRA network back on. | |
| sd.unet.train() | |
| sd.network.is_active = True | |
| sd.network.multiplier = cat_network_weight_list | |
| # do our prediction with LoRA active on the scaled guidance latents | |
| prediction = sd.predict_noise( | |
| latents=cat_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), | |
| timestep=cat_timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) | |
| pred_loss = torch.nn.functional.mse_loss( | |
| pred_pos.float(), | |
| target_pos.float(), | |
| reduction="none" | |
| ) | |
| # pred_loss = pred_loss.mean([1, 2, 3]) | |
| pred_neg_loss = torch.nn.functional.mse_loss( | |
| pred_neg.float(), | |
| target_neg.float(), | |
| reduction="none" | |
| ) | |
| loss = pred_loss + pred_neg_loss | |
| loss = loss.mean([1, 2, 3]) | |
| loss = loss.mean() | |
| if scaler is not None: | |
| scaler.scale(loss).backward() | |
| else: | |
| loss.backward() | |
| # detach it so parent class can run backward on no grads without throwing error | |
| loss = loss.detach() | |
| loss.requires_grad_(True) | |
| return loss | |
| def get_guided_tnt( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| prior_pred: torch.Tensor = None, | |
| **kwargs | |
| ): | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| with torch.no_grad(): | |
| dtype = get_torch_dtype(dtype) | |
| noise = noise.to(device, dtype=dtype).detach() | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| # double up everything to run it through all at once | |
| cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) | |
| cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) | |
| cat_timesteps = torch.cat([timesteps, timesteps], dim=0) | |
| # turn the LoRA network back on. | |
| sd.unet.train() | |
| if sd.network is not None: | |
| cat_network_weight_list = [weight for weight in network_weight_list * 2] | |
| sd.network.multiplier = cat_network_weight_list | |
| sd.network.is_active = True | |
| prediction = sd.predict_noise( | |
| latents=cat_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), | |
| timestep=cat_timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0) | |
| this_loss = torch.nn.functional.mse_loss( | |
| this_prediction.float(), | |
| noise.float(), | |
| reduction="none" | |
| ) | |
| that_loss = torch.nn.functional.mse_loss( | |
| that_prediction.float(), | |
| noise.float(), | |
| reduction="none" | |
| ) | |
| this_loss = this_loss.mean([1, 2, 3]) | |
| # negative loss on that | |
| that_loss = -that_loss.mean([1, 2, 3]) | |
| with torch.no_grad(): | |
| # match that loss with this loss so it is not a negative value and same scale | |
| that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss) | |
| that_loss = that_loss * that_loss_scaler * 0.01 | |
| loss = this_loss + that_loss | |
| loss = loss.mean() | |
| loss.backward() | |
| # detach it so parent class can run backward on no grads without throwing error | |
| loss = loss.detach() | |
| loss.requires_grad_(True) | |
| return loss | |
| def targeted_flow_guidance( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: 'PromptEmbeds', | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| mask_multiplier=None, | |
| prior_pred=None, | |
| scaler=None, | |
| train_config=None, | |
| **kwargs | |
| ): | |
| if not sd.is_flow_matching: | |
| raise ValueError("targeted_flow only works on flow matching models") | |
| dtype = get_torch_dtype(sd.torch_dtype) | |
| device = sd.device_torch | |
| with torch.no_grad(): | |
| dtype = get_torch_dtype(dtype) | |
| noise = noise.to(device, dtype=dtype).detach() | |
| conditional_latents = batch.latents.to(device, dtype=dtype).detach() | |
| unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() | |
| # get a mask on the differential of the latents | |
| # this will be scaled from 0.0-1.0 with 1.0 being the largest differential | |
| abs_differential_mask = get_differential_mask( | |
| conditional_latents, | |
| unconditional_latents, | |
| gradient=True | |
| ) | |
| # get noisy latents for both conditional and unconditional predictions | |
| unconditional_noisy_latents = sd.add_noise( | |
| unconditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) | |
| conditional_noisy_latents = sd.add_noise( | |
| conditional_latents, | |
| noise, | |
| timesteps | |
| ).detach() | |
| conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) | |
| # disable the lora to get a baseline prediction | |
| sd.network.is_active = False | |
| sd.unet.eval() | |
| # get a baseline prediction of the model knowledge without the lora network | |
| # we do this with the unconditional noisy latents | |
| baseline_prediction = sd.predict_noise( | |
| latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), | |
| timestep=timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs | |
| ).detach() | |
| # This is our normal flowmatching target | |
| # target = noise - latents | |
| # we need to target the baseline noise but with our conditional latents | |
| # to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target | |
| baseline_predicted_noise = baseline_prediction + unconditional_latents | |
| # baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image. | |
| # we use this as our new noise target to preserve the existing knowledge of the image. | |
| # we apply a mask to this noise to only allow the differential of the conditional latents to be learned | |
| baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise | |
| masked_noise = abs_differential_mask * noise | |
| target_noise = masked_noise + baseline_predicted_noise | |
| # compute our new target prediction using our current knowledge noise with our conditional latents | |
| # this makes it so the only new information is the differential of our conditional and unconditional latents | |
| # forcing the network to preserve existing knowledge, but learn only our changes | |
| target_pred = (target_noise - conditional_latents).detach() | |
| # make a prediction with the lora network active | |
| sd.unet.train() | |
| sd.network.is_active = True | |
| sd.network.multiplier = network_weight_list | |
| prediction = sd.predict_noise( | |
| latents=conditional_noisy_latents.to(device, dtype=dtype).detach(), | |
| conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), | |
| timestep=timesteps, | |
| guidance_scale=1.0, | |
| **pred_kwargs | |
| ) | |
| # target our baseline + diffirential noise target | |
| pred_loss = torch.nn.functional.mse_loss( | |
| prediction.float(), | |
| target_pred.float() | |
| ) | |
| return pred_loss | |
| # this processes all guidance losses based on the batch information | |
| def get_guidance_loss( | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: 'PromptEmbeds', | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| sd: 'StableDiffusion', | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| mask_multiplier=None, | |
| prior_pred=None, | |
| scaler=None, | |
| train_config=None, | |
| **kwargs | |
| ): | |
| # TODO add others and process individual batch items separately | |
| guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type | |
| if guidance_type == "targeted": | |
| assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance" | |
| return get_targeted_guidance_loss( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| **kwargs | |
| ) | |
| elif guidance_type == "polarity": | |
| assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" | |
| return get_guided_loss_polarity( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| scaler=scaler, | |
| train_config=train_config, | |
| **kwargs | |
| ) | |
| elif guidance_type == "tnt": | |
| assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" | |
| return get_guided_tnt( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| prior_pred=prior_pred, | |
| **kwargs | |
| ) | |
| elif guidance_type == "targeted_polarity": | |
| assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" | |
| return get_targeted_polarity_loss( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| **kwargs | |
| ) | |
| elif guidance_type == "direct": | |
| return get_direct_guidance_loss( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| unconditional_embeds=unconditional_embeds, | |
| mask_multiplier=mask_multiplier, | |
| prior_pred=prior_pred, | |
| **kwargs | |
| ) | |
| elif guidance_type == "targeted_flow": | |
| return targeted_flow_guidance( | |
| noisy_latents, | |
| conditional_embeds, | |
| match_adapter_assist, | |
| network_weight_list, | |
| timesteps, | |
| pred_kwargs, | |
| batch, | |
| noise, | |
| sd, | |
| unconditional_embeds=unconditional_embeds, | |
| mask_multiplier=mask_multiplier, | |
| prior_pred=prior_pred, | |
| scaler=scaler, | |
| train_config=train_config, | |
| **kwargs | |
| ) | |
| else: | |
| raise NotImplementedError(f"Guidance type {guidance_type} is not implemented") | |