from transformers import PretrainedConfig | |
from typing import List | |
class CompressionConfig(PretrainedConfig): | |
model_type = "compression_head" | |
def __init__(self, | |
input_size: int = 768, | |
compression_sizes: List[int] = [512, 256, 128, 64, 32], | |
dropout: float = 0.1, | |
loss_k_vals: List[int] = [], | |
**kwargs | |
): | |
self.input_size = input_size | |
self.compression_sizes = compression_sizes | |
self.dropout = dropout | |
self.loss_k_vals = loss_k_vals | |
super().__init__(**kwargs) | |