test_speech_unit / modeling_speechunit.py
AlexHung29629's picture
Upload model
853e052 verified
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaModel, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import KwargsForCausalLM
from transformers.processing_utils import Unpack
from configuration_speechunit import SpeechUnitConfig
# Copied from transformer.models.llama.modeling_llama.LlamaPreTrainedModel class
class SpeechUnitPreTrainedModel(PreTrainedModel):
config_class = SpeechUnitConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, SpeechUnitModel):
src_model = LlamaModel.from_pretrained(self.config.base_model_id)
with torch.no_grad():
for name, param in module.llama_model.named_parameters():
param.copy_(src_model.state_dict()[name])
class SpeechUnitModel(SpeechUnitPreTrainedModel):
def __init__(self, config: SpeechUnitConfig):
super(SpeechUnitModel, self).__init__(config)
# Initialize LLaMA model and load weights
llama_config = LlamaConfig.from_pretrained(config.base_model_id)
llama_config.num_hidden_layers = config.num_hidden_layers
self.llama_model = LlamaModel._from_config(llama_config)
# Embedding layers
original_vocab_size, embed_dim = self.llama_model.embed_tokens.weight.shape
# Audio embeddings (16400 = codebook size + 2 for BOS and EOS tokens)
self.audio_embed = nn.Embedding(16400, embed_dim)
nn.init.xavier_uniform_(self.audio_embed.weight.data)
# Learnable weights for token integration
self.token_weights = nn.Parameter(torch.ones(config.num_heads))
# Prediction heads
self.heads = nn.ModuleList([nn.Linear(embed_dim, config.output_dim) for _ in range(config.num_heads)])
self.post_init()
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
# ๅƒ่€ƒ https://github.com/huggingface/transformers/blob/b05df6611e6e3e6834acca2b50baeb7cdd5fbe3c/src/transformers/models/llama/modeling_llama.py#L784
pass