|
|
|
import platform
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmaction.registry import MODELS
|
|
from mmaction.structures import ActionDataSample
|
|
from mmaction.testing import get_similarity_cfg
|
|
from mmaction.utils import register_all_modules
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_clip_similarity():
|
|
register_all_modules()
|
|
cfg = get_similarity_cfg(
|
|
'clip4clip/'
|
|
'clip4clip_vit-base-p32-res224-clip-pre_8xb16-u12-5e_msrvtt-9k-rgb.py')
|
|
cfg.model.frozen_layers = -1
|
|
model = MODELS.build(cfg.model)
|
|
model.train()
|
|
|
|
data_batch = {
|
|
'inputs': {
|
|
'imgs': [torch.randint(0, 256, (2, 3, 224, 224))],
|
|
'text': [torch.randint(0, 49408, (77, ))]
|
|
},
|
|
'data_samples': [ActionDataSample()]
|
|
}
|
|
|
|
|
|
optim_wrapper = MagicMock()
|
|
loss_vars = model.train_step(data_batch, optim_wrapper)
|
|
assert 'loss' in loss_vars
|
|
assert 'sim_loss_v2t' in loss_vars
|
|
assert 'sim_loss_t2v' in loss_vars
|
|
optim_wrapper.update_params.assert_called_once()
|
|
|
|
|
|
with torch.no_grad():
|
|
predictions = model.test_step(data_batch)
|
|
features = predictions[0].features
|
|
assert len(predictions) == 1
|
|
assert features.video_feature.size() == (512, )
|
|
assert features.text_feature.size() == (512, )
|
|
|
|
|
|
def check_frozen_layers(mdl, frozen_layers):
|
|
if frozen_layers >= 0:
|
|
top_layers = [
|
|
'ln_final', 'text_projection', 'logit_scale', 'visual.ln_post',
|
|
'visual.proj'
|
|
]
|
|
mid_layers = [
|
|
'visual.transformer.resblocks', 'transformer.resblocks'
|
|
]
|
|
|
|
for name, param in mdl.clip.named_parameters():
|
|
if any(name.find(n) == 0 for n in top_layers):
|
|
assert param.requires_grad is True
|
|
elif any(name.find(n) == 0 for n in mid_layers):
|
|
layer_n = int(name.split('.resblocks.')[1].split('.')[0])
|
|
if layer_n >= frozen_layers:
|
|
assert param.requires_grad is True
|
|
else:
|
|
assert param.requires_grad is False
|
|
else:
|
|
assert param.requires_grad is False
|
|
else:
|
|
assert all([p.requires_grad for p in mdl.clip.parameters()])
|
|
|
|
check_frozen_layers(model, -1)
|
|
|
|
model.frozen_layers = 0
|
|
model.train()
|
|
check_frozen_layers(model, 0)
|
|
|
|
model.frozen_layers = 6
|
|
model.train()
|
|
check_frozen_layers(model, 6)
|
|
|
|
model.frozen_layers = 12
|
|
model.train()
|
|
check_frozen_layers(model, 12)
|
|
|