Spaces:
Paused
Paused
import torch | |
import numpy as np | |
from torch_utils.ops import bias_act | |
from torch_utils import misc | |
def normalize_2nd_moment(x, dim=1, eps=1e-8): | |
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | |
class FullyConnectedLayer_normal(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
out_features, # Number of output features. | |
bias = True, # Apply additive bias before the activation function? | |
bias_init = 0, # Initial value for the additive bias. | |
): | |
super().__init__() | |
self.fc = torch.nn.Linear(in_features, out_features, bias=bias) | |
if bias: | |
with torch.no_grad(): | |
self.fc.bias.fill_(bias_init) | |
def forward(self, x): | |
output = self.fc(x) | |
return output | |
class MappingNetwork_normal(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
int_dim, | |
num_layers = 8, # Number of mapping layers. | |
mapping_normalization = False #2nd normalization | |
): | |
super().__init__() | |
layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)] | |
for i in range(1, num_layers): | |
layers.append(torch.nn.Linear(int_dim, int_dim)) | |
layers.append(torch.nn.LeakyReLU(0.2)) | |
self.net = torch.nn.Sequential(*layers) | |
self.normalization = mapping_normalization | |
def forward(self, x): | |
if self.normalization: | |
x = normalize_2nd_moment(x) | |
output = self.net(x) | |
return output | |
class DecodingNetwork(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
out_dim, | |
num_layers = 8, # Number of mapping layers. | |
): | |
super().__init__() | |
layers = [] | |
for i in range(num_layers-1): | |
layers.append(torch.nn.Linear(in_features, in_features)) | |
layers.append(torch.nn.ReLU()) | |
layers.append(torch.nn.Linear(in_features, out_dim)) | |
self.net = torch.nn.Sequential(*layers) | |
def forward(self, x): | |
x = torch.nn.functional.normalize(x, dim=1) | |
output = self.net(x) | |
return output | |
class FullyConnectedLayer(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
out_features, # Number of output features. | |
bias = True, # Apply additive bias before the activation function? | |
activation = 'linear', # Activation function: 'relu', 'lrelu', etc. | |
lr_multiplier = 1, # Learning rate multiplier. | |
bias_init = 0, # Initial value for the additive bias. | |
): | |
super().__init__() | |
self.activation = activation | |
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | |
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | |
self.weight_gain = lr_multiplier / np.sqrt(in_features) | |
self.bias_gain = lr_multiplier | |
def forward(self, x): | |
w = self.weight.to(x.dtype) * self.weight_gain | |
b = self.bias | |
if b is not None: | |
b = b.to(x.dtype) | |
if self.bias_gain != 1: | |
b = b * self.bias_gain | |
if self.activation == 'linear' and b is not None: | |
x = torch.addmm(b.unsqueeze(0), x, w.t()) | |
else: | |
x = x.matmul(w.t()) | |
x = bias_act.bias_act(x, b, act=self.activation) | |
return x | |
class MappingNetwork(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality, 0 = no latent. | |
c_dim, # Conditioning label (C) dimensionality, 0 = no label. | |
w_dim, # Intermediate latent (W) dimensionality. | |
num_ws, # Number of intermediate latents to output, None = do not broadcast. | |
num_layers = 8, # Number of mapping layers. | |
embed_features = None, # Label embedding dimensionality, None = same as w_dim. | |
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. | |
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. | |
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. | |
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. | |
normalization = None # Normalization input using normalize_2nd_moment | |
): | |
super().__init__() | |
self.z_dim = z_dim | |
self.c_dim = c_dim | |
self.w_dim = w_dim | |
self.num_ws = num_ws | |
self.num_layers = num_layers | |
self.w_avg_beta = w_avg_beta | |
self.normalization = normalization | |
if embed_features is None: | |
embed_features = w_dim | |
if c_dim == 0: | |
embed_features = 0 | |
if layer_features is None: | |
layer_features = w_dim | |
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] | |
if c_dim > 0: | |
self.embed = FullyConnectedLayer(c_dim, embed_features) | |
for idx in range(num_layers): | |
in_features = features_list[idx] | |
out_features = features_list[idx + 1] | |
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) | |
setattr(self, f'fc{idx}', layer) | |
if num_ws is not None and w_avg_beta is not None: | |
self.register_buffer('w_avg', torch.zeros([w_dim])) | |
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): | |
# Embed, normalize, and concat inputs. | |
x = None | |
with torch.autograd.profiler.record_function('input'): | |
if self.z_dim > 0: | |
misc.assert_shape(z, [None, self.z_dim]) | |
if self.normalization: | |
x = normalize_2nd_moment(z.to(torch.float32)) | |
else: | |
x = z | |
x = z.to(torch.float32) | |
if self.c_dim > 0: | |
raise ValueError("This implementation does not need class index") | |
misc.assert_shape(c, [None, self.c_dim]) | |
y = normalize_2nd_moment(self.embed(c.to(torch.float32))) | |
y = self.embed(c.to(torch.float32)) | |
x = torch.cat([x, y], dim=1) if x is not None else y | |
# Main layers. | |
for idx in range(self.num_layers): | |
layer = getattr(self, f'fc{idx}') | |
x = layer(x) | |
# Update moving average of W. | |
if self.w_avg_beta is not None and self.training and not skip_w_avg_update: | |
with torch.autograd.profiler.record_function('update_w_avg'): | |
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | |
# Broadcast. | |
if self.num_ws is not None: | |
with torch.autograd.profiler.record_function('broadcast'): | |
x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | |
# Apply truncation. | |
if truncation_psi != 1: | |
with torch.autograd.profiler.record_function('truncate'): | |
assert self.w_avg_beta is not None | |
if self.num_ws is None or truncation_cutoff is None: | |
x = self.w_avg.lerp(x, truncation_psi) | |
else: | |
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | |
return x | |