Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import tempfile | |
| import unittest.mock as mock | |
| from collections import OrderedDict | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.runner import EpochBasedRunner, build_optimizer | |
| from mmcv.utils import get_logger | |
| from torch.utils.data import DataLoader, Dataset | |
| from mmpose.core import DistEvalHook, EvalHook | |
| class ExampleDataset(Dataset): | |
| def __init__(self): | |
| self.index = 0 | |
| self.eval_result = [0.1, 0.4, 0.3, 0.7, 0.2, 0.05, 0.4, 0.6] | |
| def __getitem__(self, idx): | |
| results = dict(imgs=torch.tensor([1])) | |
| return results | |
| def __len__(self): | |
| return 1 | |
| def evaluate(self, results, res_folder=None, logger=None): | |
| pass | |
| class EvalDataset(ExampleDataset): | |
| def evaluate(self, results, res_folder=None, logger=None): | |
| acc = self.eval_result[self.index] | |
| output = OrderedDict(acc=acc, index=self.index, score=acc) | |
| self.index += 1 | |
| return output | |
| class ExampleModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Linear(1, 1) | |
| self.test_cfg = None | |
| def forward(self, imgs, return_loss=False): | |
| return imgs | |
| def train_step(self, data_batch, optimizer, **kwargs): | |
| outputs = { | |
| 'loss': 0.5, | |
| 'log_vars': { | |
| 'accuracy': 0.98 | |
| }, | |
| 'num_samples': 1 | |
| } | |
| return outputs | |
| def test_eval_hook(EvalHookCls): | |
| with pytest.raises(TypeError): | |
| # dataloader must be a pytorch DataLoader | |
| test_dataset = ExampleDataset() | |
| data_loader = [ | |
| DataLoader( | |
| test_dataset, | |
| batch_size=1, | |
| sampler=None, | |
| num_worker=0, | |
| shuffle=False) | |
| ] | |
| EvalHookCls(data_loader) | |
| with pytest.raises(KeyError): | |
| # rule must be in keys of rule_map | |
| test_dataset = ExampleDataset() | |
| data_loader = DataLoader( | |
| test_dataset, | |
| batch_size=1, | |
| sampler=None, | |
| num_workers=0, | |
| shuffle=False) | |
| EvalHookCls(data_loader, save_best='auto', rule='unsupport') | |
| with pytest.raises(ValueError): | |
| # save_best must be valid when rule_map is None | |
| test_dataset = ExampleDataset() | |
| data_loader = DataLoader( | |
| test_dataset, | |
| batch_size=1, | |
| sampler=None, | |
| num_workers=0, | |
| shuffle=False) | |
| EvalHookCls(data_loader, save_best='unsupport') | |
| optimizer_cfg = dict( | |
| type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) | |
| test_dataset = ExampleDataset() | |
| loader = DataLoader(test_dataset, batch_size=1) | |
| model = ExampleModel() | |
| optimizer = build_optimizer(model, optimizer_cfg) | |
| data_loader = DataLoader(test_dataset, batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, save_best=None) | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=1) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)]) | |
| assert runner.meta is None or 'best_score' not in runner.meta[ | |
| 'hook_msgs'] | |
| assert runner.meta is None or 'best_ckpt' not in runner.meta[ | |
| 'hook_msgs'] | |
| # when `save_best` is set to 'auto', first metric will be used. | |
| loader = DataLoader(EvalDataset(), batch_size=1) | |
| model = ExampleModel() | |
| data_loader = DataLoader(EvalDataset(), batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, interval=1, save_best='auto') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=8) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)]) | |
| real_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.7 | |
| loader = DataLoader(EvalDataset(), batch_size=1) | |
| model = ExampleModel() | |
| data_loader = DataLoader(EvalDataset(), batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, interval=1, save_best='acc') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=8) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)]) | |
| real_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.7 | |
| data_loader = DataLoader(EvalDataset(), batch_size=1) | |
| eval_hook = EvalHookCls( | |
| data_loader, interval=1, save_best='score', rule='greater') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)], 8) | |
| real_path = osp.join(tmpdir, 'best_score_epoch_4.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.7 | |
| data_loader = DataLoader(EvalDataset(), batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, save_best='acc', rule='less') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=8) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)]) | |
| real_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.05 | |
| data_loader = DataLoader(EvalDataset(), batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, save_best='acc') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| logger = get_logger('test_eval') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=2) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.run([loader], [('train', 1)]) | |
| real_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.4 | |
| resume_from = osp.join(tmpdir, 'latest.pth') | |
| loader = DataLoader(ExampleDataset(), batch_size=1) | |
| eval_hook = EvalHookCls(data_loader, save_best='acc') | |
| runner = EpochBasedRunner( | |
| model=model, | |
| batch_processor=None, | |
| optimizer=optimizer, | |
| work_dir=tmpdir, | |
| logger=logger, | |
| max_epochs=8) | |
| runner.register_checkpoint_hook(dict(interval=1)) | |
| runner.register_hook(eval_hook) | |
| runner.resume(resume_from) | |
| runner.run([loader], [('train', 1)]) | |
| real_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
| assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path) | |
| assert runner.meta['hook_msgs']['best_score'] == 0.7 | |