from transformers import PretrainedConfig from typing import List class CompressionConfig(PretrainedConfig): model_type = "compression_head" def __init__(self, input_size: int = 768, compression_sizes: List[int] = [512, 256, 128, 64, 32], dropout: float = 0.1, loss_k_vals: List[int] = [], **kwargs ): self.input_size = input_size self.compression_sizes = compression_sizes self.dropout = dropout self.loss_k_vals = loss_k_vals super().__init__(**kwargs)