|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import LlamaConfig, LlamaModel, PreTrainedModel |
|
from transformers.cache_utils import Cache |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.models.llama.modeling_llama import KwargsForCausalLM |
|
from transformers.processing_utils import Unpack |
|
|
|
from configuration_speechunit import SpeechUnitConfig |
|
|
|
|
|
|
|
class SpeechUnitPreTrainedModel(PreTrainedModel): |
|
config_class = SpeechUnitConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LlamaDecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, SpeechUnitModel): |
|
src_model = LlamaModel.from_pretrained(self.config.base_model_id) |
|
with torch.no_grad(): |
|
for name, param in module.llama_model.named_parameters(): |
|
param.copy_(src_model.state_dict()[name]) |
|
|
|
class SpeechUnitModel(SpeechUnitPreTrainedModel): |
|
def __init__(self, config: SpeechUnitConfig): |
|
super(SpeechUnitModel, self).__init__(config) |
|
|
|
|
|
llama_config = LlamaConfig.from_pretrained(config.base_model_id) |
|
llama_config.num_hidden_layers = config.num_hidden_layers |
|
self.llama_model = LlamaModel._from_config(llama_config) |
|
|
|
|
|
original_vocab_size, embed_dim = self.llama_model.embed_tokens.weight.shape |
|
|
|
|
|
self.audio_embed = nn.Embedding(16400, embed_dim) |
|
nn.init.xavier_uniform_(self.audio_embed.weight.data) |
|
|
|
|
|
self.token_weights = nn.Parameter(torch.ones(config.num_heads)) |
|
|
|
|
|
self.heads = nn.ModuleList([nn.Linear(embed_dim, config.output_dim) for _ in range(config.num_heads)]) |
|
|
|
self.post_init() |
|
|
|
def forward(self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
num_logits_to_keep: int = 0, |
|
**kwargs: Unpack[KwargsForCausalLM], |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
pass |