| | import math |
| | import warnings |
| | from functools import partial |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| | from transformers.utils import logging |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaRotaryEmbedding, |
| | LlamaLinearScalingRotaryEmbedding, |
| | LlamaDynamicNTKScalingRotaryEmbedding, |
| | apply_rotary_pos_emb, |
| | repeat_kv, |
| | LlamaMLP, |
| | LlamaRMSNorm, |
| | is_flash_attn_greater_or_equal_2_10, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | import copy |
| | import os |
| | import sys |
| |
|
| | dir_path = os.path.dirname(os.path.realpath(__file__)) |
| | sys.path.insert(0, dir_path) |
| |
|
| | import transformers |
| | from transformers.models.llama.modeling_llama import * |
| |
|
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.utils import logging |
| |
|
| | from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa |
| | from .configuration_mplug_owl2 import LlamaConfig |
| |
|
| | def _get_unpad_data(attention_mask): |
| | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| | max_seqlen_in_batch = seqlens_in_batch.max().item() |
| | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| | return ( |
| | indices, |
| | cu_seqlens, |
| | max_seqlen_in_batch, |
| | ) |
| |
|
| |
|
| | class MultiwayNetwork(nn.Module): |
| |
|
| | def __init__(self, module_provider, num_multiway=2): |
| | super(MultiwayNetwork, self).__init__() |
| |
|
| | self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) |
| | |
| | def forward(self, hidden_states, multiway_indices): |
| |
|
| | if len(self.multiway) == 1: |
| | return self.multiway[0](hidden_states) |
| |
|
| | output_hidden_states = torch.empty_like(hidden_states) |
| | |
| | for idx, subway in enumerate(self.multiway): |
| | local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) |
| | hidden = hidden_states[local_indices].unsqueeze(1).contiguous() |
| | if hidden.numel(): |
| | output = subway(hidden) |
| | if isinstance(output, tuple): |
| | output = output[0] |
| | output = output.squeeze(1) |
| | output_hidden_states[local_indices] = output |
| | |
| | return output_hidden_states.contiguous() |
| | |
| |
|
| | class LlamaAttention(nn.Module): |
| | """Multi-headed attention from 'Attention Is All You Need' paper""" |
| |
|
| | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | if layer_idx is None: |
| | logger.warning_once( |
| | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
| | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| | "when creating this class." |
| | ) |
| | |
| | self.attention_dropout = config.attention_dropout |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.head_dim = self.hidden_size // self.num_heads |
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| | self.max_position_embeddings = config.max_position_embeddings |
| | self.rope_theta = config.rope_theta |
| | self.is_causal = True |
| |
|
| | if (self.head_dim * self.num_heads) != self.hidden_size: |
| | raise ValueError( |
| | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| | f" and `num_heads`: {self.num_heads})." |
| | ) |
| | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| | self.k_proj = MultiwayNetwork(module_provider=partial( |
| | nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| | ) |
| | self.v_proj = MultiwayNetwork(module_provider=partial( |
| | nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| | ) |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
| | self._init_rope() |
| |
|
| | def _init_rope(self): |
| | if self.config.rope_scaling is None: |
| | self.rotary_emb = LlamaRotaryEmbedding( |
| | self.head_dim, |
| | max_position_embeddings=self.max_position_embeddings, |
| | base=self.rope_theta, |
| | ) |
| | else: |
| | scaling_type = self.config.rope_scaling["type"] |
| | scaling_factor = self.config.rope_scaling["factor"] |
| | if scaling_type == "linear": |
| | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( |
| | self.head_dim, |
| | max_position_embeddings=self.max_position_embeddings, |
| | scaling_factor=scaling_factor, |
| | base=self.rope_theta, |
| | ) |
| | elif scaling_type == "dynamic": |
| | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( |
| | self.head_dim, |
| | max_position_embeddings=self.max_position_embeddings, |
| | scaling_factor=scaling_factor, |
| | base=self.rope_theta, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
| |
|
| | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | modality_indicators: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | padding_mask: Optional[torch.LongTensor] = None, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | bsz, q_len, _ = hidden_states.size() |
| |
|
| | query_states = self.q_proj(hidden_states, ) |
| | key_states = self.k_proj(hidden_states, modality_indicators) |
| | value_states = self.v_proj(hidden_states, modality_indicators) |
| |
|
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | kv_seq_len = key_states.shape[-2] |
| | if past_key_value is not None: |
| | kv_seq_len += past_key_value[0].shape[-2] |
| | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| |
|
| | if past_key_value is not None: |
| | |
| | key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| | value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| |
|
| | past_key_value = (key_states, value_states) if use_cache else None |
| |
|
| | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| |
|
| | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
| |
|
| | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| | raise ValueError( |
| | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| | f" {attn_weights.size()}" |
| | ) |
| |
|
| | if attention_mask is not None: |
| | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| | raise ValueError( |
| | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| | ) |
| | attn_weights = attn_weights + attention_mask |
| |
|
| | |
| | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| | attn_output = torch.matmul(attn_weights, value_states) |
| |
|
| | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| | raise ValueError( |
| | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| | f" {attn_output.size()}" |
| | ) |
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| |
|
| | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| |
|
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| | |
| |
|
| | class LlamaFlashAttention2(LlamaAttention): |
| | """ |
| | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays |
| | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| | flash attention and deal with padding tokens in case the input contains any of them. |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | |
| | |
| | |
| | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | modality_indicators: torch.Tensor, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | |
| | if "padding_mask" in kwargs: |
| | warnings.warn( |
| | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| | ) |
| |
|
| | |
| | attention_mask = kwargs.pop("padding_mask") |
| |
|
| | output_attentions = False |
| |
|
| | bsz, q_len, _ = hidden_states.size() |
| |
|
| | query_states = self.q_proj(hidden_states) |
| | key_states = self.k_proj(hidden_states, modality_indicators) |
| | value_states = self.v_proj(hidden_states, modality_indicators) |
| |
|
| | |
| | |
| | |
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | kv_seq_len = key_states.shape[-2] |
| | if past_key_value is not None: |
| | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| |
|
| | if past_key_value is not None: |
| | cache_kwargs = {"sin": sin, "cos": cos} |
| | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| |
|
| | |
| | |
| | query_states = query_states.transpose(1, 2) |
| | key_states = key_states.transpose(1, 2) |
| | value_states = value_states.transpose(1, 2) |
| |
|
| | dropout_rate = self.attention_dropout if self.training else 0.0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | input_dtype = query_states.dtype |
| | if input_dtype == torch.float32: |
| | if torch.is_autocast_enabled(): |
| | target_dtype = torch.get_autocast_gpu_dtype() |
| | |
| | elif hasattr(self.config, "_pre_quantization_dtype"): |
| | target_dtype = self.config._pre_quantization_dtype |
| | else: |
| | target_dtype = self.q_proj.weight.dtype |
| |
|
| | logger.warning_once( |
| | f"The input hidden states seems to be silently casted in float32, this might be related to" |
| | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| | f" {target_dtype}." |
| | ) |
| |
|
| | query_states = query_states.to(target_dtype) |
| | key_states = key_states.to(target_dtype) |
| | value_states = value_states.to(target_dtype) |
| |
|
| | attn_output = self._flash_attention_forward( |
| | query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate |
| | ) |
| |
|
| | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| | def _flash_attention_forward( |
| | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None |
| | ): |
| | """ |
| | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| | first unpad the input, then computes the attention scores and pad the final attention scores. |
| | |
| | Args: |
| | query_states (`torch.Tensor`): |
| | Input query states to be passed to Flash Attention API |
| | key_states (`torch.Tensor`): |
| | Input key states to be passed to Flash Attention API |
| | value_states (`torch.Tensor`): |
| | Input value states to be passed to Flash Attention API |
| | attention_mask (`torch.Tensor`): |
| | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| | position of padding tokens and 1 for the position of non-padding tokens. |
| | dropout (`int`, *optional*): |
| | Attention dropout |
| | softmax_scale (`float`, *optional*): |
| | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| | """ |
| | if not self._flash_attn_uses_top_left_mask: |
| | causal = self.is_causal |
| | else: |
| | |
| | causal = self.is_causal and query_length != 1 |
| |
|
| | |
| | if attention_mask is not None: |
| | batch_size = query_states.shape[0] |
| | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| | query_states, key_states, value_states, attention_mask, query_length |
| | ) |
| |
|
| | cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| |
|
| | attn_output_unpad = flash_attn_varlen_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_k=cu_seqlens_k, |
| | max_seqlen_q=max_seqlen_in_batch_q, |
| | max_seqlen_k=max_seqlen_in_batch_k, |
| | dropout_p=dropout, |
| | softmax_scale=softmax_scale, |
| | causal=causal, |
| | ) |
| |
|
| | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| | else: |
| | attn_output = flash_attn_func( |
| | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal |
| | ) |
| |
|
| | return attn_output |
| |
|
| | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| |
|
| | key_layer = index_first_axis( |
| | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| | ) |
| | value_layer = index_first_axis( |
| | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| | ) |
| | if query_length == kv_seq_len: |
| | query_layer = index_first_axis( |
| | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
| | ) |
| | cu_seqlens_q = cu_seqlens_k |
| | max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| | indices_q = indices_k |
| | elif query_length == 1: |
| | max_seqlen_in_batch_q = 1 |
| | cu_seqlens_q = torch.arange( |
| | batch_size + 1, dtype=torch.int32, device=query_layer.device |
| | ) |
| | indices_q = cu_seqlens_q[:-1] |
| | query_layer = query_layer.squeeze(1) |
| | else: |
| | |
| | attention_mask = attention_mask[:, -query_length:] |
| | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
| |
|
| | return ( |
| | query_layer, |
| | key_layer, |
| | value_layer, |
| | indices_q, |
| | (cu_seqlens_q, cu_seqlens_k), |
| | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| | ) |
| |
|
| |
|
| | class LlamaSdpaAttention(LlamaAttention): |
| | """ |
| | Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
| | `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
| | SDPA API. |
| | """ |
| |
|
| | |
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | modality_indicators: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | if output_attentions: |
| | |
| | logger.warning_once( |
| | "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
| | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| | ) |
| | return super().forward( |
| | hidden_states=hidden_states, |
| | modality_indicators=modality_indicators, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | bsz, q_len, _ = hidden_states.size() |
| |
|
| | query_states = self.q_proj(hidden_states) |
| | key_states = self.k_proj(hidden_states, modality_indicators) |
| | value_states = self.v_proj(hidden_states, modality_indicators) |
| |
|
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | kv_seq_len = key_states.shape[-2] |
| | if past_key_value is not None: |
| | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| |
|
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| |
|
| | if past_key_value is not None: |
| | cache_kwargs = {"sin": sin, "cos": cos} |
| | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| |
|
| | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| |
|
| | if attention_mask is not None: |
| | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| | raise ValueError( |
| | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| | ) |
| |
|
| | |
| | |
| | if query_states.device.type == "cuda" and attention_mask is not None: |
| | query_states = query_states.contiguous() |
| | key_states = key_states.contiguous() |
| | value_states = value_states.contiguous() |
| |
|
| | attn_output = torch.nn.functional.scaled_dot_product_attention( |
| | query_states, |
| | key_states, |
| | value_states, |
| | attn_mask=attention_mask, |
| | dropout_p=self.attention_dropout if self.training else 0.0, |
| | |
| | is_causal=self.is_causal and attention_mask is None and q_len > 1, |
| | ) |
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| |
|
| | attn_output = self.o_proj(attn_output) |
| |
|
| | return attn_output, None, past_key_value |
| |
|
| |
|
| |
|
| | LLAMA_ATTENTION_CLASSES = { |
| | "eager": LlamaAttention, |
| | "flash_attention_2": LlamaFlashAttention2, |
| | "sdpa": LlamaSdpaAttention, |
| | } |
| |
|
| | class LlamaDecoderLayer(nn.Module): |
| | def __init__(self, config: LlamaConfig, layer_idx): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.self_attn = LlamaAttention(config=config) |
| | |
| | attn_implementation = getattr(config, '_attn_implementation', 'eager') |
| | if attn_implementation in LLAMA_ATTENTION_CLASSES: |
| | self.self_attn = LLAMA_ATTENTION_CLASSES[attn_implementation](config=config, layer_idx=layer_idx) |
| | else: |
| | |
| | self.self_attn = LLAMA_ATTENTION_CLASSES['eager'](config=config, layer_idx=layer_idx) |
| | self.mlp = LlamaMLP(config) |
| | self.input_layernorm = MultiwayNetwork(module_provider=partial( |
| | LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps |
| | )) |
| | self.post_attention_layernorm = MultiwayNetwork(module_provider=partial( |
| | LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps |
| | )) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | modality_indicators: torch.Tensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: Optional[bool] = False, |
| | use_cache: Optional[bool] = False, |
| | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| | """ |
| | Args: |
| | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
| | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| | returned tensors for more detail. |
| | use_cache (`bool`, *optional*): |
| | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| | (see `past_key_values`). |
| | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| | """ |
| |
|
| | residual = hidden_states |
| |
|
| | hidden_states = self.input_layernorm(hidden_states, modality_indicators) |
| |
|
| | |
| | hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| | hidden_states=hidden_states, |
| | modality_indicators=modality_indicators, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | ) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | outputs = (hidden_states,) |
| |
|
| | if output_attentions: |
| | outputs += (self_attn_weights,) |
| |
|
| | if use_cache: |
| | outputs += (present_key_value,) |
| |
|
| | return outputs |
| |
|
| |
|
| | def model_forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | modality_indicators: torch.Tensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if not hasattr(self, '_use_flash_attention_2'): |
| | self._use_flash_attention_2 = getattr(self.config, '_attn_implementation', 'eager') == 'flash_attention_2' |
| | if not hasattr(self, '_use_sdpa'): |
| | self._use_sdpa = getattr(self.config, '_attn_implementation', 'eager') == 'sdpa' |
| |
|
| | |
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| | elif input_ids is not None: |
| | batch_size, seq_length = input_ids.shape |
| | elif inputs_embeds is not None: |
| | batch_size, seq_length, _ = inputs_embeds.shape |
| | else: |
| | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
| | |
| | seq_length_with_past = seq_length |
| | past_key_values_length = 0 |
| |
|
| | if past_key_values is not None: |
| | past_key_values_length = past_key_values[0][0].shape[2] |
| | seq_length_with_past = seq_length_with_past + past_key_values_length |
| |
|
| | if position_ids is None: |
| | device = input_ids.device if input_ids is not None else inputs_embeds.device |
| | position_ids = torch.arange( |
| | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
| | ) |
| | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| | else: |
| | position_ids = position_ids.view(-1, seq_length).long() |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| | |
| | if attention_mask is None: |
| | attention_mask = torch.ones( |
| | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device |
| | ) |
| | |
| | if self._use_flash_attention_2: |
| | |
| | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| | elif self._use_sdpa and not output_attentions: |
| | |
| | |
| | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| | attention_mask, |
| | (batch_size, seq_length), |
| | inputs_embeds, |
| | past_key_values_length, |
| | ) |
| | else: |
| | |
| | attention_mask = _prepare_4d_causal_attention_mask( |
| | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
| | ) |
| |
|
| | hidden_states = inputs_embeds |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | if use_cache: |
| | logger.warning_once( |
| | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| | ) |
| | use_cache = False |
| |
|
| | |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | next_decoder_cache = () if use_cache else None |
| |
|
| | for idx, decoder_layer in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | past_key_value = past_key_values[idx] if past_key_values is not None else None |
| |
|
| | if self.gradient_checkpointing and self.training: |
| |
|
| | def create_custom_forward(module): |
| | def custom_forward(*inputs): |
| | |
| | return module(*inputs, past_key_value, output_attentions) |
| |
|
| | return custom_forward |
| |
|
| | layer_outputs = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(decoder_layer), |
| | hidden_states, |
| | modality_indicators, |
| | attention_mask, |
| | position_ids, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | modality_indicators=modality_indicators, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if use_cache: |
| | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | next_cache = next_decoder_cache if use_cache else None |
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | ) |
| |
|
| |
|
| | def causal_model_forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | modality_indicators: torch.Tensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | r""" |
| | Args: |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| | |
| | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| | |
| | >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| | >>> inputs = tokenizer(prompt, return_tensors="pt") |
| | |
| | >>> # Generate |
| | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| | ```""" |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | outputs = self.model( |
| | input_ids=input_ids, |
| | modality_indicators=modality_indicators, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | if self.config.pretraining_tp > 1: |
| | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| | logits = torch.cat(logits, dim=-1) |
| | else: |
| | logits = self.lm_head(hidden_states) |
| | logits = logits.float() |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def replace_llama_modality_adaptive(): |
| | transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig |
| | transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention |
| | transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2 |
| | transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention |
| | transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer |
| | transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward |
| | transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward |
| |
|
| | |
| | if __name__ == "__main__": |
| | replace_llama_modality_adaptive() |
| | config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/') |
| | model = transformers.LlamaForCausalLM(config) |
| | print(model) |