WOUAF-Text-to-Image / attribution.py
wouaf's picture
Duplicate from mpatel57/WOUAF-Text-to-Image
b673263
raw
history blame
7.75 kB
import torch
import numpy as np
from torch_utils.ops import bias_act
from torch_utils import misc
def normalize_2nd_moment(x, dim=1, eps=1e-8):
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
class FullyConnectedLayer_normal(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias = True, # Apply additive bias before the activation function?
bias_init = 0, # Initial value for the additive bias.
):
super().__init__()
self.fc = torch.nn.Linear(in_features, out_features, bias=bias)
if bias:
with torch.no_grad():
self.fc.bias.fill_(bias_init)
def forward(self, x):
output = self.fc(x)
return output
class MappingNetwork_normal(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
int_dim,
num_layers = 8, # Number of mapping layers.
mapping_normalization = False #2nd normalization
):
super().__init__()
layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)]
for i in range(1, num_layers):
layers.append(torch.nn.Linear(int_dim, int_dim))
layers.append(torch.nn.LeakyReLU(0.2))
self.net = torch.nn.Sequential(*layers)
self.normalization = mapping_normalization
def forward(self, x):
if self.normalization:
x = normalize_2nd_moment(x)
output = self.net(x)
return output
class DecodingNetwork(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_dim,
num_layers = 8, # Number of mapping layers.
):
super().__init__()
layers = []
for i in range(num_layers-1):
layers.append(torch.nn.Linear(in_features, in_features))
layers.append(torch.nn.ReLU())
layers.append(torch.nn.Linear(in_features, out_dim))
self.net = torch.nn.Sequential(*layers)
def forward(self, x):
x = torch.nn.functional.normalize(x, dim=1)
output = self.net(x)
return output
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 1, # Learning rate multiplier.
bias_init = 0, # Initial value for the additive bias.
):
super().__init__()
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output, None = do not broadcast.
num_layers = 8, # Number of mapping layers.
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
normalization = None # Normalization input using normalize_2nd_moment
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
self.normalization = normalization
if embed_features is None:
embed_features = w_dim
if c_dim == 0:
embed_features = 0
if layer_features is None:
layer_features = w_dim
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
if c_dim > 0:
self.embed = FullyConnectedLayer(c_dim, embed_features)
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
if num_ws is not None and w_avg_beta is not None:
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
# Embed, normalize, and concat inputs.
x = None
with torch.autograd.profiler.record_function('input'):
if self.z_dim > 0:
misc.assert_shape(z, [None, self.z_dim])
if self.normalization:
x = normalize_2nd_moment(z.to(torch.float32))
else:
x = z
x = z.to(torch.float32)
if self.c_dim > 0:
raise ValueError("This implementation does not need class index")
misc.assert_shape(c, [None, self.c_dim])
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
y = self.embed(c.to(torch.float32))
x = torch.cat([x, y], dim=1) if x is not None else y
# Main layers.
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
# Update moving average of W.
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
with torch.autograd.profiler.record_function('update_w_avg'):
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
# Broadcast.
if self.num_ws is not None:
with torch.autograd.profiler.record_function('broadcast'):
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# Apply truncation.
if truncation_psi != 1:
with torch.autograd.profiler.record_function('truncate'):
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x