MeanAudio / meanaudio /model /networks.py
AndreasXi's picture
update new model versions and test
0ff9928
raw
history blame
26.1 kB
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from meanaudio.ext.rotary_embeddings import compute_rope_rotations
from meanaudio.model.embeddings import TimestepEmbedder
from meanaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from meanaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
log = logging.getLogger()
@dataclass
class PreprocessedConditions:
text_f: torch.Tensor
text_f_c: torch.Tensor
class FluxAudio(nn.Module):
# Flux style latent transformer for TTA, single time step embedding
def __init__(self,
*,
latent_dim: int,
text_dim: int,
text_c_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 77,
latent_mean: Optional[torch.Tensor] = None,
latent_std: Optional[torch.Tensor] = None,
empty_string_feat: Optional[torch.Tensor] = None,
empty_string_feat_c: Optional[torch.Tensor] = None,
use_rope: bool = False) -> None:
super().__init__()
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.use_rope = use_rope
self.mm_depth = depth - fused_depth
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Sequential(
nn.Linear(text_c_dim, hidden_dim),
MLP(hidden_dim, hidden_dim*4)
)
self.final_layer = FinalBlock(hidden_dim, latent_dim)
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
self.joint_blocks = nn.ModuleList([
JointBlock(hidden_dim,
num_heads,
mlp_ratio=mlp_ratio,
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) # last layer is pre-only (only appllied to text and vision)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
if latent_mean is None:
# these values are not meant to be used
# if you don't provide mean/std here, we should load them later from a checkpoint
assert latent_std is None
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
else:
assert latent_std is not None
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
if empty_string_feat_c is None:
empty_string_feat_c = torch.zeros((text_c_dim))
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}'
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}'
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False)
self.initialize_weights()
if self.use_rope:
log.info("Network: Enabling RoPE embeddings")
self.initialize_rotations()
else:
log.info("Network: RoPE embedding disabled")
self.latent_rot = None
self.text_rot = None
def initialize_rotations(self):
base_freq = 1.0
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
text_rot = compute_rope_rotations(self._text_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
self.latent_rot = nn.Buffer(latent_rot, persistent=False) # will not be saved into state dict
self.text_rot = nn.Buffer(text_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
if self.use_rope:
self.initialize_rotations() # after changing seq_len we need to re-initialize RoPE to match new seq_len
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) # the linear layer -> 6 coefficients
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# return (x - self.latent_mean) / self.latent_std
return x.sub_(self.latent_mean).div_(self.latent_std)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
# return x * self.latent_std + self.latent_mean
return x.mul_(self.latent_std).add_(self.latent_mean)
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions:
"""
cache computations that do not depend on the latent/time step
i.e., the features are reused over steps during inference
"""
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# get global and local text features
# NOTE here the order of projection has been changed so global and local features are projected seperately
text_f_c = self.text_cond_proj(text_f_c) # (B, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
return PreprocessedConditions(text_f=text_f,
text_f_c=text_f_c)
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,
conditions: PreprocessedConditions) -> torch.Tensor:
"""
for non-cacheable computations
"""
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
text_f = conditions.text_f
text_f_c = conditions.text_f_c
latent = self.audio_input_proj(latent) # (B, N, D)
global_c = self.t_embed(t).unsqueeze(1) + text_f_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # extended_c: Latent_c, global_c: Text_c
for block in self.joint_blocks:
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) # (B, N, D)
for block in self.fused_blocks:
latent = block(latent, extended_c, self.latent_rot)
flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t
return flow
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
latent: (B, N, C)
text_f: (B, T, D)
t: (B,)
"""
conditions = self.preprocess_conditions(text_f, text_f_c) # cachable operations
flow = self.predict_flow(latent, t, conditions) # non-cachable operations
return flow
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) # (b, d)
def get_empty_conditions(
self,
bs: int,
*,
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
if negative_text_features is not None:
empty_string_feat, empty_string_feat_c = negative_text_features
else:
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1)
conditions = self.preprocess_conditions(empty_string_feat,
empty_string_feat_c) # use encoder's empty features
if negative_text_features is None:
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
return conditions
def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
if cfg_strength < 1.0:
return self.predict_flow(latent, t, conditions)
else:
return (cfg_strength * self.predict_flow(latent, t, conditions) +
(1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions))
def load_weights(self, src_dict) -> None:
if 't_embed.freqs' in src_dict:
del src_dict['t_embed.freqs']
if 'latent_rot' in src_dict:
del src_dict['latent_rot']
if 'text_rot' in src_dict:
del src_dict['text_rot']
if 'empty_string_feat_c' not in src_dict.keys(): # FIXME: issue of version mismatch here
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0)
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.latent_mean.device
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
class MeanAudio(nn.Module):
# Flux style latent transformer for TTA, dual time step embedding
def __init__(self,
*,
latent_dim: int,
text_dim: int,
text_c_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 77,
latent_mean: Optional[torch.Tensor] = None,
latent_std: Optional[torch.Tensor] = None,
empty_string_feat: Optional[torch.Tensor] = None,
empty_string_feat_c: Optional[torch.Tensor] = None,
use_rope: bool = False) -> None:
super().__init__()
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.use_rope = use_rope
self.mm_depth = depth - fused_depth
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Sequential(
nn.Linear(text_c_dim, hidden_dim),
MLP(hidden_dim, hidden_dim*4)
)
self.final_layer = FinalBlock(hidden_dim, latent_dim)
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
#add
self.r_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
self.joint_blocks = nn.ModuleList([
JointBlock(hidden_dim,
num_heads,
mlp_ratio=mlp_ratio,
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) # last layer is pre-only (only appllied to text and vision)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
if latent_mean is None:
# these values are not meant to be used
# if you don't provide mean/std here, we should load them later from a checkpoint
assert latent_std is None
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
else:
assert latent_std is not None
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
if empty_string_feat_c is None:
empty_string_feat_c = torch.zeros((text_c_dim))
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}'
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}'
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False)
self.initialize_weights()
if self.use_rope:
log.info("Network: Enabling RoPE embeddings")
self.initialize_rotations()
else:
log.info("Network: RoPE embedding disabled")
self.latent_rot = None
self.text_rot = None
def initialize_rotations(self):
base_freq = 1.0
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
text_rot = compute_rope_rotations(self._text_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
self.latent_rot = nn.Buffer(latent_rot, persistent=False) # will not be saved into state dict
self.text_rot = nn.Buffer(text_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
if self.use_rope:
self.initialize_rotations() # after changing seq_len we need to re-initialize RoPE to match new seq_len
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) # the linear layer -> 6 coefficients
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# return (x - self.latent_mean) / self.latent_std
return x.sub_(self.latent_mean).div_(self.latent_std)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
# return x * self.latent_std + self.latent_mean
return x.mul_(self.latent_std).add_(self.latent_mean)
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions:
"""
cache computations that do not depend on the latent/time step
i.e., the features are reused over steps during inference
"""
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# get global and local text features
# NOTE here the order of projection has been changed so global and local features are projected seperately
text_f_c = self.text_cond_proj(text_f_c) # (B, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
return PreprocessedConditions(text_f=text_f,
text_f_c=text_f_c)
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,r: torch.Tensor,#need r<t
conditions: PreprocessedConditions) -> torch.Tensor:
"""
for non-cacheable computations
"""
#assert r<=t,"r should smaller than t"
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
text_f = conditions.text_f
text_f_c = conditions.text_f_c
latent = self.audio_input_proj(latent) # (B, N, D)
#easy try:same embed
global_c = self.t_embed(t).unsqueeze(1) + self.r_embed(r).unsqueeze(1) + text_f_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # !TODO add fine-grained control
for block in self.joint_blocks:
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) # (B, N, D)
for block in self.fused_blocks:
latent = block(latent, extended_c, self.latent_rot)
flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t
return flow
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, r: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
"""
latent: (B, N, C)
text_f: (B, T, D)
text_f_c
r: (B,)
t:(B,)
"""
#print("2")
conditions = self.preprocess_conditions(text_f, text_f_c) # cachable operations
#print(conditions)
flow = self.predict_flow(latent, t,r, conditions) # non-cachable operations
return flow
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) # (b, d)
def get_empty_conditions(
self,
bs: int,
*,
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
if negative_text_features is not None:
empty_string_feat, empty_string_feat_c = negative_text_features
else:
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1)
conditions = self.preprocess_conditions(empty_string_feat,
empty_string_feat_c) # use encoder's empty features
if negative_text_features is None:
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
return conditions
def ode_wrapper(self, t: torch.Tensor, r: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
r = r * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
#(r)
if cfg_strength < 1.0:
return self.predict_flow(latent, t,r, conditions)
else:
return (cfg_strength * self.predict_flow(latent, t,r, conditions) +
(1 - cfg_strength) * self.predict_flow(latent, t,r, empty_conditions))
def load_weights(self, src_dict) -> None:
def remove_prefix(storage):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in storage.items():
name = k.replace("ema_model.", "")
new_state_dict[name] = v
return new_state_dict
src_dict=remove_prefix(src_dict)
if 't_embed.freqs' in src_dict:
del src_dict['t_embed.freqs']
if 'r_embed.freqs' in src_dict:
del src_dict['r_embed.freqs']
if 'latent_rot' in src_dict:
del src_dict['latent_rot']
if 'text_rot' in src_dict:
del src_dict['text_rot']
if 'empty_string_feat_c' not in src_dict.keys(): # FIXME: issue of version mismatch here
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0)
if '_extra_state' in src_dict:
del src_dict['_extra_state']
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.latent_mean.device
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
def fluxaudio_s(**kwargs) -> FluxAudio:
num_heads = 7
return FluxAudio(latent_dim=20,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=312, # for 10s audio
**kwargs)
def meanaudio_s(**kwargs) -> MeanAudio:
num_heads = 7
return MeanAudio(latent_dim=20,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=312, # for 10s audio
**kwargs)
def get_mean_audio(name: str, **kwargs) -> MeanAudio:
if name == 'meanaudio_s':
return meanaudio_s(**kwargs)
if name == 'fluxaudio_s':
return fluxaudio_s(**kwargs)
raise ValueError(f'Unknown model name: {name}')
if __name__ == '__main__':
from meanaudio.model.utils.sample_utils import log_normal_sample
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler("main.log"),
logging.StreamHandler()
]
)
network: MeanAudio = get_mean_audio('meanaudio_s',
use_rope=False,
text_c_dim=512)
x = torch.randn(256, 312, 20)
print(x.shape)
print('Finish')