makeavid-sd-jax / makeavid_sd /flax_impl /flax_unet_pseudo3d_condition.py
lopho's picture
forgot about the nested package structure
b2f876f
from typing import Tuple, Union
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
from diffusers.models.modeling_flax_utils import FlaxModelMixin
from diffusers.utils import BaseOutput
from .flax_unet_pseudo3d_blocks import (
CrossAttnDownBlockPseudo3D,
CrossAttnUpBlockPseudo3D,
DownBlockPseudo3D,
UpBlockPseudo3D,
UNetMidBlockPseudo3DCrossAttn
)
#from flax_embeddings import (
# TimestepEmbedding,
# Timesteps
#)
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .flax_resnet_pseudo3d import ConvPseudo3D
class UNetPseudo3DConditionOutput(BaseOutput):
sample: jax.Array
@flax_register_to_config
class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size: Union[int, Tuple[int, int]] = (64, 64)
in_channels: int = 4
out_channels: int = 4
down_block_types: Tuple[str] = (
"CrossAttnDownBlockPseudo3D",
"CrossAttnDownBlockPseudo3D",
"CrossAttnDownBlockPseudo3D",
"DownBlockPseudo3D"
)
up_block_types: Tuple[str] = (
"UpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D"
)
block_out_channels: Tuple[int] = (
320,
640,
1280,
1280
)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
cross_attention_dim: int = 768
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
param_dtype: str = 'float32'
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
if self.param_dtype == 'bfloat16':
param_dtype = jnp.bfloat16
elif self.param_dtype == 'float16':
param_dtype = jnp.float16
elif self.param_dtype == 'float32':
param_dtype = jnp.float32
else:
raise ValueError(f'unknown parameter type: {self.param_dtype}')
sample_size = self.sample_size
if isinstance(sample_size, int):
sample_size = (sample_size, sample_size)
sample_shape = (1, self.in_channels, 1, *sample_size)
sample = jnp.zeros(sample_shape, dtype = param_dtype)
timesteps = jnp.ones((1, ), dtype = jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
params_rng, dropout_rng = jax.random.split(rng)
rngs = { "params": params_rng, "dropout": dropout_rng }
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self) -> None:
if isinstance(self.attention_head_dim, int):
attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
else:
attention_head_dim = self.attention_head_dim
time_embed_dim = self.block_out_channels[0] * 4
self.conv_in = ConvPseudo3D(
features = self.block_out_channels[0],
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
self.time_proj = FlaxTimesteps(
dim = self.block_out_channels[0],
flip_sin_to_cos = self.flip_sin_to_cos,
freq_shift = self.freq_shift
)
self.time_embedding = FlaxTimestepEmbedding(
time_embed_dim = time_embed_dim,
dtype = self.dtype
)
down_blocks = []
output_channels = self.block_out_channels[0]
for i, down_block_type in enumerate(self.down_block_types):
input_channels = output_channels
output_channels = self.block_out_channels[i]
is_final_block = i == len(self.block_out_channels) - 1
# allows loading 3d models with old layer type names in their configs
# eg. 2D instead of Pseudo3D, like lxj's timelapse model
if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
down_block = CrossAttnDownBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
num_layers = self.layers_per_block,
attn_num_head_channels = attention_head_dim[i],
add_downsample = not is_final_block,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
down_block = DownBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
num_layers = self.layers_per_block,
add_downsample = not is_final_block,
dtype = self.dtype
)
else:
raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
down_blocks.append(down_block)
self.down_blocks = down_blocks
self.mid_block = UNetMidBlockPseudo3DCrossAttn(
in_channels = self.block_out_channels[-1],
attn_num_head_channels = attention_head_dim[-1],
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
up_blocks = []
reversed_block_out_channels = list(reversed(self.block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channels = output_channels
output_channels = reversed_block_out_channels[i]
input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
is_final_block = i == len(self.block_out_channels) - 1
if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
up_block = CrossAttnUpBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
prev_output_channels = prev_output_channels,
num_layers = self.layers_per_block + 1,
attn_num_head_channels = reversed_attention_head_dim[i],
add_upsample = not is_final_block,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
up_block = UpBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
prev_output_channels = prev_output_channels,
num_layers = self.layers_per_block + 1,
add_upsample = not is_final_block,
dtype = self.dtype
)
else:
raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
up_blocks.append(up_block)
self.up_blocks = up_blocks
self.conv_norm_out = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv_out = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self,
sample: jax.Array,
timesteps: jax.Array,
encoder_hidden_states: jax.Array,
return_dict: bool = True
) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
if timesteps.dtype != jnp.float32:
timesteps = timesteps.astype(dtype = jnp.float32)
if len(timesteps.shape) == 0:
timesteps = jnp.expand_dims(timesteps, 0)
# b,c,f,h,w -> b,f,h,w,c
sample = sample.transpose((0, 2, 3, 4, 1))
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
sample = self.conv_in(sample)
down_block_res_samples = (sample, )
for down_block in self.down_blocks:
if isinstance(down_block, CrossAttnDownBlockPseudo3D):
sample, res_samples = down_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states
)
elif isinstance(down_block, DownBlockPseudo3D):
sample, res_samples = down_block(
hidden_states = sample,
temb = t_emb
)
else:
raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
down_block_res_samples += res_samples
sample = self.mid_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states
)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
if isinstance(up_block, CrossAttnUpBlockPseudo3D):
sample = up_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states,
res_hidden_states_tuple = res_samples
)
elif isinstance(up_block, UpBlockPseudo3D):
sample = up_block(
hidden_states = sample,
temb = t_emb,
res_hidden_states_tuple = res_samples
)
else:
raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
# b,f,h,w,c -> b,c,f,h,w
sample = sample.transpose((0, 4, 1, 2, 3))
if not return_dict:
return (sample, )
return UNetPseudo3DConditionOutput(sample = sample)