granite-speech-3.2-8b / configuration_granite_speech.py
gsaon's picture
Upload 16 files
ee8e2c4 verified
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING, AutoConfig
class GraniteSpeechEncoderConfig(PretrainedConfig):
model_type = "granite_speech_encoder"
def __init__(
self,
input_dim=160,
num_layers=10,
hidden_dim=1024,
feedforward_mult=4,
num_heads=8,
dim_head=128,
output_dim=42,
context_size=200,
dropout=0.1,
conv_kernel_size=15,
conv_expansion_factor=2,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.feedforward_mult = feedforward_mult
self.num_heads = num_heads
self.dim_head = dim_head
self.output_dim = output_dim
self.context_size = context_size
self.dropout = dropout
self.conv_kernel_size = conv_kernel_size
self.conv_expansion_factor = conv_expansion_factor
## adapted from transformers.models.blip.configuration_blip_2.Blip2VisionConfig
class GraniteSpeechProjectorConfig(PretrainedConfig):
model_type = "granite_speech_qformer"
def __init__(
self,
llm_dim=4096,
downsample_rate=5,
window_size=15,
hidden_size=1024,
num_attention_heads=16,
intermediate_size=4096,
num_hidden_layers=2,
encoder_hidden_size=1024,
cross_attention_frequency=1,
max_position_embeddings=2048,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_qformer_text_input=False,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.cross_attention_frequency = cross_attention_frequency
self.encoder_hidden_size = encoder_hidden_size
self.use_qformer_text_input = use_qformer_text_input
self.downsample_rate = downsample_rate
self.window_size = window_size
self.llm_dim = llm_dim
class GraniteSpeechConfig(PretrainedConfig):
model_type = "granite_speech"
sub_configs = {
"text_config": AutoConfig,
"encoder_config": GraniteSpeechEncoderConfig,
"projector_config": GraniteSpeechProjectorConfig,
}
def __init__(
self,
encoder_config=None,
text_config=None,
projector_config=None,
audio_token_index=49155,
initializer_range=0.02,
has_lora_adapter=True,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "granite"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["granite"]()
if isinstance(projector_config, dict):
# TODO - In the future, we should make this generic.
projector_config = GraniteSpeechProjectorConfig(**projector_config)
elif projector_config is None:
projector_config = GraniteSpeechProjectorConfig()
if not isinstance(encoder_config, GraniteSpeechEncoderConfig):
encoder_config = {} if encoder_config is None else encoder_config
encoder_config = GraniteSpeechEncoderConfig(**encoder_config)
self.text_config = text_config
self.encoder_config = encoder_config
self.projector_config = projector_config
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
self.has_lora_adapter = has_lora_adapter
super().__init__(**kwargs)
__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechProjectorConfig", "GraniteSpeechConfig"]