"""MERaLiON AudioLLM model configuration""" from transformers import Gemma2Config, WhisperConfig from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class MERaLiONConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`MERaLiONForConditionalGeneration`]. It is used to instantiate an MERaLiON model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MERaLiON. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: audio_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): The config object or dictionary of the audio backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. audio_token_index (`int`, *optional*, defaults to 151646): The image token index to encode the image prompt. """ model_type = "meralion" is_composition = False def __init__( self, speech_config=None, text_config=None, speech_mlp_scale_factor=15, speech_token_index=255999, **kwargs, ): if isinstance(speech_config, dict): speech_config = WhisperConfig(**speech_config) elif speech_config is None: speech_config = WhisperConfig( d_model=1280, encoder_attention_heads=20, encoder_ffn_dim=5120, encoder_layerdrop=0.0, encoder_layers=32, num_mel_bins=128, max_source_positions=1500, scale_embedding=False, activation_function="gelu", ) self.speech_config = speech_config if isinstance(text_config, dict): text_config = Gemma2Config(**text_config) elif text_config is None: text_config = Gemma2Config() self.text_config = text_config self.speech_mlp_scale_factor = speech_mlp_scale_factor self.speech_token_index = speech_token_index self.sliding_window = self.text_config.sliding_window self.hidden_size = self.text_config.hidden_size self.num_attention_heads = self.text_config.num_attention_heads self.num_hidden_layers = self.text_config.num_hidden_layers self.num_key_value_heads = self.text_config.num_key_value_heads self.head_dim = self.text_config.head_dim self.intermediate_size = self.text_config.intermediate_size super().__init__(**kwargs)