Upload losses.py
Browse files- audiocraft/adversarial/losses.py +228 -0
    	
        audiocraft/adversarial/losses.py
    ADDED
    
    | @@ -0,0 +1,228 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Utility module to handle adversarial losses without requiring to mess up the main training loop.
         | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import typing as tp
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import flashy
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
         | 
| 23 | 
            +
            FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class AdversarialLoss(nn.Module):
         | 
| 27 | 
            +
                """Adversary training wrapper.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
         | 
| 31 | 
            +
                        We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
         | 
| 32 | 
            +
                        where the first item is a list of logits and the second item is a list of feature maps.
         | 
| 33 | 
            +
                    optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
         | 
| 34 | 
            +
                    loss (AdvLossType): Loss function for generator training.
         | 
| 35 | 
            +
                    loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
         | 
| 36 | 
            +
                    loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
         | 
| 37 | 
            +
                    loss_feat (FeatLossType): Feature matching loss function for generator training.
         | 
| 38 | 
            +
                    normalize (bool): Whether to normalize by number of sub-discriminators.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Example of usage:
         | 
| 41 | 
            +
                    adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
         | 
| 42 | 
            +
                    for real in loader:
         | 
| 43 | 
            +
                        noise = torch.randn(...)
         | 
| 44 | 
            +
                        fake = model(noise)
         | 
| 45 | 
            +
                        adv_loss.train_adv(fake, real)
         | 
| 46 | 
            +
                        loss, _ = adv_loss(fake, real)
         | 
| 47 | 
            +
                        loss.backward()
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                def __init__(self,
         | 
| 50 | 
            +
                             adversary: nn.Module,
         | 
| 51 | 
            +
                             optimizer: torch.optim.Optimizer,
         | 
| 52 | 
            +
                             loss: AdvLossType,
         | 
| 53 | 
            +
                             loss_real: AdvLossType,
         | 
| 54 | 
            +
                             loss_fake: AdvLossType,
         | 
| 55 | 
            +
                             loss_feat: tp.Optional[FeatLossType] = None,
         | 
| 56 | 
            +
                             normalize: bool = True):
         | 
| 57 | 
            +
                    super().__init__()
         | 
| 58 | 
            +
                    self.adversary: nn.Module = adversary
         | 
| 59 | 
            +
                    flashy.distrib.broadcast_model(self.adversary)
         | 
| 60 | 
            +
                    self.optimizer = optimizer
         | 
| 61 | 
            +
                    self.loss = loss
         | 
| 62 | 
            +
                    self.loss_real = loss_real
         | 
| 63 | 
            +
                    self.loss_fake = loss_fake
         | 
| 64 | 
            +
                    self.loss_feat = loss_feat
         | 
| 65 | 
            +
                    self.normalize = normalize
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def _save_to_state_dict(self, destination, prefix, keep_vars):
         | 
| 68 | 
            +
                    # Add the optimizer state dict inside our own.
         | 
| 69 | 
            +
                    super()._save_to_state_dict(destination, prefix, keep_vars)
         | 
| 70 | 
            +
                    destination[prefix + 'optimizer'] = self.optimizer.state_dict()
         | 
| 71 | 
            +
                    return destination
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
         | 
| 74 | 
            +
                    # Load optimizer state.
         | 
| 75 | 
            +
                    self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
         | 
| 76 | 
            +
                    super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def get_adversary_pred(self, x):
         | 
| 79 | 
            +
                    """Run adversary model, validating expected output format."""
         | 
| 80 | 
            +
                    logits, fmaps = self.adversary(x)
         | 
| 81 | 
            +
                    assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
         | 
| 82 | 
            +
                        f'Expecting a list of tensors as logits but {type(logits)} found.'
         | 
| 83 | 
            +
                    assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
         | 
| 84 | 
            +
                    for fmap in fmaps:
         | 
| 85 | 
            +
                        assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
         | 
| 86 | 
            +
                            f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
         | 
| 87 | 
            +
                    return logits, fmaps
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
         | 
| 90 | 
            +
                    """Train the adversary with the given fake and real example.
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
         | 
| 93 | 
            +
                    The first item being the logits and second item being a list of feature maps for each sub-discriminator.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
         | 
| 96 | 
            +
                    and call the optimizer.
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    loss = torch.tensor(0., device=fake.device)
         | 
| 99 | 
            +
                    all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
         | 
| 100 | 
            +
                    all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
         | 
| 101 | 
            +
                    n_sub_adversaries = len(all_logits_fake_is_fake)
         | 
| 102 | 
            +
                    for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
         | 
| 103 | 
            +
                        loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if self.normalize:
         | 
| 106 | 
            +
                        loss /= n_sub_adversaries
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.optimizer.zero_grad()
         | 
| 109 | 
            +
                    with flashy.distrib.eager_sync_model(self.adversary):
         | 
| 110 | 
            +
                        loss.backward()
         | 
| 111 | 
            +
                    self.optimizer.step()
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    return loss
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
         | 
| 116 | 
            +
                    """Return the loss for the generator, i.e. trying to fool the adversary,
         | 
| 117 | 
            +
                    and feature matching loss if provided.
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
                    adv = torch.tensor(0., device=fake.device)
         | 
| 120 | 
            +
                    feat = torch.tensor(0., device=fake.device)
         | 
| 121 | 
            +
                    with flashy.utils.readonly(self.adversary):
         | 
| 122 | 
            +
                        all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
         | 
| 123 | 
            +
                        all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
         | 
| 124 | 
            +
                        n_sub_adversaries = len(all_logits_fake_is_fake)
         | 
| 125 | 
            +
                        for logit_fake_is_fake in all_logits_fake_is_fake:
         | 
| 126 | 
            +
                            adv += self.loss(logit_fake_is_fake)
         | 
| 127 | 
            +
                        if self.loss_feat:
         | 
| 128 | 
            +
                            for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
         | 
| 129 | 
            +
                                feat += self.loss_feat(fmap_fake, fmap_real)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if self.normalize:
         | 
| 132 | 
            +
                        adv /= n_sub_adversaries
         | 
| 133 | 
            +
                        feat /= n_sub_adversaries
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    return adv, feat
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def get_adv_criterion(loss_type: str) -> tp.Callable:
         | 
| 139 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 140 | 
            +
                if loss_type == 'mse':
         | 
| 141 | 
            +
                    return mse_loss
         | 
| 142 | 
            +
                elif loss_type == 'hinge':
         | 
| 143 | 
            +
                    return hinge_loss
         | 
| 144 | 
            +
                elif loss_type == 'hinge2':
         | 
| 145 | 
            +
                    return hinge2_loss
         | 
| 146 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def get_fake_criterion(loss_type: str) -> tp.Callable:
         | 
| 150 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 151 | 
            +
                if loss_type == 'mse':
         | 
| 152 | 
            +
                    return mse_fake_loss
         | 
| 153 | 
            +
                elif loss_type in ['hinge', 'hinge2']:
         | 
| 154 | 
            +
                    return hinge_fake_loss
         | 
| 155 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def get_real_criterion(loss_type: str) -> tp.Callable:
         | 
| 159 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 160 | 
            +
                if loss_type == 'mse':
         | 
| 161 | 
            +
                    return mse_real_loss
         | 
| 162 | 
            +
                elif loss_type in ['hinge', 'hinge2']:
         | 
| 163 | 
            +
                    return hinge_real_loss
         | 
| 164 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 168 | 
            +
                return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 172 | 
            +
                return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 176 | 
            +
                return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 180 | 
            +
                return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def mse_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 184 | 
            +
                if x.numel() == 0:
         | 
| 185 | 
            +
                    return torch.tensor([0.0], device=x.device)
         | 
| 186 | 
            +
                return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def hinge_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 190 | 
            +
                if x.numel() == 0:
         | 
| 191 | 
            +
                    return torch.tensor([0.0], device=x.device)
         | 
| 192 | 
            +
                return -x.mean()
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 196 | 
            +
                if x.numel() == 0:
         | 
| 197 | 
            +
                    return torch.tensor([0.0])
         | 
| 198 | 
            +
                return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class FeatureMatchingLoss(nn.Module):
         | 
| 202 | 
            +
                """Feature matching loss for adversarial training.
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Args:
         | 
| 205 | 
            +
                    loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
         | 
| 206 | 
            +
                    normalize (bool): Whether to normalize the loss.
         | 
| 207 | 
            +
                        by number of feature maps.
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
         | 
| 210 | 
            +
                    super().__init__()
         | 
| 211 | 
            +
                    self.loss = loss
         | 
| 212 | 
            +
                    self.normalize = normalize
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
         | 
| 215 | 
            +
                    assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
         | 
| 216 | 
            +
                    feat_loss = torch.tensor(0., device=fmap_fake[0].device)
         | 
| 217 | 
            +
                    feat_scale = torch.tensor(0., device=fmap_fake[0].device)
         | 
| 218 | 
            +
                    n_fmaps = 0
         | 
| 219 | 
            +
                    for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
         | 
| 220 | 
            +
                        assert feat_fake.shape == feat_real.shape
         | 
| 221 | 
            +
                        n_fmaps += 1
         | 
| 222 | 
            +
                        feat_loss += self.loss(feat_fake, feat_real)
         | 
| 223 | 
            +
                        feat_scale += torch.mean(torch.abs(feat_real))
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if self.normalize:
         | 
| 226 | 
            +
                        feat_loss /= n_fmaps
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    return feat_loss
         |