mmaction2 / tests /models /recognizers /recognizer_omni.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock
import torch
from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
from mmaction.testing import get_recognizer_cfg
from mmaction.utils import register_all_modules
def test_omni_resnet():
register_all_modules()
config = get_recognizer_cfg(
'omnisource/slowonly_r50_8xb16-8x8x1-256e_imagenet-kinetics400-rgb.py')
recognizer = MODELS.build(config.model)
# test train_step
video_sample = {
'inputs': [
torch.randint(0, 255, (1, 3, 8, 224, 224)),
torch.randint(0, 255, (1, 3, 8, 224, 224))
],
'data_samples': [
ActionDataSample().set_gt_label(2),
ActionDataSample().set_gt_label(2)
]
}
image_sample = {
'inputs': [
torch.randint(0, 255, (1, 3, 224, 224)),
torch.randint(0, 255, (1, 3, 224, 224))
],
'data_samples': [
ActionDataSample().set_gt_label(2),
ActionDataSample().set_gt_label(2)
]
}
optim_wrapper = MagicMock()
loss_vars = recognizer.train_step([video_sample, image_sample],
optim_wrapper)
assert 'loss_cls_0' in loss_vars
assert 'loss_cls_1' in loss_vars
loss_vars = recognizer.train_step([image_sample, video_sample],
optim_wrapper)
assert 'loss_cls_0' in loss_vars
assert 'loss_cls_1' in loss_vars
# test test_step
with torch.no_grad():
predictions = recognizer.test_step(video_sample)
score = predictions[0].pred_score
assert len(predictions) == 2
assert torch.min(score) >= 0
assert torch.max(score) <= 1