File size: 618 Bytes
853e052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import PretrainedConfig
class SpeechUnitConfig(PretrainedConfig):
model_type = "speechunit"
def __init__(
self,
base_model_id: str = "meta-llama/Llama-3.2-1B",
num_hidden_layers: int = 3,
output_dim: int = 2048,
num_heads: int = 8,
initializer_range: float = 0.02,
**kwargs,
):
self.base_model_id = base_model_id
self.num_hidden_layers = num_hidden_layers
self.output_dim = output_dim
self.num_heads = num_heads
self.initializer_range = initializer_range
super().__init__(**kwargs) |