File size: 1,171 Bytes
c3bcb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
from transformers import PretrainedConfig

class cceVAEConfig(PretrainedConfig):
    model_type = "cceVAE"
    def __init__(

            self,

            d=2,

            input_size=(1, 256, 256),

            z_dim=1024,

            fmap_sizes=(16, 64, 256, 1024),

            to_1x1=True,

            conv_params=None,

            tconv_params=None,

            normalization_op=None,

            normalization_params=None,

            activation_op="prelu",

            activation_params=None,

            block_op=None,

            block_params=None,

            **kwargs):
        self.d = d
        self.input_size = input_size
        self.z_dim = z_dim
        self.fmap_sizes = fmap_sizes
        self.to_1x1 = to_1x1
        self.conv_params = conv_params
        self.tconv_params = tconv_params
        self.normalization_op = normalization_op
        self.normalization_params = normalization_params
        self.activation_op = activation_op
        self.activation_params = activation_params
        self.block_op = block_op
        self.block_params = block_params
        super().__init__(**kwargs)