|
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
|
|
from mmaction.registry import MODELS
|
|
from mmaction.structures import ActionDataSample
|
|
from mmaction.testing import get_recognizer_cfg
|
|
from mmaction.utils import register_all_modules
|
|
|
|
|
|
def test_omni_resnet():
|
|
register_all_modules()
|
|
config = get_recognizer_cfg(
|
|
'omnisource/slowonly_r50_8xb16-8x8x1-256e_imagenet-kinetics400-rgb.py')
|
|
recognizer = MODELS.build(config.model)
|
|
|
|
|
|
|
|
video_sample = {
|
|
'inputs': [
|
|
torch.randint(0, 255, (1, 3, 8, 224, 224)),
|
|
torch.randint(0, 255, (1, 3, 8, 224, 224))
|
|
],
|
|
'data_samples': [
|
|
ActionDataSample().set_gt_label(2),
|
|
ActionDataSample().set_gt_label(2)
|
|
]
|
|
}
|
|
|
|
image_sample = {
|
|
'inputs': [
|
|
torch.randint(0, 255, (1, 3, 224, 224)),
|
|
torch.randint(0, 255, (1, 3, 224, 224))
|
|
],
|
|
'data_samples': [
|
|
ActionDataSample().set_gt_label(2),
|
|
ActionDataSample().set_gt_label(2)
|
|
]
|
|
}
|
|
|
|
optim_wrapper = MagicMock()
|
|
loss_vars = recognizer.train_step([video_sample, image_sample],
|
|
optim_wrapper)
|
|
assert 'loss_cls_0' in loss_vars
|
|
assert 'loss_cls_1' in loss_vars
|
|
|
|
loss_vars = recognizer.train_step([image_sample, video_sample],
|
|
optim_wrapper)
|
|
assert 'loss_cls_0' in loss_vars
|
|
assert 'loss_cls_1' in loss_vars
|
|
|
|
|
|
with torch.no_grad():
|
|
predictions = recognizer.test_step(video_sample)
|
|
score = predictions[0].pred_score
|
|
assert len(predictions) == 2
|
|
assert torch.min(score) >= 0
|
|
assert torch.max(score) <= 1
|
|
|