import math
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    logging,
)
from typing import List, Optional, Tuple, Union

from .configuration_gpt_refact import GPTRefactConfig

logger = logging.get_logger(__name__)


@torch.jit.script
def upcast_masked_softmax(
        x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, softmax_dtype: torch.dtype
):
    input_dtype = x.dtype
    x = x.to(softmax_dtype)
    x = torch.where(mask, x, mask_value)
    x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
    return x


@torch.jit.script
def upcast_softmax(x: torch.Tensor, softmax_dtype: torch.dtype):
    input_dtype = x.dtype
    x = x.to(softmax_dtype)
    x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
    return x


@torch.jit.script
def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
    """
    ## Get head-specific slope $m$ for each head
    * `n_heads` is the number of heads in the attention layer $n$
    The slope for first head is
    $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
    The slopes for the rest of the heads are in a geometric series with a ratio same as above.
    For instance when the number of heads is $8$ the slopes are
    $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
    """

    # Get the closest power of 2 to `n_heads`.
    # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
    # and then add the remaining slopes.
    n = 2 ** math.floor(math.log(attn_heads, 2))
    # $2^{-\frac{8}{n}}$
    m_0 = 2.0 ** (-8.0 / n)
    # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
    m = torch.pow(m_0, torch.arange(1, 1 + n, device=dev))

    # If `n_heads` is not a power of 2, then we add the remaining slopes.
    # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
    # And pick the slopes upto `n_heads`.
    if n < attn_heads:
        # $2^{-\frac{8}{2n}}$
        m_hat_0 = 2.0 ** (-4.0 / n)
        # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
        # Note that we take steps by $2$ to avoid slopes added previously.
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
        # Concatenate the slopes with the remaining slopes.
        m = torch.cat([m, m_hat])
    return m

@torch.jit.script
def get_alibi_biases(
        B: int,
        T: int,
        attn_heads: int,
        dev: torch.device,
        dtype: torch.dtype) -> torch.Tensor:
    """
    ## Calculate the attention biases matrix
    * `n_heads` is the number of heads in the attention layer
    * `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
    This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
    """

    # Get slopes $m$ for each head
    mask = torch.ones((T, T), device=dev, dtype=torch.bool)

    m = _get_slopes(attn_heads, dev).to(dtype)

    # Calculate distances $[0, 1, \dots, N]$
    # Here we calculate the distances using the mask.
    #
    # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
    # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
    distance = mask.cumsum(dim=-1).to(dtype)

    # Multiply them pair-wise to get the AliBi bias matrix
    biases = distance[:, :, None] * m[None, None, :]
    biases = biases.permute(2, 0, 1)[None, :, :T, :T]
    return biases.contiguous()


class Attention(nn.Module):

    def __init__(self, config, layer_idx=None):
        super().__init__()
        self.mask_value = None

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.kv_attn_heads = 1

        self.scale_factor = self.head_dim ** -0.5

        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.layer_idx = layer_idx
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        self.scale_attention_softmax_in_fp32 = (
                config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
        )
        self.attention_bias_in_fp32 = config.attention_bias_in_fp32

        self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.kv = nn.Linear(self.embed_dim, self.head_dim * 2, bias=False)
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

    def _get_mask_value(self, device, dtype):
        # torch.where expects a tensor. We use a cache to avoid recreating it every time.
        if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
            self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
        return self.mask_value

    def _attn(self, query, key, value, attention_mask=None, alibi=None):
        dtype = query.dtype
        softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
        mask_value = self._get_mask_value(query.device, softmax_dtype)
        upcast = dtype != softmax_dtype

        query_shape = query.shape
        batch_size = query_shape[0]
        key_length = key.size(-1)

        # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
        # -> (batch_size, query_length, num_heads, key_length)
        query_length = query_shape[1]
        attn_shape = (batch_size, query_length, self.num_heads, key_length)
        attn_view = (batch_size, query_length * self.num_heads, key_length)
        # No copy needed for MQA 2, or when layer_past is provided.
        query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)

        alibi = alibi.transpose(2, 1).reshape(alibi.shape[0], -1, alibi.shape[-1])
        initial_dtype = query.dtype
        new_dtype = torch.float32 if self.attention_bias_in_fp32 else initial_dtype
        attn_weights = alibi.baddbmm(
            batch1=query.to(new_dtype),
            batch2=key.to(new_dtype),
            beta=1,
            alpha=self.scale_factor
        ).view(attn_shape).to(initial_dtype)

        if upcast:
            # Use a fused kernel to prevent a large overhead from casting and scaling.
            # Sub-optimal when the key length is not a multiple of 8.
            if attention_mask is None:
                attn_weights = upcast_softmax(attn_weights, softmax_dtype)
            else:
                attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, softmax_dtype)
        else:
            if attention_mask is not None:
                # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
                attn_weights = torch.where(attention_mask, attn_weights, mask_value)
            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

        attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)

        return attn_output, attn_weights

    def forward(
            self,
            hidden_states: torch.Tensor,
            layer_past: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            alibi: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = False,
            output_attentions: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]],
        Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
    ]:
        query = self.q(hidden_states)
        kv = self.kv(hidden_states)
        key, value = kv.split(self.head_dim, dim=-1)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
        attn_output = self.c_proj(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            attn_weights = attn_weights.transpose(1, 2)
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)


class MLP(nn.Module):

    def __init__(self, intermediate_size, config, multiple_of: int = 256):
        super().__init__()
        embed_dim = config.hidden_size
        hidden_dim = intermediate_size
        hidden_dim = int(2 * hidden_dim / 3)
        self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.gate_up_proj = nn.Linear(embed_dim, self.hidden_dim * 2, bias=False)
        self.c_proj = nn.Linear(self.hidden_dim, embed_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up_proj = self.gate_up_proj(x)
        x1, x2 = torch.split(up_proj, self.hidden_dim, dim=-1)
        x = self.c_proj(F.silu(x1) * x2)
        return x


class LayerNormNoBias(nn.Module):

    def __init__(self, shape: int, eps: float = 1e-5):
        super().__init__()
        self.shape = (shape,)
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(self.shape))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(x, self.shape, self.weight, None, self.eps)


class GPTRefactBlock(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        hidden_size = config.hidden_size
        self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = Attention(config, layer_idx=layer_idx)
        self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(self.inner_dim, config)

    def forward(
            self,
            hidden_states: Optional[Tuple[torch.Tensor]],
            layer_past: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            alibi: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = False,
            output_attentions: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ]:
        hidden_states_norm = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states_norm,
            layer_past=layer_past,
            attention_mask=attention_mask,
            alibi=alibi,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        mix = attn_output + hidden_states

        norm_mix = self.ln_2(mix)
        feed_forward_hidden_states = self.mlp(norm_mix)
        # residual connection
        hidden_states = mix + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)


class GPTRefactPreTrainedModel(PreTrainedModel):

    config_class = GPTRefactConfig
    base_model_prefix = "transformer"
    supports_gradient_checkpointing = True
    _no_split_modules = ["GPTRefactBlock"]
    _skip_keys_device_placement = "past_key_values"

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        if isinstance(module, (MLP, Attention)):
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
            #
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
            module.c_proj.weight.data.normal_(
                mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
            )
            module.c_proj._is_hf_initialized = True
        elif isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, LayerNormNoBias):
            module.weight.data.fill_(1.0)


class GPTRefactModel(GPTRefactPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.multi_query = config.multi_query
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)

        self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])

        self.max_positions = config.max_position_embeddings
        self.attention_bias_in_fp32 = config.attention_bias_in_fp32
        self.register_buffer(
            "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
            persistent=False
        )

        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.wte

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            past_key_values: Optional[List[torch.Tensor]] = None,
            attention_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if batch_size <= 0:
            raise ValueError("batch_size has to be defined and > 0")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)

        query_length = input_shape[-1]
        seq_length_with_past = past_length + query_length

        # Self-attention mask.
        key_length = past_length + query_length
        self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
        if attention_mask is not None:
            self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
                dtype=torch.bool, device=self_attention_mask.device
            )

        # MQA models: (batch_size, query_length, n_heads, key_length)
        attention_mask = self_attention_mask.unsqueeze(2)

        hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds

        alibi_dtype = torch.float32 if self.attention_bias_in_fp32 else self.wte.weight.dtype
        alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
                                 self.num_heads, device, alibi_dtype)[:, :, -query_length:, :]

        output_shape = input_shape + (hidden_states.size(-1),)

        presents = [] if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache, output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    None,
                    attention_mask,
                    alibi
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    alibi=alibi,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]
            if use_cache:
                presents.append(outputs[1])

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

        hidden_states = hidden_states.view(output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class GPTRefactForCausalLM(GPTRefactPreTrainedModel):

    _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPTRefactModel(config)
        self.ln_f = LayerNormNoBias(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()
        
        # gradient checkpointing support for lower versions of transformers
        import transformers
        from packaging import version

        def _set_gradient_checkpointing(module, value=False):
            if isinstance(module, GPTRefactModel):
                module.gradient_checkpointing = value

        v = version.parse(transformers.__version__)
        if v.major <= 4 and v.minor < 35:
            self._set_gradient_checkpointing = _set_gradient_checkpointing

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            if past_key_values is not None:
                model_inputs = {"input_ids": input_ids[..., -1:]}
            else:
                model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
            }
        )
        return model_inputs

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
            attention_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        x = self.ln_f(hidden_states)
        lm_logits = self.lm_head(x)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

    @staticmethod
    def _reorder_cache(
            past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        """
        return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)