from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import LlamaConfig, LlamaModel, PreTrainedModel from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import KwargsForCausalLM from transformers.processing_utils import Unpack from configuration_speechunit import SpeechUnitConfig # Copied from transformer.models.llama.modeling_llama.LlamaPreTrainedModel class class SpeechUnitPreTrainedModel(PreTrainedModel): config_class = SpeechUnitConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, SpeechUnitModel): src_model = LlamaModel.from_pretrained(self.config.base_model_id) with torch.no_grad(): for name, param in module.llama_model.named_parameters(): param.copy_(src_model.state_dict()[name]) class SpeechUnitModel(SpeechUnitPreTrainedModel): def __init__(self, config: SpeechUnitConfig): super(SpeechUnitModel, self).__init__(config) # Initialize LLaMA model and load weights llama_config = LlamaConfig.from_pretrained(config.base_model_id) llama_config.num_hidden_layers = config.num_hidden_layers self.llama_model = LlamaModel._from_config(llama_config) # Embedding layers original_vocab_size, embed_dim = self.llama_model.embed_tokens.weight.shape # Audio embeddings (16400 = codebook size + 2 for BOS and EOS tokens) self.audio_embed = nn.Embedding(16400, embed_dim) nn.init.xavier_uniform_(self.audio_embed.weight.data) # Learnable weights for token integration self.token_weights = nn.Parameter(torch.ones(config.num_heads)) # Prediction heads self.heads = nn.ModuleList([nn.Linear(embed_dim, config.output_dim) for _ in range(config.num_heads)]) self.post_init() def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: # 參考 https://github.com/huggingface/transformers/blob/b05df6611e6e3e6834acca2b50baeb7cdd5fbe3c/src/transformers/models/llama/modeling_llama.py#L784 pass