File size: 636 Bytes
9700d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import PretrainedConfig
from typing import List

class CompressionConfig(PretrainedConfig):
    model_type = "compression_head"
    
    def __init__(self,
                 input_size: int = 768,
                 compression_sizes: List[int] = [512, 256, 128, 64, 32],
                 dropout: float = 0.1,
                 loss_k_vals: List[int] = [],
                 **kwargs
                ):
        
        self.input_size = input_size
        self.compression_sizes = compression_sizes
        self.dropout = dropout
        self.loss_k_vals = loss_k_vals
        
        super().__init__(**kwargs)