Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class Loss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, input, target): | |
| input = input.view(-1) | |
| target = target.view(-1) | |
| return torch.mean(input - target) | |
| class TestCrissCrossAttention: | |
| def test_cc_attention(self): | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| from mmcv.ops import CrissCrossAttention | |
| loss_func = Loss() | |
| input = np.fromfile( | |
| 'tests/data/for_ccattention/ccattention_input.bin', | |
| dtype=np.float32) | |
| output = np.fromfile( | |
| 'tests/data/for_ccattention/ccattention_output.bin', | |
| dtype=np.float32) | |
| input = input.reshape((1, 32, 45, 45)) | |
| output = output.reshape((1, 32, 45, 45)) | |
| label = torch.ones((1, 32, 45, 45)) | |
| input = torch.FloatTensor(input) | |
| output = torch.FloatTensor(output) | |
| input.requires_grad = True | |
| shape = input.shape | |
| channel = shape[1] | |
| cca = CrissCrossAttention(channel) | |
| cca.to(device) | |
| input = input.to(device) | |
| label = label.to(device) | |
| cca.train() | |
| test_output = cca(input) | |
| test_loss = loss_func(test_output, label) | |
| test_loss.backward() | |
| test_output = test_output.detach().cpu().numpy() | |
| output = output.numpy() | |
| assert np.allclose(test_output, output) | |
| assert test_output.shape == shape | |