Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import os | |
| import copy | |
| import json | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.utils import ( | |
| extract_commit_hash, | |
| ) | |
| from step1x3d_geometry.utils.config import parse_structured | |
| from step1x3d_geometry.utils.misc import get_device, load_module_weights | |
| from step1x3d_geometry.utils.typing import * | |
| class Configurable: | |
| class Config: | |
| pass | |
| def __init__(self, cfg: Optional[dict] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| class Updateable: | |
| def do_update_step( | |
| self, epoch: int, global_step: int, on_load_weights: bool = False | |
| ): | |
| for attr in self.__dir__(): | |
| if attr.startswith("_"): | |
| continue | |
| try: | |
| module = getattr(self, attr) | |
| except: | |
| continue # ignore attributes like property, which can't be retrived using getattr? | |
| if isinstance(module, Updateable): | |
| module.do_update_step( | |
| epoch, global_step, on_load_weights=on_load_weights | |
| ) | |
| self.update_step(epoch, global_step, on_load_weights=on_load_weights) | |
| def do_update_step_end(self, epoch: int, global_step: int): | |
| for attr in self.__dir__(): | |
| if attr.startswith("_"): | |
| continue | |
| try: | |
| module = getattr(self, attr) | |
| except: | |
| continue # ignore attributes like property, which can't be retrived using getattr? | |
| if isinstance(module, Updateable): | |
| module.do_update_step_end(epoch, global_step) | |
| self.update_step_end(epoch, global_step) | |
| def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| # override this method to implement custom update logic | |
| # if on_load_weights is True, you should be careful doing things related to model evaluations, | |
| # as the models and tensors are not guarenteed to be on the same device | |
| pass | |
| def update_step_end(self, epoch: int, global_step: int): | |
| pass | |
| def update_if_possible(module: Any, epoch: int, global_step: int) -> None: | |
| if isinstance(module, Updateable): | |
| module.do_update_step(epoch, global_step) | |
| def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: | |
| if isinstance(module, Updateable): | |
| module.do_update_step_end(epoch, global_step) | |
| class BaseObject(Updateable): | |
| class Config: | |
| pass | |
| cfg: Config # add this to every subclass of BaseObject to enable static type checking | |
| def __init__( | |
| self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| self.device = get_device() | |
| self.configure(*args, **kwargs) | |
| def configure(self, *args, **kwargs) -> None: | |
| pass | |
| class BaseModule(ModelMixin, Updateable, nn.Module): | |
| class Config: | |
| weights: Optional[str] = None | |
| cfg: Config # add this to every subclass of BaseModule to enable static type checking | |
| config_name = "config.json" | |
| def __init__( | |
| self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| # self.device = get_device() | |
| self.configure(*args, **kwargs) | |
| if self.cfg.weights is not None: | |
| # format: path/to/weights:module_name | |
| weights_path, module_name = self.cfg.weights.split(":") | |
| state_dict, epoch, global_step = load_module_weights( | |
| weights_path, module_name=module_name, map_location="cpu" | |
| ) | |
| self.load_state_dict(state_dict) | |
| self.do_update_step( | |
| epoch, global_step, on_load_weights=True | |
| ) # restore states | |
| # dummy tensor to indicate model state | |
| self._dummy: Float[Tensor, "..."] | |
| self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) | |
| def configure(self, *args, **kwargs) -> None: | |
| pass | |
| def load_config( | |
| cls, | |
| pretrained_model_name_or_path: Union[str, os.PathLike], | |
| return_unused_kwargs=False, | |
| return_commit_hash=False, | |
| **kwargs, | |
| ): | |
| subfolder = kwargs.pop("subfolder", None) | |
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |
| if os.path.isfile(pretrained_model_name_or_path): | |
| config_file = pretrained_model_name_or_path | |
| elif os.path.isdir(pretrained_model_name_or_path): | |
| if subfolder is not None and os.path.isfile( | |
| os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) | |
| ): | |
| config_file = os.path.join( | |
| pretrained_model_name_or_path, subfolder, cls.config_name | |
| ) | |
| elif os.path.isfile( | |
| os.path.join(pretrained_model_name_or_path, cls.config_name) | |
| ): | |
| # Load from a PyTorch checkpoint | |
| config_file = os.path.join( | |
| pretrained_model_name_or_path, cls.config_name | |
| ) | |
| else: | |
| raise EnvironmentError( | |
| f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." | |
| ) | |
| else: | |
| raise ValueError | |
| config_dict = json.load(open(config_file, "r")) | |
| commit_hash = extract_commit_hash(config_file) | |
| outputs = (config_dict,) | |
| if return_unused_kwargs: | |
| outputs += (kwargs,) | |
| if return_commit_hash: | |
| outputs += (commit_hash,) | |
| return outputs | |
| def from_config(cls, config: Dict[str, Any] = None, **kwargs): | |
| model = cls(config) | |
| return model | |
| def register_to_config(self, **kwargs): | |
| pass | |
| def save_config(self, save_directory: Union[str, os.PathLike], **kwargs): | |
| """ | |
| Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the | |
| [`~ConfigMixin.from_config`] class method. | |
| Args: | |
| save_directory (`str` or `os.PathLike`): | |
| Directory where the configuration JSON file is saved (will be created if it does not exist). | |
| kwargs (`Dict[str, Any]`, *optional*): | |
| Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | |
| """ | |
| if os.path.isfile(save_directory): | |
| raise AssertionError( | |
| f"Provided path ({save_directory}) should be a directory, not a file" | |
| ) | |
| os.makedirs(save_directory, exist_ok=True) | |
| # If we save using the predefined names, we can load using `from_config` | |
| output_config_file = os.path.join(save_directory, self.config_name) | |
| config_dict = OmegaConf.to_container(self.cfg, resolve=True) | |
| for k in copy.deepcopy(config_dict).keys(): | |
| if k.startswith("pretrained"): | |
| config_dict.pop(k) | |
| config_dict.pop("weights") | |
| with open(output_config_file, "w", encoding="utf-8") as f: | |
| json.dump(config_dict, f, ensure_ascii=False, indent=4) | |
| print(f"Configuration saved in {output_config_file}") | |