File size: 6,424 Bytes
edc45da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# coding=utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple
import torch
from torch import nn
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
if "910" in torch.npu.get_device_name():
NPU_ATTN_INFR = True
print("[INFO] torch_npu detected. Using NPU fused infer attention.")
else:
NPU_ATTN_INFR = False
except ImportError:
NPU_ATTN_INFR = False
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.processing_utils import Unpack
from transformers.utils import logging
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaMLP,
LlamaModel,
apply_rotary_pos_emb,
eager_attention_forward,
)
from .configuration_openpangu_dense import PanguEmbeddedConfig
logger = logging.get_logger(__name__)
class PanguEmbeddedMLP(LlamaMLP):
def __init__(self, config):
super().__init__(config)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
class PanguEmbeddedAttention(LlamaAttention):
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if not self.training and NPU_ATTN_INFR:
q_len = input_shape[1]
if attention_mask is not None:
attention_mask = ~attention_mask.bool()
elif q_len > 1:
attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query_states, key_states, value_states,
num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
attn_output = attn_output.transpose(1, 2)
attn_weights = None
else:
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class PanguEmbeddedDecoderLayer(LlamaDecoderLayer):
pass
class PanguEmbeddedModel(LlamaModel):
pass
class PanguEmbeddedForCausalLM(LlamaForCausalLM):
pass
|