mmaction2 / tests /datasets /test_repeataug_dataset.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine.testing import assert_dict_has_keys
from mmaction.datasets import RepeatAugDataset
from mmaction.utils import register_all_modules
from .base import BaseTestDataset
class TestVideoDataset(BaseTestDataset):
register_all_modules()
def test_video_dataset(self):
with pytest.raises(AssertionError):
# Currently only support decord backend
video_dataset = RepeatAugDataset(
self.video_ann_file,
self.video_pipeline,
data_prefix={'video': self.data_prefix},
start_index=3)
video_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames', clip_len=4, frame_interval=2,
num_clips=1),
dict(type='DecordDecode')
]
video_dataset = RepeatAugDataset(
self.video_ann_file,
video_pipeline,
data_prefix={'video': self.data_prefix},
start_index=3)
assert len(video_dataset) == 2
assert video_dataset.start_index == 3
video_dataset = RepeatAugDataset(
self.video_ann_file,
video_pipeline,
data_prefix={'video': self.data_prefix})
assert video_dataset.start_index == 0
def test_video_dataset_multi_label(self):
video_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames', clip_len=4, frame_interval=2,
num_clips=1),
dict(type='DecordDecode')
]
video_dataset = RepeatAugDataset(
self.video_ann_file_multi_label,
video_pipeline,
data_prefix={'video': self.data_prefix},
multi_class=True,
num_classes=100)
assert video_dataset.start_index == 0
def test_video_pipeline(self):
video_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames', clip_len=4, frame_interval=2,
num_clips=1),
dict(type='DecordDecode')
]
target_keys = ['filename', 'label', 'start_index', 'modality']
# RepeatAugDataset not in test mode
video_dataset = RepeatAugDataset(
self.video_ann_file,
video_pipeline,
data_prefix={'video': self.data_prefix})
result = video_dataset[0]
assert isinstance(result, (list, tuple))
assert assert_dict_has_keys(result[0], target_keys)