|
|
|
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
from mmengine.structures import InstanceData
|
|
|
|
from mmaction.registry import MODELS
|
|
from mmaction.structures import ActionDataSample
|
|
from mmaction.testing import get_localizer_cfg
|
|
from mmaction.utils import register_all_modules
|
|
|
|
register_all_modules()
|
|
|
|
|
|
def get_localization_data_sample():
|
|
bsp_feature = torch.rand(100, 32)
|
|
reference_temporal_iou = torch.rand(100)
|
|
data_sample = ActionDataSample()
|
|
instance_data = InstanceData()
|
|
instance_data['bsp_feature'] = bsp_feature
|
|
instance_data['reference_temporal_iou'] = reference_temporal_iou
|
|
data_sample.gt_instances = instance_data
|
|
return data_sample
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_pem():
|
|
model_cfg = get_localizer_cfg(
|
|
'bsn/bsn_pem_1xb16-400x100-20e_activitynet-feature.py')
|
|
|
|
localizer_pem = MODELS.build(model_cfg.model)
|
|
raw_features = [torch.rand(100, 32)] * 8
|
|
data_samples = [get_localization_data_sample()] * 8
|
|
losses = localizer_pem(raw_features, data_samples, mode='loss')
|
|
assert isinstance(losses, dict)
|
|
|
|
|
|
tmin = torch.rand(100)
|
|
tmax = torch.rand(100)
|
|
tmin_score = torch.rand(100)
|
|
tmax_score = torch.rand(100)
|
|
|
|
video_meta = dict(
|
|
video_name='v_test',
|
|
duration_second=100,
|
|
duration_frame=1000,
|
|
annotations=[{
|
|
'segment': [0.3, 0.6],
|
|
'label': 'Rock climbing'
|
|
}],
|
|
feature_frame=900)
|
|
|
|
with torch.no_grad():
|
|
raw_feature = [torch.rand(100, 32)]
|
|
data_sample = get_localization_data_sample()
|
|
data_sample.set_metainfo(video_meta)
|
|
gt_instances = data_sample.gt_instances
|
|
gt_instances['tmin'] = tmin
|
|
gt_instances['tmax'] = tmax
|
|
gt_instances['tmin_score'] = tmin_score
|
|
gt_instances['tmax_score'] = tmax_score
|
|
data_samples = [data_sample]
|
|
|
|
localizer_pem(raw_feature, data_samples, mode='predict')
|
|
|
|
|
|
with torch.no_grad():
|
|
raw_feature = [torch.rand(100, 32)]
|
|
localizer_pem(raw_feature, data_samples=None, mode='tensor')
|
|
|