mmaction2 / tests /models /similarity /test_clip_similarity.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest.mock import MagicMock
import pytest
import torch
from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
from mmaction.testing import get_similarity_cfg
from mmaction.utils import register_all_modules
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_clip_similarity():
register_all_modules()
cfg = get_similarity_cfg(
'clip4clip/'
'clip4clip_vit-base-p32-res224-clip-pre_8xb16-u12-5e_msrvtt-9k-rgb.py')
cfg.model.frozen_layers = -1 # no frozen layers
model = MODELS.build(cfg.model)
model.train()
data_batch = {
'inputs': {
'imgs': [torch.randint(0, 256, (2, 3, 224, 224))],
'text': [torch.randint(0, 49408, (77, ))]
},
'data_samples': [ActionDataSample()]
}
# test train_step
optim_wrapper = MagicMock()
loss_vars = model.train_step(data_batch, optim_wrapper)
assert 'loss' in loss_vars
assert 'sim_loss_v2t' in loss_vars
assert 'sim_loss_t2v' in loss_vars
optim_wrapper.update_params.assert_called_once()
# test test_step
with torch.no_grad():
predictions = model.test_step(data_batch)
features = predictions[0].features
assert len(predictions) == 1
assert features.video_feature.size() == (512, )
assert features.text_feature.size() == (512, )
# test frozen layers
def check_frozen_layers(mdl, frozen_layers):
if frozen_layers >= 0:
top_layers = [
'ln_final', 'text_projection', 'logit_scale', 'visual.ln_post',
'visual.proj'
]
mid_layers = [
'visual.transformer.resblocks', 'transformer.resblocks'
]
for name, param in mdl.clip.named_parameters():
if any(name.find(n) == 0 for n in top_layers):
assert param.requires_grad is True
elif any(name.find(n) == 0 for n in mid_layers):
layer_n = int(name.split('.resblocks.')[1].split('.')[0])
if layer_n >= frozen_layers:
assert param.requires_grad is True
else:
assert param.requires_grad is False
else:
assert param.requires_grad is False
else:
assert all([p.requires_grad for p in mdl.clip.parameters()])
check_frozen_layers(model, -1)
model.frozen_layers = 0
model.train()
check_frozen_layers(model, 0)
model.frozen_layers = 6
model.train()
check_frozen_layers(model, 6)
model.frozen_layers = 12
model.train()
check_frozen_layers(model, 12)