import torch import torch.nn as nn from torch import Tensor from typing import List, Tuple from itertools import chain def expand_t_like_x(t, x): """Function to reshape time t to broadcastable dimension of x Args: t: [batch_dim,], time vector x: [batch_dim,...], data point """ dims = [1] * (len(x.size()) - 1) t = t.view(t.size(0), *dims) return t def build_mlp(hidden_size, projector_dim, z_dim): return nn.Sequential( nn.Linear(hidden_size, projector_dim), nn.SiLU(), nn.Linear(projector_dim, projector_dim), nn.SiLU(), nn.Linear(projector_dim, z_dim), ) def modulate(x, shift, scale): return x * (1 + scale) + shift def get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype