Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import random | |
| import torch | |
| from audiocraft.adversarial import ( | |
| AdversarialLoss, | |
| get_adv_criterion, | |
| get_real_criterion, | |
| get_fake_criterion, | |
| FeatureMatchingLoss, | |
| MultiScaleDiscriminator, | |
| ) | |
| class TestAdversarialLoss: | |
| def test_adversarial_single_multidiscriminator(self): | |
| adv = MultiScaleDiscriminator() | |
| optimizer = torch.optim.Adam( | |
| adv.parameters(), | |
| lr=1e-4, | |
| ) | |
| loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') | |
| adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake) | |
| B, C, T = 4, 1, random.randint(1000, 5000) | |
| real = torch.randn(B, C, T) | |
| fake = torch.randn(B, C, T) | |
| disc_loss = adv_loss.train_adv(fake, real) | |
| assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float) | |
| loss, loss_feat = adv_loss(fake, real) | |
| assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) | |
| # we did not specify feature loss | |
| assert loss_feat.item() == 0. | |
| def test_adversarial_feat_loss(self): | |
| adv = MultiScaleDiscriminator() | |
| optimizer = torch.optim.Adam( | |
| adv.parameters(), | |
| lr=1e-4, | |
| ) | |
| loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') | |
| feat_loss = FeatureMatchingLoss() | |
| adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss) | |
| B, C, T = 4, 1, random.randint(1000, 5000) | |
| real = torch.randn(B, C, T) | |
| fake = torch.randn(B, C, T) | |
| loss, loss_feat = adv_loss(fake, real) | |
| assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) | |
| assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float) | |
| class TestGeneratorAdversarialLoss: | |
| def test_hinge_generator_adv_loss(self): | |
| adv_loss = get_adv_criterion(loss_type='hinge') | |
| t0 = torch.randn(1, 2, 0) | |
| t1 = torch.FloatTensor([1.0, 2.0, 3.0]) | |
| assert adv_loss(t0).item() == 0.0 | |
| assert adv_loss(t1).item() == -2.0 | |
| def test_mse_generator_adv_loss(self): | |
| adv_loss = get_adv_criterion(loss_type='mse') | |
| t0 = torch.randn(1, 2, 0) | |
| t1 = torch.FloatTensor([1.0, 1.0, 1.0]) | |
| t2 = torch.FloatTensor([2.0, 5.0, 5.0]) | |
| assert adv_loss(t0).item() == 0.0 | |
| assert adv_loss(t1).item() == 0.0 | |
| assert adv_loss(t2).item() == 11.0 | |
| class TestDiscriminatorAdversarialLoss: | |
| def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor): | |
| disc_loss_real = get_real_criterion(loss_type) | |
| disc_loss_fake = get_fake_criterion(loss_type) | |
| loss = disc_loss_fake(fake) + disc_loss_real(real) | |
| return loss | |
| def test_hinge_discriminator_adv_loss(self): | |
| loss_type = 'hinge' | |
| t0 = torch.FloatTensor([0.0, 0.0, 0.0]) | |
| t1 = torch.FloatTensor([1.0, 2.0, 3.0]) | |
| assert self._disc_loss(loss_type, t0, t0).item() == 2.0 | |
| assert self._disc_loss(loss_type, t1, t1).item() == 3.0 | |
| def test_mse_discriminator_adv_loss(self): | |
| loss_type = 'mse' | |
| t0 = torch.FloatTensor([0.0, 0.0, 0.0]) | |
| t1 = torch.FloatTensor([1.0, 1.0, 1.0]) | |
| assert self._disc_loss(loss_type, t0, t0).item() == 1.0 | |
| assert self._disc_loss(loss_type, t1, t0).item() == 2.0 | |
| class TestFeatureMatchingLoss: | |
| def test_features_matching_loss_base(self): | |
| ft_matching_loss = FeatureMatchingLoss() | |
| length = random.randrange(1, 100_000) | |
| t1 = torch.randn(1, 2, length) | |
| loss = ft_matching_loss([t1], [t1]) | |
| assert isinstance(loss, torch.Tensor) | |
| assert loss.item() == 0.0 | |
| def test_features_matching_loss_raises_exception(self): | |
| ft_matching_loss = FeatureMatchingLoss() | |
| length = random.randrange(1, 100_000) | |
| t1 = torch.randn(1, 2, length) | |
| t2 = torch.randn(1, 2, length + 1) | |
| with pytest.raises(AssertionError): | |
| ft_matching_loss([], []) | |
| with pytest.raises(AssertionError): | |
| ft_matching_loss([t1], [t1, t1]) | |
| with pytest.raises(AssertionError): | |
| ft_matching_loss([t1], [t2]) | |
| def test_features_matching_loss_output(self): | |
| loss_nonorm = FeatureMatchingLoss(normalize=False) | |
| loss_layer_normed = FeatureMatchingLoss(normalize=True) | |
| length = random.randrange(1, 100_000) | |
| t1 = torch.randn(1, 2, length) | |
| t2 = torch.randn(1, 2, length) | |
| assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0 | |
| assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0 | |
| t3 = torch.FloatTensor([1.0, 2.0, 3.0]) | |
| t4 = torch.FloatTensor([2.0, 10.0, 3.0]) | |
| assert loss_nonorm([t3], [t4]).item() == 3.0 | |
| assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0 | |
| assert loss_layer_normed([t3], [t4]).item() == 3.0 | |
| assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0 | |