vmem / modeling /network.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
8.26 kB
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from typing import Union
from modeling.modules.layers import (
Downsample,
GroupNorm32,
ResBlock,
TimestepEmbedSequential,
Upsample,
timestep_embedding,
)
from modeling.modules.transformer import MultiviewTransformer
@dataclass
class VMemModelParams(object):
in_channels: int = 11
model_channels: int = 320
out_channels: int = 4
num_frames: int = 8
num_res_blocks: int = 2
attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
num_head_channels: int = 64
transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
context_dim: int = 1024
dense_in_channels: int = 6
dropout: float = 0.0
unflatten_names: list[str] = field(
default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
)
def __post_init__(self):
assert len(self.channel_mult) == len(self.transformer_depth)
class VMemModel(nn.Module):
def __init__(self, params: VMemModelParams) -> None:
super().__init__()
self.params = params
self.model_channels = params.model_channels
self.out_channels = params.out_channels
self.num_head_channels = params.num_head_channels
time_embed_dim = params.model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(params.model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
)
]
)
self._feature_size = params.model_channels
input_block_chans = [params.model_channels]
ch = params.model_channels
ds = 1
for level, mult in enumerate(params.channel_mult):
for _ in range(params.num_res_blocks):
input_layers: list[Union[ResBlock, MultiviewTransformer, Downsample]] = [
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=mult * params.model_channels,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
)
]
ch = mult * params.model_channels
if ds in params.attention_resolutions:
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
input_layers.append(
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"input_ds{ds}",
depth=params.transformer_depth[level],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
)
)
self.input_blocks.append(TimestepEmbedSequential(*input_layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(params.channel_mult) - 1:
ds *= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
self._feature_size += ch
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=None,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
),
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"middle_ds{ds}",
depth=params.transformer_depth[-1],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
),
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=None,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(params.channel_mult))[::-1]:
for i in range(params.num_res_blocks + 1):
ich = input_block_chans.pop()
output_layers: list[Union[ResBlock, MultiviewTransformer, Upsample]] = [
ResBlock(
channels=ch + ich,
emb_channels=time_embed_dim,
out_channels=params.model_channels * mult,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
)
]
ch = params.model_channels * mult
if ds in params.attention_resolutions:
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
output_layers.append(
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"output_ds{ds}",
depth=params.transformer_depth[level],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
)
)
if level and i == params.num_res_blocks:
out_ch = ch
ds //= 2
output_layers.append(Upsample(ch, out_ch))
self.output_blocks.append(TimestepEmbedSequential(*output_layers))
self._feature_size += ch
self.out = nn.Sequential(
GroupNorm32(32, ch),
nn.SiLU(),
nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: torch.Tensor,
dense_y: torch.Tensor,
num_frames: Union[int, None] = None,
) -> torch.Tensor:
num_frames = num_frames or self.params.num_frames
t_emb = timestep_embedding(t, self.model_channels)
t_emb = self.time_embed(t_emb)
hs = []
h = x
for module in self.input_blocks:
h = module(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
hs.append(h)
h = self.middle_block(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
h = h.type(x.dtype)
return self.out(h)
class VMemWrapper(nn.Module):
def __init__(self, module: VMemModel):
super().__init__()
self.module = module
def forward(
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
) -> torch.Tensor:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
return self.module(
x,
t=t,
y=c["crossattn"],
dense_y=c["dense_vector"],
**kwargs,
)