Spaces:
Running
on
Zero
Running
on
Zero
from typing import NamedTuple | |
import torch | |
import torch.nn.functional as F | |
def get_mean_shifted_latents( | |
latents: torch.Tensor, | |
shift: float = 0.11, | |
delta_shift: float = 0.1, | |
channels: list[int] = [0, 1, 1, 0], # list of {-1, 0, 1} | |
) -> torch.Tensor: | |
shifted_latents = latents.clone() | |
print("channels", channels) | |
for idx, sign in enumerate(channels): | |
if sign == 0: | |
# skip | |
continue | |
latent_channel = shifted_latents[:, idx, :, :] | |
positive_ratio = (latent_channel > 0).float().mean() | |
target_ratio = positive_ratio + shift * sign | |
# gradually shift latent_channel | |
while True: | |
latent_channel += delta_shift * sign | |
new_positive_ratio = (latent_channel > 0).float().mean() | |
if new_positive_ratio >= target_ratio: | |
break | |
# replace the channel in the original latents | |
shifted_latents[:, idx, :, :] = latent_channel | |
return shifted_latents | |
def get_2d_gaussian( | |
latent_height: int, | |
latent_width: int, | |
std_dev: float, | |
device: torch.device, | |
center_x: float = 0.0, | |
center_y: float = 0.0, | |
factor: int = 8, # idk why | |
): | |
y = torch.linspace(-1, 1, steps=latent_height // factor, device=device) | |
x = torch.linspace(-1, 1, steps=latent_width // factor, device=device) | |
y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") | |
x_grid = x_grid - center_x | |
y_grid = y_grid - center_y | |
gauss = torch.exp(-((x_grid**2 + y_grid**2) / (2 * std_dev**2))) | |
gauss = gauss[None, None, :, :] # add batch and channel dimensions | |
return gauss | |
def apply_tkg_noise( | |
latents: torch.Tensor, | |
shift: float = 0.11, | |
delta_shift: float = 0.1, | |
std_dev: float = 0.5, | |
factor: int = 8, | |
channels: list[int] = [0, 1, 1, 0], | |
): | |
batch_size, num_channels, latent_height, latent_width = latents.shape | |
shifted_latents = get_mean_shifted_latents( | |
latents, | |
shift=shift, | |
delta_shift=delta_shift, | |
channels=channels, | |
) | |
gauss_mask = get_2d_gaussian( | |
latent_height=latent_height, | |
latent_width=latent_width, | |
std_dev=std_dev, | |
center_x=0.0, | |
center_y=0.0, | |
factor=factor, | |
device=latents.device, | |
) | |
gauss_mask = F.interpolate( | |
gauss_mask, | |
size=(latent_height, latent_width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
gauss_mask = gauss_mask.expand(batch_size, num_channels, -1, -1) | |
noised_latents = shifted_latents * (1 - gauss_mask) + latents * gauss_mask | |
return noised_latents | |
class ColorSet(NamedTuple): | |
name: str | |
channels: list[int] | |
# ref: Figure 28. Additional Result in various color Background with SD | |
COLOR_SETS: list[ColorSet] = [ | |
ColorSet("green", [0, 1, 1, 0]), | |
ColorSet("cyan", [0, 1, 0, 0]), | |
ColorSet("magenta", [0, -1, -1, -1]), | |
ColorSet("purple", [0, 0, -1, -1]), | |
ColorSet("black", [-1, 0, 0, 1]), | |
ColorSet("orange", [-1, -1, 1, 0]), | |
ColorSet("white", [0, 0, 0, -1]), | |
ColorSet("yellow", [0, -1, 1, -1]), | |
] | |
COLOR_SET_MAP: dict[str, ColorSet] = {c.name: c for c in COLOR_SETS} | |