|
|
|
import pytest
|
|
from mmengine.testing import assert_dict_has_keys
|
|
|
|
from mmaction.datasets import RepeatAugDataset
|
|
from mmaction.utils import register_all_modules
|
|
from .base import BaseTestDataset
|
|
|
|
|
|
class TestVideoDataset(BaseTestDataset):
|
|
register_all_modules()
|
|
|
|
def test_video_dataset(self):
|
|
with pytest.raises(AssertionError):
|
|
|
|
video_dataset = RepeatAugDataset(
|
|
self.video_ann_file,
|
|
self.video_pipeline,
|
|
data_prefix={'video': self.data_prefix},
|
|
start_index=3)
|
|
|
|
video_pipeline = [
|
|
dict(type='DecordInit'),
|
|
dict(
|
|
type='SampleFrames', clip_len=4, frame_interval=2,
|
|
num_clips=1),
|
|
dict(type='DecordDecode')
|
|
]
|
|
|
|
video_dataset = RepeatAugDataset(
|
|
self.video_ann_file,
|
|
video_pipeline,
|
|
data_prefix={'video': self.data_prefix},
|
|
start_index=3)
|
|
assert len(video_dataset) == 2
|
|
assert video_dataset.start_index == 3
|
|
|
|
video_dataset = RepeatAugDataset(
|
|
self.video_ann_file,
|
|
video_pipeline,
|
|
data_prefix={'video': self.data_prefix})
|
|
assert video_dataset.start_index == 0
|
|
|
|
def test_video_dataset_multi_label(self):
|
|
video_pipeline = [
|
|
dict(type='DecordInit'),
|
|
dict(
|
|
type='SampleFrames', clip_len=4, frame_interval=2,
|
|
num_clips=1),
|
|
dict(type='DecordDecode')
|
|
]
|
|
video_dataset = RepeatAugDataset(
|
|
self.video_ann_file_multi_label,
|
|
video_pipeline,
|
|
data_prefix={'video': self.data_prefix},
|
|
multi_class=True,
|
|
num_classes=100)
|
|
assert video_dataset.start_index == 0
|
|
|
|
def test_video_pipeline(self):
|
|
video_pipeline = [
|
|
dict(type='DecordInit'),
|
|
dict(
|
|
type='SampleFrames', clip_len=4, frame_interval=2,
|
|
num_clips=1),
|
|
dict(type='DecordDecode')
|
|
]
|
|
target_keys = ['filename', 'label', 'start_index', 'modality']
|
|
|
|
|
|
video_dataset = RepeatAugDataset(
|
|
self.video_ann_file,
|
|
video_pipeline,
|
|
data_prefix={'video': self.data_prefix})
|
|
result = video_dataset[0]
|
|
assert isinstance(result, (list, tuple))
|
|
assert assert_dict_has_keys(result[0], target_keys)
|
|
|