Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| import torch | |
| class DenoisingLoss(ABC): | |
| def __call__( | |
| self, x: torch.Tensor, x_pred: torch.Tensor, | |
| noise: torch.Tensor, noise_pred: torch.Tensor, | |
| alphas_cumprod: torch.Tensor, | |
| timestep: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Base class for denoising loss. | |
| Input: | |
| - x: the clean data with shape [B, F, C, H, W] | |
| - x_pred: the predicted clean data with shape [B, F, C, H, W] | |
| - noise: the noise with shape [B, F, C, H, W] | |
| - noise_pred: the predicted noise with shape [B, F, C, H, W] | |
| - alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T] | |
| - timestep: the current timestep with shape [B, F] | |
| """ | |
| pass | |
| class X0PredLoss(DenoisingLoss): | |
| def __call__( | |
| self, x: torch.Tensor, x_pred: torch.Tensor, | |
| noise: torch.Tensor, noise_pred: torch.Tensor, | |
| alphas_cumprod: torch.Tensor, | |
| timestep: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| return torch.mean((x - x_pred) ** 2) | |
| class VPredLoss(DenoisingLoss): | |
| def __call__( | |
| self, x: torch.Tensor, x_pred: torch.Tensor, | |
| noise: torch.Tensor, noise_pred: torch.Tensor, | |
| alphas_cumprod: torch.Tensor, | |
| timestep: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1)) | |
| return torch.mean(weights * (x - x_pred) ** 2) | |
| class NoisePredLoss(DenoisingLoss): | |
| def __call__( | |
| self, x: torch.Tensor, x_pred: torch.Tensor, | |
| noise: torch.Tensor, noise_pred: torch.Tensor, | |
| alphas_cumprod: torch.Tensor, | |
| timestep: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| return torch.mean((noise - noise_pred) ** 2) | |
| class FlowPredLoss(DenoisingLoss): | |
| def __call__( | |
| self, x: torch.Tensor, x_pred: torch.Tensor, | |
| noise: torch.Tensor, noise_pred: torch.Tensor, | |
| alphas_cumprod: torch.Tensor, | |
| timestep: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2) | |
| NAME_TO_CLASS = { | |
| "x0": X0PredLoss, | |
| "v": VPredLoss, | |
| "noise": NoisePredLoss, | |
| "flow": FlowPredLoss | |
| } | |
| def get_denoising_loss(loss_type: str) -> DenoisingLoss: | |
| return NAME_TO_CLASS[loss_type] | |