File size: 636 Bytes
9700d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import PretrainedConfig
from typing import List
class CompressionConfig(PretrainedConfig):
model_type = "compression_head"
def __init__(self,
input_size: int = 768,
compression_sizes: List[int] = [512, 256, 128, 64, 32],
dropout: float = 0.1,
loss_k_vals: List[int] = [],
**kwargs
):
self.input_size = input_size
self.compression_sizes = compression_sizes
self.dropout = dropout
self.loss_k_vals = loss_k_vals
super().__init__(**kwargs)
|