Spaces:
Running
on
A10G
Running
on
A10G
from typing import List | |
import hydra | |
from omegaconf import DictConfig | |
from pytorch_lightning import Callback | |
from pytorch_lightning.loggers import Logger | |
from .logger import RankedLogger | |
log = RankedLogger(__name__, rank_zero_only=True) | |
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: | |
"""Instantiates callbacks from config.""" | |
callbacks: List[Callback] = [] | |
if not callbacks_cfg: | |
log.warning("No callback configs found! Skipping..") | |
return callbacks | |
if not isinstance(callbacks_cfg, DictConfig): | |
raise TypeError("Callbacks config must be a DictConfig!") | |
for _, cb_conf in callbacks_cfg.items(): | |
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: | |
log.info(f"Instantiating callback <{cb_conf._target_}>") | |
callbacks.append(hydra.utils.instantiate(cb_conf)) | |
return callbacks | |
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: | |
"""Instantiates loggers from config.""" | |
logger: List[Logger] = [] | |
if not logger_cfg: | |
log.warning("No logger configs found! Skipping...") | |
return logger | |
if not isinstance(logger_cfg, DictConfig): | |
raise TypeError("Logger config must be a DictConfig!") | |
for _, lg_conf in logger_cfg.items(): | |
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: | |
log.info(f"Instantiating logger <{lg_conf._target_}>") | |
logger.append(hydra.utils.instantiate(lg_conf)) | |
return logger | |