mmaction2 / tests /datasets /test_pose_dataset.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmaction.datasets import PoseDataset
from .base import BaseTestDataset
class TestPoseDataset(BaseTestDataset):
def test_pose_dataset(self):
ann_file = self.pose_ann_file
data_prefix = dict(video='root')
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
box_thr=0.5,
data_prefix=data_prefix)
assert len(dataset) == 100
item = dataset[0]
assert item['frame_dir'].startswith(data_prefix['video'])
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
valid_ratio=0.2,
box_thr=0.9)
assert len(dataset) == 84
for item in dataset:
assert np.all(item['box_score'][item['anno_inds']] >= 0.9)
assert item['valid'][0.9] / item['total_frames'] >= 0.2
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
valid_ratio=0.3,
box_thr=0.7)
assert len(dataset) == 87
for item in dataset:
assert np.all(item['box_score'][item['anno_inds']] >= 0.7)
assert item['valid'][0.7] / item['total_frames'] >= 0.3
with pytest.raises(AssertionError):
dataset = PoseDataset(
ann_file=ann_file, pipeline=[], valid_ratio=0.2, box_thr=0.55)