|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ PyTorch StableLM Epoch model. """ | 
					
						
						|  | import importlib | 
					
						
						|  | import math | 
					
						
						|  | from typing import Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.utils.checkpoint | 
					
						
						|  | from accelerate import init_empty_weights | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | from flash_attn.flash_attn_interface import ( | 
					
						
						|  | flash_attn_varlen_qkvpacked_func, | 
					
						
						|  | ) | 
					
						
						|  | from torch import nn | 
					
						
						|  | from transformers import AutoConfig, AutoModelForCausalLM | 
					
						
						|  | from transformers.modeling_outputs import BaseModelOutputWithPast | 
					
						
						|  | from transformers.utils import logging | 
					
						
						|  |  | 
					
						
						|  | from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids | 
					
						
						|  |  | 
					
						
						|  | logger = logging.get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): | 
					
						
						|  |  | 
					
						
						|  | model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | 
					
						
						|  |  | 
					
						
						|  | with init_empty_weights(): | 
					
						
						|  | AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | 
					
						
						|  | module_name = model_config.__class__.__module__.replace( | 
					
						
						|  | ".configuration_stablelm_epoch", ".modeling_stablelm_epoch" | 
					
						
						|  | ) | 
					
						
						|  | modeling_stablelm = importlib.import_module(module_name) | 
					
						
						|  | modeling_stablelm.Attention.forward = ( | 
					
						
						|  | flashattn_attn | 
					
						
						|  | ) | 
					
						
						|  | modeling_stablelm.StableLMEpochModel.forward = ( | 
					
						
						|  | stablelm_model_forward | 
					
						
						|  | ) | 
					
						
						|  | modeling_stablelm.DecoderLayer.forward = ( | 
					
						
						|  | decoder_layer_forward | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rotate_half(x: torch.Tensor): | 
					
						
						|  | """Rotates half the hidden dims of the input.""" | 
					
						
						|  |  | 
					
						
						|  | x1, x2 = torch.chunk(x, 2, dim=-1) | 
					
						
						|  | return torch.cat((-x2, x1), dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cos = cos.squeeze(1).squeeze(0) | 
					
						
						|  | sin = sin.squeeze(1).squeeze(0) | 
					
						
						|  | cos = cos[position_ids].unsqueeze(1) | 
					
						
						|  | sin = sin[position_ids].unsqueeze(1) | 
					
						
						|  | q_embed = (q * cos) + (rotate_half(q) * sin) | 
					
						
						|  | k_embed = (k * cos) + (rotate_half(k) * sin) | 
					
						
						|  | return q_embed, k_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | 
					
						
						|  | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | 
					
						
						|  | """ | 
					
						
						|  | batch, num_key_value_heads, slen, head_dim = hidden_states.shape | 
					
						
						|  | if n_rep == 1: | 
					
						
						|  | return hidden_states | 
					
						
						|  | hidden_states = hidden_states[:, :, None, :, :].expand( | 
					
						
						|  | batch, num_key_value_heads, n_rep, slen, head_dim | 
					
						
						|  | ) | 
					
						
						|  | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def flashattn_attn( | 
					
						
						|  | self, | 
					
						
						|  | hidden_states: torch.FloatTensor, | 
					
						
						|  | attention_mask: torch.FloatTensor, | 
					
						
						|  | position_ids: torch.LongTensor, | 
					
						
						|  | past_key_value: Optional[Tuple[torch.Tensor]] = None, | 
					
						
						|  | output_attentions: Optional[bool] = False, | 
					
						
						|  | use_cache: Optional[bool] = False, | 
					
						
						|  | cu_seqlens: Optional[torch.Tensor] = None, | 
					
						
						|  | max_seqlen: Optional[torch.Tensor] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | 
					
						
						|  | bsz, q_len, _ = hidden_states.size() | 
					
						
						|  |  | 
					
						
						|  | query_states = self.q_proj(hidden_states) | 
					
						
						|  | key_states = self.k_proj(hidden_states) | 
					
						
						|  | value_states = self.v_proj(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | query_states = query_states.view( | 
					
						
						|  | bsz, q_len, self.num_heads, self.head_dim | 
					
						
						|  | ).transpose(1, 2) | 
					
						
						|  | key_states = key_states.view( | 
					
						
						|  | bsz, q_len, self.num_key_value_heads, self.head_dim | 
					
						
						|  | ).transpose(1, 2) | 
					
						
						|  | value_states = value_states.view( | 
					
						
						|  | bsz, q_len, self.num_key_value_heads, self.head_dim | 
					
						
						|  | ).transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | query_rot = query_states[..., : self.rotary_ndims] | 
					
						
						|  | query_pass = query_states[..., self.rotary_ndims :] | 
					
						
						|  | key_rot = key_states[..., : self.rotary_ndims] | 
					
						
						|  | key_pass = key_states[..., self.rotary_ndims :] | 
					
						
						|  |  | 
					
						
						|  | kv_seq_len = key_states.shape[-2] | 
					
						
						|  | if past_key_value is not None: | 
					
						
						|  | kv_seq_len += past_key_value[0].shape[-2] | 
					
						
						|  | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | 
					
						
						|  | query_states, key_states = apply_rotary_pos_emb( | 
					
						
						|  | query_rot, key_rot, cos, sin, position_ids | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | query_states = torch.cat((query_states, query_pass), dim=-1) | 
					
						
						|  | key_states = torch.cat((key_states, key_pass), dim=-1) | 
					
						
						|  |  | 
					
						
						|  | if past_key_value is not None: | 
					
						
						|  |  | 
					
						
						|  | key_states = torch.cat((past_key_value[0], key_states), dim=2) | 
					
						
						|  | value_states = torch.cat((past_key_value[1], value_states), dim=2) | 
					
						
						|  |  | 
					
						
						|  | past_key_value = (key_states, value_states) if use_cache else None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | key_states = repeat_kv(key_states, self.num_key_value_groups) | 
					
						
						|  | value_states = repeat_kv(value_states, self.num_key_value_groups) | 
					
						
						|  |  | 
					
						
						|  | if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: | 
					
						
						|  |  | 
					
						
						|  | qkv = torch.stack( | 
					
						
						|  | [query_states, key_states, value_states], dim=2 | 
					
						
						|  | ) | 
					
						
						|  | qkv = qkv.transpose(1, 3) | 
					
						
						|  | qkv = rearrange(qkv, "b s ... -> (b s) ...") | 
					
						
						|  | softmax_scale = None | 
					
						
						|  |  | 
					
						
						|  | output = flash_attn_varlen_qkvpacked_func( | 
					
						
						|  | qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz) | 
					
						
						|  | attn_output = rearrange(attn_output, "b s h d -> b s (h d)") | 
					
						
						|  | else: | 
					
						
						|  | attn_weights = torch.matmul( | 
					
						
						|  | query_states, key_states.transpose(2, 3) | 
					
						
						|  | ) / math.sqrt(self.head_dim) | 
					
						
						|  |  | 
					
						
						|  | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | 
					
						
						|  | f" {attn_weights.size()}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is not None: | 
					
						
						|  | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | 
					
						
						|  | ) | 
					
						
						|  | attn_weights = attn_weights + attention_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attn_weights = nn.functional.softmax( | 
					
						
						|  | attn_weights, dim=-1, dtype=torch.float32 | 
					
						
						|  | ).to(query_states.dtype) | 
					
						
						|  | attn_output = torch.matmul(attn_weights, value_states) | 
					
						
						|  |  | 
					
						
						|  | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | 
					
						
						|  | f" {attn_output.size()}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attn_output = attn_output.transpose(1, 2).contiguous() | 
					
						
						|  | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attn_output = self.o_proj(attn_output) | 
					
						
						|  |  | 
					
						
						|  | return attn_output, None, past_key_value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def decoder_layer_forward( | 
					
						
						|  | self, | 
					
						
						|  | hidden_states: Optional[torch.FloatTensor], | 
					
						
						|  | attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | past_key_value: Optional[Tuple[torch.Tensor]] = None, | 
					
						
						|  | output_attentions: Optional[bool] = False, | 
					
						
						|  | use_cache: Optional[bool] = False, | 
					
						
						|  | cu_seqlens: Optional[torch.Tensor] = None, | 
					
						
						|  | max_seqlen: Optional[torch.Tensor] = None, | 
					
						
						|  | ) -> Union[ | 
					
						
						|  | Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]] | 
					
						
						|  | ]: | 
					
						
						|  |  | 
					
						
						|  | residual = hidden_states | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.input_layernorm(hidden_states) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | hidden_states, self_attn_weights, present_key_value = self.self_attn( | 
					
						
						|  | hidden_states=hidden_states, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | position_ids=position_ids, | 
					
						
						|  | past_key_value=past_key_value, | 
					
						
						|  | output_attentions=output_attentions, | 
					
						
						|  | use_cache=use_cache, | 
					
						
						|  | cu_seqlens=cu_seqlens, | 
					
						
						|  | max_seqlen=max_seqlen, | 
					
						
						|  | ) | 
					
						
						|  | hidden_states = residual + hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | residual = hidden_states | 
					
						
						|  | hidden_states = self.post_attention_layernorm(hidden_states) | 
					
						
						|  | hidden_states = self.mlp(hidden_states) | 
					
						
						|  | hidden_states = residual + hidden_states | 
					
						
						|  |  | 
					
						
						|  | outputs = (hidden_states,) | 
					
						
						|  |  | 
					
						
						|  | if output_attentions: | 
					
						
						|  | outputs += (self_attn_weights,) | 
					
						
						|  |  | 
					
						
						|  | if use_cache: | 
					
						
						|  | outputs += (present_key_value,) | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def stablelm_model_forward( | 
					
						
						|  | self, | 
					
						
						|  | input_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | 
					
						
						|  | inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
						
						|  | use_cache: Optional[bool] = None, | 
					
						
						|  | output_attentions: Optional[bool] = None, | 
					
						
						|  | output_hidden_states: Optional[bool] = None, | 
					
						
						|  | return_dict: Optional[bool] = None, | 
					
						
						|  | ) -> Union[Tuple, BaseModelOutputWithPast]: | 
					
						
						|  |  | 
					
						
						|  | output_attentions = ( | 
					
						
						|  | output_attentions | 
					
						
						|  | if output_attentions is not None | 
					
						
						|  | else self.config.output_attentions | 
					
						
						|  | ) | 
					
						
						|  | output_hidden_states = ( | 
					
						
						|  | output_hidden_states | 
					
						
						|  | if output_hidden_states is not None | 
					
						
						|  | else self.config.output_hidden_states | 
					
						
						|  | ) | 
					
						
						|  | use_cache = use_cache if use_cache is not None else self.config.use_cache | 
					
						
						|  |  | 
					
						
						|  | return_dict = ( | 
					
						
						|  | return_dict if return_dict is not None else self.config.use_return_dict | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if input_ids is not None and inputs_embeds is not None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" | 
					
						
						|  | ) | 
					
						
						|  | if input_ids is not None: | 
					
						
						|  | batch_size, seq_length = input_ids.shape | 
					
						
						|  | elif inputs_embeds is not None: | 
					
						
						|  | batch_size, seq_length, _ = inputs_embeds.shape | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "You have to specify either decoder_input_ids or decoder_inputs_embeds" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | seq_length_with_past = seq_length | 
					
						
						|  | past_key_values_length = 0 | 
					
						
						|  |  | 
					
						
						|  | if past_key_values is not None: | 
					
						
						|  | past_key_values_length = past_key_values[0][0].shape[2] | 
					
						
						|  | seq_length_with_past = seq_length_with_past + past_key_values_length | 
					
						
						|  |  | 
					
						
						|  | cu_seqlens = None | 
					
						
						|  | max_seqlen = None | 
					
						
						|  | if position_ids is None: | 
					
						
						|  | device = input_ids.device if input_ids is not None else inputs_embeds.device | 
					
						
						|  | position_ids = torch.arange( | 
					
						
						|  | past_key_values_length, | 
					
						
						|  | seq_length + past_key_values_length, | 
					
						
						|  | dtype=torch.long, | 
					
						
						|  | device=device, | 
					
						
						|  | ) | 
					
						
						|  | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | 
					
						
						|  | else: | 
					
						
						|  | position_ids = position_ids.view(-1, seq_length).long() | 
					
						
						|  | cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) | 
					
						
						|  | cu_seqlens = cu_seqlens.squeeze() | 
					
						
						|  |  | 
					
						
						|  | if inputs_embeds is None: | 
					
						
						|  | inputs_embeds = self.embed_tokens(input_ids) | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is None: | 
					
						
						|  | attention_mask = torch.ones( | 
					
						
						|  | (batch_size, seq_length_with_past), | 
					
						
						|  | dtype=torch.bool, | 
					
						
						|  | device=inputs_embeds.device, | 
					
						
						|  | ) | 
					
						
						|  | attention_mask = ( | 
					
						
						|  | self._prepare_decoder_attention_mask( | 
					
						
						|  | attention_mask, | 
					
						
						|  | (batch_size, seq_length), | 
					
						
						|  | inputs_embeds, | 
					
						
						|  | past_key_values_length, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = inputs_embeds | 
					
						
						|  |  | 
					
						
						|  | if self.gradient_checkpointing and self.training: | 
					
						
						|  | if use_cache: | 
					
						
						|  | logger.warning( | 
					
						
						|  | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | 
					
						
						|  | ) | 
					
						
						|  | use_cache = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_hidden_states = () if output_hidden_states else None | 
					
						
						|  | all_self_attns = () if output_attentions else None | 
					
						
						|  | next_decoder_cache = () if use_cache else None | 
					
						
						|  |  | 
					
						
						|  | for idx, decoder_layer in enumerate(self.layers): | 
					
						
						|  | if output_hidden_states: | 
					
						
						|  | all_hidden_states += (hidden_states,) | 
					
						
						|  |  | 
					
						
						|  | past_key_value = past_key_values[idx] if past_key_values is not None else None | 
					
						
						|  |  | 
					
						
						|  | if self.gradient_checkpointing and self.training: | 
					
						
						|  |  | 
					
						
						|  | def create_custom_forward(module): | 
					
						
						|  | def custom_forward(*inputs): | 
					
						
						|  |  | 
					
						
						|  | return module(*inputs) | 
					
						
						|  |  | 
					
						
						|  | return custom_forward | 
					
						
						|  |  | 
					
						
						|  | layer_outputs = torch.utils.checkpoint.checkpoint( | 
					
						
						|  | create_custom_forward(decoder_layer), | 
					
						
						|  | hidden_states, | 
					
						
						|  | attention_mask, | 
					
						
						|  | position_ids, | 
					
						
						|  | past_key_value, | 
					
						
						|  | output_attentions, | 
					
						
						|  | None, | 
					
						
						|  | cu_seqlens, | 
					
						
						|  | max_seqlen, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | layer_outputs = decoder_layer( | 
					
						
						|  | hidden_states, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | position_ids=position_ids, | 
					
						
						|  | past_key_value=past_key_value, | 
					
						
						|  | output_attentions=output_attentions, | 
					
						
						|  | use_cache=use_cache, | 
					
						
						|  | cu_seqlens=cu_seqlens, | 
					
						
						|  | max_seqlen=max_seqlen, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = layer_outputs[0] | 
					
						
						|  |  | 
					
						
						|  | if use_cache: | 
					
						
						|  | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | 
					
						
						|  |  | 
					
						
						|  | if output_attentions: | 
					
						
						|  | all_self_attns += (layer_outputs[1],) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.norm(hidden_states) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if output_hidden_states: | 
					
						
						|  | all_hidden_states += (hidden_states,) | 
					
						
						|  |  | 
					
						
						|  | next_cache = next_decoder_cache if use_cache else None | 
					
						
						|  | if not return_dict: | 
					
						
						|  | return tuple( | 
					
						
						|  | v | 
					
						
						|  | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] | 
					
						
						|  | if v is not None | 
					
						
						|  | ) | 
					
						
						|  | return BaseModelOutputWithPast( | 
					
						
						|  | last_hidden_state=hidden_states, | 
					
						
						|  | past_key_values=next_cache, | 
					
						
						|  | hidden_states=all_hidden_states, | 
					
						
						|  | attentions=all_self_attns, | 
					
						
						|  | ) | 
					
						
						|  |  |