niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import numpy as np
import pytest
import torch
from mmcv.transforms import to_tensor
from mmengine.structures import InstanceData
from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
from mmaction.testing import get_localizer_cfg
from mmaction.utils import register_all_modules
register_all_modules()
def get_localization_data_sample():
gt_bbox = np.array([[0.1, 0.3], [0.375, 0.625]])
data_sample = ActionDataSample()
instance_data = InstanceData()
instance_data['gt_bbox'] = to_tensor(gt_bbox)
data_sample.gt_instances = instance_data
data_sample.set_metainfo(
dict(
video_name='v_test',
duration_second=100,
duration_frame=960,
feature_frame=960))
return data_sample
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_tem():
model_cfg = get_localizer_cfg(
'bsn/bsn_tem_1xb16-400x100-20e_activitynet-feature.py')
localizer_tem = MODELS.build(model_cfg.model)
raw_feature = torch.rand(8, 400, 100)
# gt_bbox = torch.Tensor([[[1.0, 3.0], [3.0, 5.0]]] * 8)
data_samples = [get_localization_data_sample()] * 8
losses = localizer_tem(raw_feature, data_samples, mode='loss')
assert isinstance(losses, dict)
# Test forward predict
with torch.no_grad():
for one_raw_feature in raw_feature:
one_raw_feature = one_raw_feature.reshape(1, 400, 100)
data_samples = [get_localization_data_sample()]
localizer_tem(one_raw_feature, data_samples, mode='predict')