fffiloni's picture
Upload 25 files
e394497 verified
from pathlib import Path
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
class PoseNet(nn.Module):
"""a tiny conv network for introducing pose sequence as the condition
"""
def __init__(self, noise_latent_channels=320, *args, **kwargs):
super().__init__(*args, **kwargs)
# multiple convolution layers
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.SiLU()
)
# Final projection layer
self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
# Initialize layers
self._initialize_weights()
self.scale = nn.Parameter(torch.ones(1) * 2)
def _initialize_weights(self):
"""Initialize weights with He. initialization and zero out the biases
"""
for m in self.conv_layers:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
if m.bias is not None:
init.zeros_(m.bias)
init.zeros_(self.final_proj.weight)
if self.final_proj.bias is not None:
init.zeros_(self.final_proj.bias)
def forward(self, x):
if x.ndim == 5:
x = einops.rearrange(x, "b f c h w -> (b f) c h w")
x = self.conv_layers(x)
x = self.final_proj(x)
return x * self.scale
@classmethod
def from_pretrained(cls, pretrained_model_path):
"""load pretrained pose-net weights
"""
if not Path(pretrained_model_path).exists():
print(f"There is no model file in {pretrained_model_path}")
print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.")
state_dict = torch.load(pretrained_model_path, map_location="cpu")
model = PoseNet(noise_latent_channels=320)
model.load_state_dict(state_dict, strict=True)
return model