Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,192 Bytes
e518b27 7c61bf1 e518b27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from typing import NamedTuple
import torch
import torch.nn.functional as F
def get_mean_shifted_latents(
latents: torch.Tensor,
shift: float = 0.11,
delta_shift: float = 0.1,
channels: list[int] = [0, 1, 1, 0], # list of {-1, 0, 1}
) -> torch.Tensor:
shifted_latents = latents.clone()
print("channels", channels)
for idx, sign in enumerate(channels):
if sign == 0:
# skip
continue
latent_channel = shifted_latents[:, idx, :, :]
positive_ratio = (latent_channel > 0).float().mean()
target_ratio = positive_ratio + shift * sign
# gradually shift latent_channel
while True:
latent_channel += delta_shift * sign
new_positive_ratio = (latent_channel > 0).float().mean()
if new_positive_ratio >= target_ratio:
break
# replace the channel in the original latents
shifted_latents[:, idx, :, :] = latent_channel
return shifted_latents
def get_2d_gaussian(
latent_height: int,
latent_width: int,
std_dev: float,
device: torch.device,
center_x: float = 0.0,
center_y: float = 0.0,
factor: int = 8, # idk why
):
y = torch.linspace(-1, 1, steps=latent_height // factor, device=device)
x = torch.linspace(-1, 1, steps=latent_width // factor, device=device)
y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
x_grid = x_grid - center_x
y_grid = y_grid - center_y
gauss = torch.exp(-((x_grid**2 + y_grid**2) / (2 * std_dev**2)))
gauss = gauss[None, None, :, :] # add batch and channel dimensions
return gauss
def apply_tkg_noise(
latents: torch.Tensor,
shift: float = 0.11,
delta_shift: float = 0.1,
std_dev: float = 0.5,
factor: int = 8,
channels: list[int] = [0, 1, 1, 0],
):
batch_size, num_channels, latent_height, latent_width = latents.shape
shifted_latents = get_mean_shifted_latents(
latents,
shift=shift,
delta_shift=delta_shift,
channels=channels,
)
gauss_mask = get_2d_gaussian(
latent_height=latent_height,
latent_width=latent_width,
std_dev=std_dev,
center_x=0.0,
center_y=0.0,
factor=factor,
device=latents.device,
)
gauss_mask = F.interpolate(
gauss_mask,
size=(latent_height, latent_width),
mode="bilinear",
align_corners=False,
)
gauss_mask = gauss_mask.expand(batch_size, num_channels, -1, -1)
noised_latents = shifted_latents * (1 - gauss_mask) + latents * gauss_mask
return noised_latents
class ColorSet(NamedTuple):
name: str
channels: list[int]
# ref: Figure 28. Additional Result in various color Background with SD
COLOR_SETS: list[ColorSet] = [
ColorSet("green", [0, 1, 1, 0]),
ColorSet("cyan", [0, 1, 0, 0]),
ColorSet("magenta", [0, -1, -1, -1]),
ColorSet("purple", [0, 0, -1, -1]),
ColorSet("black", [-1, 0, 0, 1]),
ColorSet("orange", [-1, -1, 1, 0]),
ColorSet("white", [0, 0, 0, -1]),
ColorSet("yellow", [0, -1, 1, -1]),
]
COLOR_SET_MAP: dict[str, ColorSet] = {c.name: c for c in COLOR_SETS}
|