AudioMorphix / src /utils /config.py
JinhuaL1ANG's picture
v1
9a6dac6
raw
history blame
5.82 kB
import os
import argparse
from copy import deepcopy
from typing import Union, Optional
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
def add_args_from_config(config, parser, prefix=""):
r"""Add new arguments to an argparser by using a predefined configs. e.g.,
config = {'a': {'b': 123}}, access the config['a']['b'] by
`python custom.py --a-b 234`.
"""
for key, value in config.items():
# '--' For the args under the root
arg_name = f"-{prefix}-{key}"
# Add args recursively if cfg is a nested DictConfig
if OmegaConf.is_dict(value):
add_args_from_config(value, parser, prefix=f"{prefix}-{key}")
else:
if OmegaConf.is_list(value):
value = OmegaConf.to_container(value)
parser.add_argument(
arg_name, type=type(value[0]), nargs="+", default=None
)
continue
arg_type = type(value)
if arg_type == bool:
parser.add_argument(
arg_name,
action="store_false" if value else "store_true",
default=None,
)
elif arg_type == type(None):
parser.add_argument(
arg_name, default=None,
)
else:
parser.add_argument(
arg_name, type=arg_type, default=None,
)
def update_config_from_args(config, args):
r"""Update an existing config by using a set of arguments.
The arguments should be created by `add_args_from_config`.
"""
def _recur_update_cfgs_from_args(config, args, prefix=""):
cur_config = deepcopy(config)
for key in config:
if OmegaConf.is_dict(config[key]):
updated_cfgs = _recur_update_cfgs_from_args(
config[key], args, prefix=f"{prefix}-{key}"
)
cur_config = OmegaConf.merge(cur_config, {key: updated_cfgs})
else:
arg_name = f"{prefix}-{key}".lstrip("-").replace("-", "_")
if hasattr(args, arg_name):
override_v = getattr(args, arg_name)
cur_config[key] = (
override_v if override_v is not None else config[key]
)
return cur_config
# Update config from each subgroup
for k, v in config["__subgroup__"].items():
sg_cfgs_path = getattr(args, f"__subgroup__-{k}".replace("-", "_"))
if sg_cfgs_path is not None:
updated_sg_cfgs = load_config(sg_cfgs_path)
config = OmegaConf.merge(config, {k: updated_sg_cfgs})
del config.__subgroup__
# Update config from each leaf node
config = _recur_update_cfgs_from_args(config, args, prefix="")
return config
def load_config(
config_path: Union[dict, str, DictConfig], dump_path: Optional[str] = None
) -> dict:
r"""Load config from yaml file.
This function will also read the yaml files
if they are specified in '__subgroup__'. e.g.,
[within `config_path`]
__subgroup__:
a: path_to_yaml_a
b: path_to_yaml_b
...
attribute 1:
...
------
RETURNS: OmegaConf.DictConfig
"""
if isinstance(config_path, str):
with open(config_path, "r") as file:
config = OmegaConf.load(file)
elif isinstance(config_path, dict):
config = OmegaConf.create(config_path)
else:
assert OmegaConf.is_config(
config_path
), f"config_path must be config path, dict, or DictConfig"
config = config_path
if "__subgroup__" in config:
subgroups = config.get("__subgroup__")
cur_cfg_dir = os.path.dirname(os.path.abspath(config_path))
for sg_name, sg_config_path in subgroups.items():
sg_abs_pth = os.path.join(cur_cfg_dir, sg_config_path)
sg_config = OmegaConf.load(sg_abs_pth)
config = OmegaConf.merge(config, {sg_name: sg_config})
config.__subgroup__[sg_name] = sg_abs_pth # update sub cfg path
return config
def dynamic_config(description: Optional[str] = None, verbose: bool = True):
r"""Load configuration from both yaml file and command line.
The config in the yaml will be overrided by the arg passed from command line.
e.g.,
[Command line] python3 custom.py --config_path /path/to/config.yaml --a-b-c=123
[Python file] cfgs = dynamic_config('A demo for dynamic configuration.')
cfgs.to_yaml('path/to/output/config.yaml') # log the config of this trial
------
RETURNS:
DictConfig.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--config_path", type=str, help="Path to the yaml file.")
# Get predefined configs and add new args dynamically
args, remaining_args = parser.parse_known_args()
cfgs = load_config(args.config_path)
add_args_from_config(cfgs, parser)
# Override values in `cfgs` if applicable
args = parser.parse_args(remaining_args)
cfgs = update_config_from_args(cfgs, args)
if verbose:
import logging
log = logging.getLogger(__name__)
log.info(f"Successfully setup the configuration:\n{OmegaConf.to_yaml(cfgs)}")
return cfgs
def dump_config(cfgs, dump_path):
dump_dir = os.path.dirname(os.path.abspath(dump_path))
os.makedirs(dump_dir, exist_ok=True)
with open(dump_path, "w") as file:
OmegaConf.save(cfgs, f=file)
if __name__ == "__main__":
cfgs = dynamic_config()
print("Updated Configuration:")
print(OmegaConf.to_yaml(cfgs))
import ipdb
ipdb.set_trace()