import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import transformers from typing import Optional, Dict, Tuple, List, Union, Unpack, Sequence, Any from flash_attn import ( flash_attn_kvpacked_func, flash_attn_varlen_func ) from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb from flash_attn.ops.triton.layer_norm import RMSNorm from flash_attn.modules.mlp import GatedMlp from flash_attn.losses.cross_entropy import CrossEntropyLoss from einops import rearrange from itertools import chain from flash_attn.bert_padding import unpad_input from .configuration_sophie0 import Sophie0Config from transformers.modeling_utils import PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast ######################################################### # --- basic functions --- ######################################################### class Cache(transformers.cache_utils.Cache): """ A cache used for storing hidden states produced by flash linear attention models. **Input:** - attn_state: Cache for standard attention, tuple(size(bsz, k_len/v_len, dmodel) * 2) """ is_compileable = True def __init__(self, cache_position: int = 0): super().__init__() self.states: List[Dict[str, Any]] = [] self._cache_position = [cache_position] # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> Dict[str, Any]: if layer_idx < len(self): return self.states[layer_idx] else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): for state in self.states: yield state def __len__(self): return len(self.states) def update( self, attn_state: Tuple[torch.Tensor, torch.Tensor] = None, layer_idx: int = 0, offset: Optional[int] = 1, cache_kwargs: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: """ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. Args: attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): The new attention key/value states to cache. layer_idx (`int`, defaults to 0): The index of the layer to cache the states for. offset (`int`, `optional`, defaults to 1): The number of new tokens being processed. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. Return: Dictionary of the updated state. """ # Update the number of seen tokens if len(self._cache_position) <= layer_idx: self._cache_position.append(0) self._cache_position[layer_idx] += offset if attn_state is not None: input_size = attn_state[0].shape[-2] window_size = cache_kwargs.get('window_size', None) if not isinstance(attn_state, Tuple) or len(attn_state) != 2: raise ValueError("`attn_state` must be a tuple of two tensors for key/value states") if len(self.states) <= layer_idx: if attn_state is not None: if window_size is not None and input_size > window_size: attn_state = (attn_state[0][..., -window_size:, :].contiguous(), attn_state[1][..., -window_size:, :].contiguous()) state = dict( attn_state=attn_state, ) self.states.append(state) else: state = self.states[layer_idx] if attn_state is not None: if state['attn_state'] is None: if window_size is not None and input_size > window_size: attn_state = (attn_state[0][..., -window_size:, :].contiguous(), attn_state[1][..., -window_size:, :].contiguous()) else: key_state, value_state = state['attn_state'] if window_size is not None and key_state.shape[-2] == window_size: # DO NOT allocate new memory if the cache is full # roll the key/value states to the left by `input_size` key_state = key_state.roll(-input_size, -2) value_state = value_state.roll(-input_size, -2) # replace the last `input_size` tokens with the new key/value states key_state[..., -input_size:, :] = attn_state[0] value_state[..., -input_size:, :] = attn_state[1] attn_state = (key_state, value_state) else: attn_state = (torch.cat([key_state, attn_state[0]], -2), torch.cat([value_state, attn_state[1]], -2),) state['attn_state'] = attn_state return state def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" if len(self.states) <= layer_idx: return 0 return self._cache_position[layer_idx] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" return None def to_legacy_cache(self) -> Tuple: return tuple(self.states) def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.states)): for k in self.states[layer_idx].keys(): if isinstance(self.states[layer_idx][k], torch.Tensor): device = self.states[layer_idx][k].device self.states[layer_idx][k] = self.states[layer_idx][k].index_select(0, beam_idx.to(device)) elif isinstance(self.states[layer_idx][k], Tuple): _temp = [] for i in range(len(self.states[layer_idx][k])): device = self.states[layer_idx][k][i].device _temp.append(self.states[layer_idx][k][i].index_select(0, beam_idx.to(device))) self.states[layer_idx][k] = tuple(_temp) @classmethod @torch.compiler.disable def from_legacy_cache( cls, past_key_values: Optional[Tuple] = None, cache_position: int = 0 ): """Converts a cache in the legacy cache format into an equivalent `Cache`.""" cache = cls(cache_position) if isinstance(past_key_values, list): for layer_idx in range(len(past_key_values)): cache.states.append(past_key_values[layer_idx]) return cache class VarlenCache(transformers.cache_utils.Cache): """ A varlen cache used for storing hidden states produced by varlen batch inference. **Input:** - attn_state: Cache for standard attention, tuple(size(total_nnz, dmodel) * 2) """ is_compileable = True def __init__(self, cache_position: int = 0, batch_size: int = 1, device: str | torch.device = None): super().__init__() self.states: List[Dict[str, Any]] = [] self._cache_position = [torch.full((batch_size,), cache_position, dtype=torch.int64, device=device)] # Used in `generate` to keep tally of how many tokens the cache has seen self.batch_size = batch_size self.device = device def __getitem__(self, layer_idx: int) -> Dict[str, Any]: if layer_idx < len(self): return self.states[layer_idx] else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): for state in self.states: yield state def __len__(self): return len(self.states) def update( self, attn_state: Tuple[torch.Tensor, torch.Tensor] = None, cu_seqlens: torch.LongTensor = None, layer_idx: int = 0, cache_kwargs: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: """ Updates the cache with the new `attn_state` for the layer `layer_idx`. Args: attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): The new attention key/value states to cache, sizes (total_nnz, hidden_size) cu_seqlens (`torch.LongTensor`): the accumulated sequence length for current states, sizes (bsz + 1,) layer_idx (`int`, defaults to 0): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. Return: Dictionary of the updated state. """ if attn_state is not None: if not isinstance(attn_state, Tuple) or len(attn_state) != 2: raise ValueError("`attn_state` must be a tuple of two tensors for key/value states") dtype = attn_state[0].dtype device = attn_state[0].device hidden_size = attn_state[0].size(-1) # Case 1: prefill at the 1st step if len(self._cache_position) <= layer_idx: self._cache_position.append( torch.zeros((cu_seqlens.size(0) - 1,), dtype=torch.int64, device=cu_seqlens.device) ) kv_seqlens = cu_seqlens[1:] - cu_seqlens[:-1] kv_seqlens_cpu = kv_seqlens.cpu().tolist() self._cache_position[layer_idx] += kv_seqlens if len(self.states) <= layer_idx: key_state, value_state = list(map(lambda x: torch.split(x, kv_seqlens_cpu), attn_state)) state = dict( attn_state=(key_state, value_state), cu_seqlens=cu_seqlens, max_seqlen=kv_seqlens.max().item() ) self.states.append(state) # Case 2: append current step's kv cache else: state = self.states[layer_idx] if state["attn_state"] is not None: key_state, value_state = list(map(lambda x: torch.split(x, kv_seqlens_cpu), attn_state)) key_cache, value_cache = state['attn_state'] old_cu_seqlens = state['cu_seqlens'] key_cache = tuple(map(lambda x, y: torch.cat([x, y], dim=0), key_cache, key_state)) value_cache = tuple(map(lambda x, y: torch.cat([x, y], dim=0), value_cache, value_state)) new_cu_seqlens = old_cu_seqlens + cu_seqlens state.update( attn_state=(key_cache, value_cache), cu_seqlens=new_cu_seqlens, max_seqlen=(new_cu_seqlens[1:] - new_cu_seqlens[:-1]).max().item() ) return state def get_kv_cache(self, state: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: return tuple(map(lambda x: torch.cat(x, 0), state['attn_state'])) def get_seq_length(self, layer_idx: Optional[int] = 0) -> torch.Tensor: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" if len(self.states) <= layer_idx: return torch.zeros(self.batch_size, dtype=torch.int64, device=self.device) return self._cache_position[layer_idx] def get_cu_seq_length(self, layer_idx: Optional[int] = 0) -> torch.Tensor: """Returns the accumulated sequence length of the cached states. A layer index can be optionally passed.""" if len(self.states) <= layer_idx: return torch.zeros(self.batch_size + 1, dtype=torch.int64, device=self.device) return self.states[layer_idx]['cu_seqlens'] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" return None def to_legacy_cache(self) -> Tuple: return tuple(self.states) def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" raise NotImplementedError("Varlen Batch Inference does not support beam search at now.") @classmethod @torch.compiler.disable def from_legacy_cache( cls, past_key_values: Optional[Tuple] = None, cache_position: int = 0, batch_size: int = 1, device: str | torch.device = None ): """Converts a cache in the legacy cache format into an equivalent `Cache`.""" cache = cls(cache_position, batch_size=batch_size, device=device) if isinstance(past_key_values, list): for layer_idx in range(len(past_key_values)): cache.states.append(past_key_values[layer_idx]) return cache @torch.no_grad() def linear_init( linear: nn.Linear, distribution: Optional[str]='normal', zero_bias: Optional[bool]=False, gain: Optional[float]=1.0 ) ->None: if distribution == 'normal': nn.init.xavier_normal_(linear.weight, gain=gain) elif distribution == 'uniform': nn.init.xavier_uniform_(linear.weight, gain=gain) if linear.bias is not None: if zero_bias: nn.init.zeros_(linear.bias) else: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(linear.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(linear.bias, -bound, bound) @torch.no_grad() def embedding_init(embedding: nn.Embedding) ->None: fan_out = embedding.weight.size(1) std = 1.0 * math.sqrt(1.0 / float(fan_out)) nn.init.normal_(embedding.weight, 0., std) if embedding.padding_idx is not None: embedding.weight[embedding.padding_idx].fill_(0) def sparse_to_dense(src: torch.Tensor, length: torch.Tensor) ->torch.Tensor: maxLength = length.max().item() length = length.cpu().numpy() broadcastIdx = np.arange(length[0], dtype=np.int64) for i in range(1, length.shape[0]): broadcastIdx = np.concatenate([broadcastIdx, np.arange(length[i], dtype=np.int64) + maxLength * i], axis=0) broadcastIdx = torch.tensor(broadcastIdx, dtype=torch.int64, device=src.device) tgt = torch.zeros((length.shape[0] * maxLength, src.size(-1)), dtype=src.dtype, device=src.device) tgt[broadcastIdx] = src tgt = tgt.reshape(length.shape[0], maxLength, -1).contiguous() return tgt ######################################################### # --- model --- ######################################################### class FullAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rotary_base: int, dropout: float, layer_idx: int, **kwargs ): super(FullAttention, self).__init__() self.hidden_size = hidden_size self.num_q_heads = num_heads self.num_kv_heads = num_kv_heads self.head_size = hidden_size // num_heads self.dropout = dropout self.layer_idx = layer_idx self.qkv = nn.Linear(hidden_size, hidden_size + 2 * num_kv_heads * self.head_size, bias=False) self.out = nn.Linear(hidden_size, hidden_size, bias=False) self.rotary = RotaryEmbedding(dim=self.head_size, base=rotary_base) self._init_weights() def _init_weights(self): for k, v in self.named_modules(): if isinstance(v, nn.Linear): linear_init(v, zero_bias=True) def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor=None, max_seqlen: int=None, causal: bool=True, past_key_values: Cache | VarlenCache=None): """ Training with varlen: x -> size(B*L, D) cu_seqlens -> size(B+1) Generating with padding: x -> size(B, L, D) cu_seqlens -> None """ if cu_seqlens is None: qkv: torch.Tensor = self.qkv(x) qkv = rearrange(qkv, "B L (H D) -> B L H D", H=(self.num_q_heads + 2 * self.num_kv_heads), D=self.head_size) q, kv = torch.split(qkv, [self.num_q_heads, 2 * self.num_kv_heads], dim=-2) kv = rearrange(kv, "B L (C H) D -> B L C H D", C=2, H=self.num_kv_heads) if past_key_values is not None: seqlen_offset = past_key_values.get_seq_length(self.layer_idx) _max_seqlen = q.size(1) + seqlen_offset q, kv = self.rotary(q, kv, seqlen_offset=seqlen_offset, max_seqlen=_max_seqlen, num_heads_q=self.num_q_heads) k, v = kv.unbind(dim=2) k, v = past_key_values.update( attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), layer_idx=self.layer_idx, offset=q.size(1), cache_kwargs=dict() )["attn_state"] k, v = rearrange(k, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size), rearrange(v, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size) kv = torch.cat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) else: q, kv = self.rotary(q, kv) out = flash_attn_kvpacked_func(q, kv, dropout_p=self.dropout if self.training else 0, causal=causal) out = self.out(rearrange(out, "B L H D -> B L (H D)")) else: qkv: torch.Tensor = self.qkv(x) qkv = rearrange(qkv, "L (H D) -> L H D", H=(self.num_q_heads + 2 * self.num_kv_heads), D=self.head_size) q, k, v = torch.split(qkv, [self.num_q_heads, self.num_kv_heads, self.num_kv_heads], dim=-2) if past_key_values is not None: assert isinstance(past_key_values, VarlenCache) seqlen_offset = past_key_values.get_seq_length(self.layer_idx) _seqlen = cu_seqlens[1:] - cu_seqlens[:-1] _max_seqlen = (seqlen_offset + _seqlen).max().item() self.rotary._update_cos_sin_cache(seqlen=_max_seqlen, device=q.device, dtype=q.dtype) q, k = apply_rotary_emb(q, self.rotary._cos_cached, self.rotary._sin_cached, seqlen_offsets=seqlen_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen),\ apply_rotary_emb(k, self.rotary._cos_cached, self.rotary._sin_cached, seqlen_offsets=seqlen_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) new_cache = past_key_values.update( attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), cu_seqlens=cu_seqlens, layer_idx=self.layer_idx, cache_kwargs=dict() ) k, v = past_key_values.get_kv_cache(new_cache) k, v = rearrange(k, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size), rearrange(v, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size) kv_cu_seqlens, kv_max_seqlen = new_cache['cu_seqlens'], new_cache['max_seqlen'] out = flash_attn_varlen_func(q, k, v, cu_seqlens, kv_cu_seqlens, max_seqlen, kv_max_seqlen, dropout_p=self.dropout if self.training else 0, causal=causal) else: self.rotary._update_cos_sin_cache(seqlen=max_seqlen, device=q.device, dtype=q.dtype) q, k = apply_rotary_emb(q, self.rotary._cos_cached, self.rotary._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen),\ apply_rotary_emb(k, self.rotary._cos_cached, self.rotary._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) out = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=self.dropout if self.training else 0, causal=causal) out = self.out(rearrange(out, "L H D -> L (H D)")) return out, None, past_key_values class TransformerBlock(nn.Module): def __init__(self, config: Sophie0Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.eps) self.attn = FullAttention( hidden_size=config.hidden_size, num_heads=config.num_heads, num_kv_heads=config.num_kv_heads, rotary_base=config.rope_base, dropout=config.dropout, layer_idx=self.layer_idx ) self.ffn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.eps) self.ffn = GatedMlp( in_features=config.hidden_size, hidden_features=config.intermediate_size, activation=F.silu, bias1=False, bias2=False, multiple_of=1 ) self._init_weights() def _init_weights(self): for k, v in self.ffn.named_modules(): if isinstance(v, nn.Linear): linear_init(v, zero_bias=True) def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor=None, max_seqlen: int=None, causal: bool=True, past_key_values: Cache=None): out, _, past_key_values = self.attn(self.attn_norm(x), cu_seqlens, max_seqlen, causal, past_key_values) x = x + out x = x + self.ffn(self.ffn_norm(x)) return (x, _, past_key_values) class Sophie0PretraindModel(PreTrainedModel): config_class = Sophie0Config supports_gradient_checkpointing = True _supports_cache_class = True _no_split_modules = ["TransformerBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights(self, module: nn.Module): if isinstance(module, nn.Embedding): embedding_init(module) elif isinstance(module, nn.Linear): linear_init(module, zero_bias=True) class Sophie0Model(Sophie0PretraindModel): def __init__(self, config: Sophie0Config, **kwargs): super().__init__(config, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.eps) self.post_init() def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, value): self.embeddings = value def forward( self, input_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.LongTensor] = None, max_seqlen: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Union[Cache, VarlenCache, List[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, **kwargs: Unpack[Dict] ) -> Union[Tuple, BaseModelOutputWithPast]: output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", False) if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) hidden_states = inputs_embeds if cu_seqlens is not None: if use_cache and not isinstance(past_key_values, VarlenCache): past_key_values = VarlenCache.from_legacy_cache(past_key_values, batch_size=cu_seqlens.size(0)-1, device=cu_seqlens.device) else: if use_cache and not isinstance(past_key_values, Cache): past_key_values = Cache.from_legacy_cache(past_key_values) if kwargs.get("use_gradient_checkpoint", False) is True and self.supports_gradient_checkpointing and self.training: self.gradient_checkpointing = True else: self.gradient_checkpointing = False all_hidden_states = () if output_hidden_states else None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states, _, past_key_values = checkpoint.checkpoint( layer.__call__, hidden_states, cu_seqlens, max_seqlen, True, past_key_values, use_reentrant=False ) else: hidden_states, _, past_key_values = layer(hidden_states, cu_seqlens, max_seqlen, True, past_key_values) hidden_states = self.norm(hidden_states) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, past_key_values] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=None ) class Sophie0ForCausalLM(Sophie0PretraindModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: Sophie0Config): super().__init__(config) self.model = Sophie0Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.criterion = None self.post_init() def get_input_embeddings(self): return self.model.embeddings def set_input_embeddings(self, value): self.model.embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def generate(self, *args, **kwargs): try: return super().generate(*args, **kwargs) except AttributeError as exception: if 'past_key_values' in str(exception): raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" ) else: raise exception def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[int] = None, use_cache: Optional[bool] = True, logits_to_keep = None, cu_seqlens: Optional[torch.LongTensor] = None, max_seqlen: Optional[int] = None, use_varlen_inference: Optional[bool]=False, **kwargs ): if inputs_embeds is not None and len(past_key_values) == 0: model_inputs = {'inputs_embeds': inputs_embeds} else: if past_key_values is not None and len(past_key_values) > 0: input_ids = input_ids[:, -1:] if isinstance(past_key_values, VarlenCache): input_ids = input_ids.squeeze(-1) cu_seqlens = torch.arange(past_key_values.batch_size + 1, dtype=torch.int32, device=input_ids.device) max_seqlen = 1 else: if use_varlen_inference: input_ids, _, cu_seqlens, max_seqlen, _ = unpad_input(input_ids.unsqueeze(-1), attention_mask) input_ids = input_ids.squeeze(-1) model_inputs = {'input_ids': input_ids.contiguous()} if logits_to_keep is not None: model_inputs['logits_to_keep'] = logits_to_keep model_inputs.update({ 'past_key_values': past_key_values, 'use_cache': use_cache, 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen }) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.LongTensor] = None, max_seqlen: Optional[int] = None, use_varlen_inference: Optional[bool]=False, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Union[Cache, VarlenCache, List[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, labels_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, **kwargs: Unpack[Dict] ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) past_key_values = outputs.past_key_values loss = None if labels is not None: self.criterion = CrossEntropyLoss(ignore_index=self.config.pad_token_id, reduction="mean" if labels_mask is None else "none") if logits.dim() == 2: # varlen assert labels.dim() == 1 loss = self.criterion(logits, labels) if labels_mask is not None: loss = loss * labels_mask loss = loss.sum() / labels_mask.sum() else: loss = loss.mean() else: assert labels.dim() == 2 if self.config.right_shift: labels = labels[:, 1:] logits = logits[:, :-1].contiguous() loss = self.criterion(logits.flatten(0, 1), labels.flatten(0, 1)) if labels_mask is not None: loss = loss * labels_mask.flatten(0, 1) loss = loss.sum() / labels_mask.sum() else: loss = loss.mean() else: if isinstance(past_key_values, VarlenCache): kv_cu_seqlens = past_key_values.get_cu_seq_length() if logits.size(0) > past_key_values.batch_size: logits = logits.index_select(0, kv_cu_seqlens[1:] - 1) logits = logits.unsqueeze(1) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions )