from __future__ import annotations import functools import importlib import os from functools import partial from inspect import isfunction import fsspec import torch from einops import repeat def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # convert the string to a tuple t = eval(s) # check if the type of t is tuple if isinstance(t, tuple): return t[0] else: pass except: pass return s def is_power_of_two(n): """Return True if n is a power of 2, otherwise return False.""" if n <= 0: return False else: return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled() ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def repeat_as_img_seq(x, num_frames): if x is not None: if isinstance(x, list): new_x = list() for item_x in x: new_x += [item_x] * num_frames return new_x else: x = x.unsqueeze(1) x = repeat(x, "b 1 ... -> (b t) ...", t=num_frames) return x else: return None def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) else: return path def ismap(x): if not isinstance(x, torch.Tensor): return False else: return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False else: return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False else: return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False else: return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val else: return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params") return total_params def instantiate_from_config(config): if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None else: raise KeyError("Expected key `target` to instantiate") else: return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def append_zero(x): return torch.cat((x, x.new_zeros([1]))) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f"Input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] def get_configs_path() -> str: """Get the `configs` directory.""" this_dir = os.path.dirname(__file__) candidates = ( os.path.join(this_dir, "configs"), os.path.join(this_dir, "..", "configs") ) for candidate in candidates: candidate = os.path.abspath(candidate) if os.path.isdir(candidate): return candidate raise FileNotFoundError(f"Could not find configs in {candidates}")