File size: 528 Bytes
b68ef84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch


# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))

    def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
        return pixels

    def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
        return samples