Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from ...dist_utils import master_only | |
| from ..hook import HOOKS | |
| from .base import LoggerHook | |
| class MlflowLoggerHook(LoggerHook): | |
| def __init__(self, | |
| exp_name=None, | |
| tags=None, | |
| log_model=True, | |
| interval=10, | |
| ignore_last=True, | |
| reset_flag=False, | |
| by_epoch=True): | |
| """Class to log metrics and (optionally) a trained model to MLflow. | |
| It requires `MLflow`_ to be installed. | |
| Args: | |
| exp_name (str, optional): Name of the experiment to be used. | |
| Default None. | |
| If not None, set the active experiment. | |
| If experiment does not exist, an experiment with provided name | |
| will be created. | |
| tags (dict of str: str, optional): Tags for the current run. | |
| Default None. | |
| If not None, set tags for the current run. | |
| log_model (bool, optional): Whether to log an MLflow artifact. | |
| Default True. | |
| If True, log runner.model as an MLflow artifact | |
| for the current run. | |
| interval (int): Logging interval (every k iterations). | |
| ignore_last (bool): Ignore the log of last iterations in each epoch | |
| if less than `interval`. | |
| reset_flag (bool): Whether to clear the output buffer after logging | |
| by_epoch (bool): Whether EpochBasedRunner is used. | |
| .. _MLflow: | |
| https://www.mlflow.org/docs/latest/index.html | |
| """ | |
| super(MlflowLoggerHook, self).__init__(interval, ignore_last, | |
| reset_flag, by_epoch) | |
| self.import_mlflow() | |
| self.exp_name = exp_name | |
| self.tags = tags | |
| self.log_model = log_model | |
| def import_mlflow(self): | |
| try: | |
| import mlflow | |
| import mlflow.pytorch as mlflow_pytorch | |
| except ImportError: | |
| raise ImportError( | |
| 'Please run "pip install mlflow" to install mlflow') | |
| self.mlflow = mlflow | |
| self.mlflow_pytorch = mlflow_pytorch | |
| def before_run(self, runner): | |
| super(MlflowLoggerHook, self).before_run(runner) | |
| if self.exp_name is not None: | |
| self.mlflow.set_experiment(self.exp_name) | |
| if self.tags is not None: | |
| self.mlflow.set_tags(self.tags) | |
| def log(self, runner): | |
| tags = self.get_loggable_tags(runner) | |
| if tags: | |
| self.mlflow.log_metrics(tags, step=self.get_iter(runner)) | |
| def after_run(self, runner): | |
| if self.log_model: | |
| self.mlflow_pytorch.log_model(runner.model, 'models') | |