MERaLiON-AudioLLM-Whisper-SEA-LION / configuration_meralion.py
YingxuHe's picture
Upload config
1eb7880 verified
raw
history blame
2.94 kB
"""MERaLiON AudioLLM model configuration"""
from transformers import Gemma2Config, WhisperConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MERaLiONConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MERaLiONForConditionalGeneration`]. It is used to instantiate an
MERaLiON model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the MERaLiON.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
audio_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
The config object or dictionary of the audio backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
audio_token_index (`int`, *optional*, defaults to 151646):
The image token index to encode the image prompt.
"""
model_type = "meralion"
is_composition = False
def __init__(
self,
speech_config=None,
text_config=None,
speech_mlp_scale_factor=15,
speech_token_index=255999,
**kwargs,
):
if isinstance(speech_config, dict):
speech_config = WhisperConfig(**speech_config)
elif speech_config is None:
speech_config = WhisperConfig(
d_model=1280,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
encoder_layerdrop=0.0,
encoder_layers=32,
num_mel_bins=128,
max_source_positions=1500,
scale_embedding=False,
activation_function="gelu",
)
self.speech_config = speech_config
if isinstance(text_config, dict):
text_config = Gemma2Config(**text_config)
elif text_config is None:
text_config = Gemma2Config()
self.text_config = text_config
self.speech_mlp_scale_factor = speech_mlp_scale_factor
self.speech_token_index = speech_token_index
self.sliding_window = self.text_config.sliding_window
self.hidden_size = self.text_config.hidden_size
self.num_attention_heads = self.text_config.num_attention_heads
self.num_hidden_layers = self.text_config.num_hidden_layers
self.num_key_value_heads = self.text_config.num_key_value_heads
self.head_dim = self.text_config.head_dim
self.intermediate_size = self.text_config.intermediate_size
super().__init__(**kwargs)