lastdefiance20 commited on
Commit
2a77950
·
verified ·
1 Parent(s): a6ceffb

Create modeling_exaone.py

Browse files
Files changed (1) hide show
  1. modeling_exaone.py +1376 -0
modeling_exaone.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The LG AI Research EXAONE Lab.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """LG AI Research EXAONE Lab"""
22
+
23
+ import math
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from packaging import version
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
+ from transformers.generation import GenerationMixin
35
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ CausalLMOutputWithPast,
41
+ QuestionAnsweringModelOutput,
42
+ SequenceClassifierOutputWithPast,
43
+ )
44
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
45
+ from transformers.modeling_utils import PreTrainedModel
46
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
47
+ from transformers.utils import (
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ is_flash_attn_2_available,
52
+ logging,
53
+ )
54
+ from .configuration_exaone import ExaoneConfig
55
+
56
+
57
+ if is_flash_attn_2_available():
58
+ try:
59
+ import flash_attn
60
+
61
+ if version.parse(flash_attn.__version__) > version.parse("2.4.2"):
62
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn
63
+ else:
64
+ from flash_attn.ops.triton.layernorm import rms_norm_fn
65
+ except ImportError:
66
+ pass
67
+
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ _CHECKPOINT_FOR_DOC = "exaone"
72
+ _CONFIG_FOR_DOC = "ExaoneConfig"
73
+
74
+ EXAONE_PRETRAINED_MODEL_ARCHIVE_LIST = [
75
+ "exaone",
76
+ ]
77
+
78
+
79
+ @torch.jit.script
80
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81
+ """
82
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84
+ """
85
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86
+ if n_rep == 1:
87
+ return hidden_states
88
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
89
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90
+
91
+
92
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
93
+ """Applies Rotary Position Embedding to the query and key tensors.
94
+ Args:
95
+ q (`torch.Tensor`): The query tensor.
96
+ k (`torch.Tensor`): The key tensor.
97
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
98
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
99
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
100
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
101
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
105
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
+ Returns:
107
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
108
+ """
109
+ cos = cos.unsqueeze(unsqueeze_dim)
110
+ sin = sin.unsqueeze(unsqueeze_dim)
111
+ q_embed = (q * cos) + (rotate_half(q) * sin)
112
+ k_embed = (k * cos) + (rotate_half(k) * sin)
113
+ return q_embed, k_embed
114
+
115
+
116
+ def rotate_half(x):
117
+ """Rotates half the hidden dims of the input."""
118
+ x1 = x[..., : x.shape[-1] // 2]
119
+ x2 = x[..., x.shape[-1] // 2 :]
120
+ return torch.cat((-x2, x1), dim=-1)
121
+
122
+
123
+ def _prepare_4d_causal_attention_mask_with_cache_position(
124
+ attention_mask: torch.Tensor,
125
+ sequence_length: int,
126
+ target_length: int,
127
+ dtype: torch.dtype,
128
+ device: torch.device,
129
+ min_dtype: float,
130
+ cache_position: torch.Tensor,
131
+ batch_size: int,
132
+ ):
133
+ """
134
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
135
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
136
+ Args:
137
+ attention_mask (`torch.Tensor`):
138
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
139
+ sequence_length (`int`):
140
+ The sequence length being processed.
141
+ target_length (`int`):
142
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
143
+ dtype (`torch.dtype`):
144
+ The dtype to use for the 4D attention mask.
145
+ device (`torch.device`):
146
+ The device to plcae the 4D attention mask on.
147
+ min_dtype (`float`):
148
+ The minimum value representable with the dtype `dtype`.
149
+ cache_position (`torch.Tensor`):
150
+ Indices depicting the position of the input sequence tokens in the sequence.
151
+ batch_size (`torch.Tensor`):
152
+ Batch size.
153
+ """
154
+ if attention_mask is not None and attention_mask.dim() == 4:
155
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
156
+ causal_mask = attention_mask
157
+ else:
158
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
159
+ if sequence_length != 1:
160
+ causal_mask = torch.triu(causal_mask, diagonal=1)
161
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
162
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
163
+ if attention_mask is not None:
164
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
165
+ mask_length = attention_mask.shape[-1]
166
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
167
+ padding_mask = padding_mask == 0
168
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
169
+ padding_mask, min_dtype
170
+ )
171
+
172
+ return causal_mask
173
+
174
+
175
+ class ExaoneRMSNorm(torch.nn.Module):
176
+ def __init__(self, hidden_size, eps=1e-6):
177
+ super().__init__()
178
+ self.eps = eps
179
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
180
+
181
+ def forward(self, hidden_states):
182
+ input_dtype = hidden_states.dtype
183
+ hidden_states = hidden_states.to(torch.float32)
184
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
185
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
186
+ return self.weight * hidden_states.to(input_dtype)
187
+
188
+
189
+ class ExaoneTritonRMSNorm(torch.nn.Module):
190
+ def __init__(
191
+ self,
192
+ hidden_size: int = 0,
193
+ eps: float = 1e-5,
194
+ ):
195
+ super().__init__()
196
+ self.eps = eps
197
+ self.drop = None
198
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size))
199
+ self.register_parameter("bias", None)
200
+ self.reset_parameters()
201
+
202
+ def reset_parameters(self):
203
+ torch.nn.init.ones_(self.weight)
204
+
205
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
206
+ return rms_norm_fn(
207
+ x,
208
+ self.weight,
209
+ self.bias,
210
+ residual=residual,
211
+ eps=self.eps,
212
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
213
+ prenorm=prenorm,
214
+ residual_in_fp32=residual_in_fp32,
215
+ )
216
+
217
+
218
+ ALL_LAYERNORM_LAYERS.append(ExaoneRMSNorm)
219
+ ALL_LAYERNORM_LAYERS.append(ExaoneTritonRMSNorm)
220
+
221
+
222
+ class ExaoneRotaryEmbedding(nn.Module):
223
+ def __init__(self, config: ExaoneConfig, device=None):
224
+ super().__init__()
225
+ if config.rope_scaling is not None:
226
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
227
+ else:
228
+ self.rope_type = "default"
229
+ self.rope_theta = config.rope_theta
230
+ self.max_seq_len = config.max_position_embeddings
231
+ self.original_max_seq_len = config.max_position_embeddings
232
+
233
+ self.config = config
234
+ if self.rope_type not in ROPE_INIT_FUNCTIONS:
235
+ raise KeyError(f"The EXAONE model does not support RoPE type: {self.rope_type}")
236
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
237
+
238
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
239
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
240
+ self.original_inv_freq = self.inv_freq
241
+
242
+ def _update_freq(self, position_ids, device):
243
+ """
244
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
245
+ 1 - growing beyond the cached sequence length (allow scaling)
246
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
247
+ """
248
+ seq_len = torch.max(position_ids) + 1
249
+ if seq_len > self.max_seq_len: # expand to seq_len
250
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
251
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+ self.max_seq_len = seq_len
253
+
254
+ if seq_len < self.original_max_seq_len and self.max_seq_len > self.original_max_seq_len: # reset to original
255
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
256
+ self.max_seq_len = self.original_max_seq_len
257
+
258
+ @torch.no_grad()
259
+ def forward(self, x, position_ids):
260
+ if "dynamic" in self.rope_type:
261
+ self._update_freq(position_ids, device=x.device)
262
+
263
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
264
+ position_ids_expanded = position_ids[:, None, :].float()
265
+
266
+ device_type = x.device.type
267
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
268
+ with torch.autocast(device_type=device_type, enabled=False):
269
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
270
+ emb = torch.cat((freqs, freqs), dim=-1)
271
+ cos, sin = emb.cos(), emb.sin()
272
+
273
+ cos, sin = cos * self.attention_scaling, sin * self.attention_scaling
274
+ return cos.to(x.dtype), sin.to(x.dtype)
275
+
276
+
277
+ class ExaoneSelfAttention(nn.Module):
278
+ def __init__(self, config: ExaoneConfig, layer_idx: Optional[int] = None):
279
+ super().__init__()
280
+ self.config = config
281
+ self.layer_idx = layer_idx
282
+ self.embed_dim = config.hidden_size
283
+ self.num_heads = config.num_attention_heads
284
+ self.head_dim = self.embed_dim // self.num_heads
285
+ self.num_key_value_heads = config.num_key_value_heads
286
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
287
+ self.attention_dropout_rate = config.attention_dropout
288
+
289
+ if self.head_dim * self.num_heads != self.embed_dim:
290
+ raise ValueError(
291
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
292
+ )
293
+
294
+ self.rotary = ExaoneRotaryEmbedding(config)
295
+
296
+ self.k_proj = nn.Linear(self.embed_dim, self.num_key_value_heads * self.head_dim, bias=False)
297
+ self.v_proj = nn.Linear(self.embed_dim, self.num_key_value_heads * self.head_dim, bias=False)
298
+ self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
299
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
300
+
301
+ def forward(
302
+ self,
303
+ hidden_states: torch.Tensor,
304
+ attention_mask: Optional[torch.Tensor] = None,
305
+ position_ids: Optional[torch.LongTensor] = None,
306
+ past_key_value: Optional[Cache] = None,
307
+ output_attentions: Optional[bool] = False,
308
+ use_cache: Optional[bool] = False,
309
+ cache_position: Optional[torch.LongTensor] = None,
310
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
311
+ **kwargs,
312
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
313
+ bsz, q_len, _ = hidden_states.size()
314
+ query_states = self.q_proj(hidden_states)
315
+ key_states = self.k_proj(hidden_states)
316
+ value_states = self.v_proj(hidden_states)
317
+
318
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
319
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
321
+
322
+ if position_embeddings is None:
323
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
324
+ else:
325
+ cos, sin = position_embeddings
326
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
327
+
328
+ if past_key_value is not None:
329
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
330
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
331
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
332
+
333
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
334
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
335
+
336
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
337
+
338
+ if attention_mask is not None:
339
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
340
+ attn_weights = attn_weights + causal_mask
341
+
342
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
343
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout_rate, training=self.training)
344
+ attn_output = torch.matmul(attn_weights, value_states)
345
+
346
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
347
+ raise ValueError(
348
+ f"Attention outputs should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
349
+ f" {attn_output.size()}"
350
+ )
351
+
352
+ attn_output = attn_output.transpose(1, 2).contiguous()
353
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
354
+
355
+ attn_output = self.out_proj(attn_output)
356
+
357
+ if not output_attentions:
358
+ attn_weights = None
359
+
360
+ return attn_output, attn_weights, past_key_value
361
+
362
+
363
+ class ExaoneFlashAttention(ExaoneSelfAttention):
364
+ def __init__(self, *args, **kwargs):
365
+ super().__init__(*args, **kwargs)
366
+
367
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
368
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.Tensor,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ position_ids: Optional[torch.LongTensor] = None,
375
+ past_key_value: Optional[Cache] = None,
376
+ output_attentions: Optional[bool] = False,
377
+ use_cache: Optional[bool] = False,
378
+ cache_position: Optional[torch.LongTensor] = None,
379
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
380
+ **kwargs,
381
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
382
+ if isinstance(past_key_value, StaticCache):
383
+ raise ValueError(
384
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
385
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
386
+ )
387
+
388
+ output_attentions = False
389
+
390
+ bsz, q_len, h_size = hidden_states.size()
391
+
392
+ query_states = self.q_proj(hidden_states)
393
+ key_states = self.k_proj(hidden_states)
394
+ value_states = self.v_proj(hidden_states)
395
+
396
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
398
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
+
400
+ if position_embeddings is None:
401
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
402
+ else:
403
+ cos, sin = position_embeddings
404
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
405
+
406
+ if past_key_value is not None:
407
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
408
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
409
+ # Only update cache as shape of [bsz, n_head, q_len, head_dim]
410
+ # TODO: need to be fixed when transformers' KV cache layout is changed
411
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
412
+
413
+ query_states = query_states.transpose(1, 2)
414
+ key_states = key_states.transpose(1, 2)
415
+ value_states = value_states.transpose(1, 2)
416
+
417
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
418
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
419
+ # cast them back in the correct dtype just to be sure everything works as expected.
420
+ input_dtype = query_states.dtype
421
+ if input_dtype == torch.float32:
422
+ if torch.is_autocast_enabled():
423
+ target_dtype = torch.get_autocast_gpu_dtype()
424
+ # Handle the case where the model is quantized
425
+ elif hasattr(self.config, "_pre_quantization_dtype"):
426
+ target_dtype = self.config._pre_quantization_dtype
427
+ else:
428
+ target_dtype = self.q_proj.weight.dtype
429
+
430
+ logger.warning_once(
431
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
432
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
433
+ f" {target_dtype}."
434
+ )
435
+
436
+ query_states = query_states.to(target_dtype)
437
+ key_states = key_states.to(target_dtype)
438
+ value_states = value_states.to(target_dtype)
439
+
440
+ dropout_rate = self.attention_dropout_rate if self.training else 0.0
441
+
442
+ attn_output = _flash_attention_forward(
443
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, is_causal=True
444
+ )
445
+
446
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
447
+ attn_output = self.out_proj(attn_output)
448
+
449
+ if not output_attentions:
450
+ attn_weights = None
451
+
452
+ return attn_output, attn_weights, past_key_value
453
+
454
+
455
+ class ExaoneSdpaAttention(ExaoneSelfAttention):
456
+ def __init__(self, *args, **kwargs):
457
+ super().__init__(*args, **kwargs)
458
+
459
+ def forward(
460
+ self,
461
+ hidden_states: torch.Tensor,
462
+ attention_mask: Optional[torch.Tensor] = None,
463
+ position_ids: Optional[torch.LongTensor] = None,
464
+ past_key_value: Optional[Cache] = None,
465
+ output_attentions: Optional[bool] = False,
466
+ use_cache: Optional[bool] = False,
467
+ cache_position: Optional[torch.LongTensor] = None,
468
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
469
+ **kwargs,
470
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
471
+ if output_attentions:
472
+ logger.warning_once(
473
+ "ExaoneModel is using ExaoneSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
474
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
475
+ )
476
+ return super().forward(
477
+ hidden_states=hidden_states,
478
+ attention_mask=attention_mask,
479
+ position_ids=position_ids,
480
+ past_key_value=past_key_value,
481
+ output_attentions=output_attentions,
482
+ use_cache=use_cache,
483
+ cache_position=cache_position,
484
+ position_embeddings=position_embeddings,
485
+ **kwargs,
486
+ )
487
+
488
+ bsz, q_len, _ = hidden_states.size()
489
+
490
+ query_states = self.q_proj(hidden_states)
491
+ key_states = self.k_proj(hidden_states)
492
+ value_states = self.v_proj(hidden_states)
493
+
494
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
495
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
496
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
497
+
498
+ if position_embeddings is None:
499
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
500
+ else:
501
+ cos, sin = position_embeddings
502
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
503
+
504
+ if past_key_value is not None:
505
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
506
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
507
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
508
+
509
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
510
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
511
+
512
+ causal_mask = attention_mask
513
+ if attention_mask is not None:
514
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
515
+
516
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
517
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
518
+ if query_states.device.type == "cuda" and causal_mask is not None:
519
+ query_states = query_states.contiguous()
520
+ key_states = key_states.contiguous()
521
+ value_states = value_states.contiguous()
522
+
523
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
524
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
525
+ is_causal = True if causal_mask is None and q_len > 1 else False
526
+
527
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ attn_mask=causal_mask,
532
+ dropout_p=self.attention_dropout_rate if self.training else 0.0,
533
+ is_causal=is_causal,
534
+ )
535
+
536
+ attn_output = attn_output.transpose(1, 2).contiguous()
537
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
538
+
539
+ attn_output = self.out_proj(attn_output)
540
+
541
+ return attn_output, None, past_key_value
542
+
543
+
544
+ class ExaoneAttention(nn.Module):
545
+ def __init__(self, config, layer_id=0):
546
+ super().__init__()
547
+ self.layer_id = layer_id
548
+ if "flash" in config._attn_implementation:
549
+ self.attention = ExaoneFlashAttention(config, self.layer_id)
550
+ elif "sdpa" in config._attn_implementation:
551
+ self.attention = ExaoneSdpaAttention(config, self.layer_id)
552
+ else:
553
+ self.attention = ExaoneSelfAttention(config, self.layer_id)
554
+
555
+ def forward(
556
+ self,
557
+ hidden_states: torch.Tensor,
558
+ attention_mask: Optional[torch.Tensor] = None,
559
+ position_ids: Optional[torch.LongTensor] = None,
560
+ past_key_value: Optional[Cache] = None,
561
+ output_attentions: Optional[bool] = False,
562
+ use_cache: Optional[bool] = False,
563
+ cache_position: Optional[torch.LongTensor] = None,
564
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
565
+ **kwargs,
566
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
567
+ return self.attention(
568
+ hidden_states=hidden_states,
569
+ attention_mask=attention_mask,
570
+ position_ids=position_ids,
571
+ past_key_value=past_key_value,
572
+ output_attentions=output_attentions,
573
+ use_cache=use_cache,
574
+ cache_position=cache_position,
575
+ position_embeddings=position_embeddings,
576
+ **kwargs,
577
+ )
578
+
579
+
580
+ class ExaoneGatedMLP(nn.Module):
581
+ def __init__(self, intermediate_size, config):
582
+ super().__init__()
583
+ self.config = config
584
+ embed_dim = config.hidden_size
585
+ self.c_fc_0 = nn.Linear(embed_dim, intermediate_size, bias=False)
586
+ self.c_fc_1 = nn.Linear(embed_dim, intermediate_size, bias=False)
587
+ self.c_proj = nn.Linear(intermediate_size, embed_dim, bias=False)
588
+ self.act = ACT2FN[config.activation_function]
589
+
590
+ def forward(self, hidden_states):
591
+ output_proj = self.c_proj(self.act(self.c_fc_0(hidden_states)) * self.c_fc_1(hidden_states))
592
+ return output_proj
593
+
594
+
595
+ class ExaoneBlock(nn.Module):
596
+ def __init__(self, config, layer_id):
597
+ super().__init__()
598
+ self.config = config
599
+ hidden_size = config.hidden_size
600
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
601
+ self.ln_1 = ExaoneRMSNorm(hidden_size=hidden_size, eps=config.layer_norm_epsilon)
602
+ self.attn = ExaoneAttention(config, layer_id)
603
+ self.ln_2 = ExaoneRMSNorm(hidden_size=hidden_size, eps=config.layer_norm_epsilon)
604
+ self.mlp = ExaoneGatedMLP(inner_dim, config)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.Tensor,
609
+ attention_mask: Optional[torch.Tensor] = None,
610
+ position_ids: Optional[torch.LongTensor] = None,
611
+ past_key_value: Optional[Cache] = None,
612
+ output_attentions: Optional[bool] = False,
613
+ use_cache: Optional[bool] = False,
614
+ cache_position: Optional[torch.LongTensor] = None,
615
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
616
+ **kwargs,
617
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
618
+ residual = hidden_states
619
+ hidden_states = self.ln_1(hidden_states)
620
+
621
+ hidden_states, self_attn_weights, present_key_value = self.attn(
622
+ hidden_states=hidden_states,
623
+ attention_mask=attention_mask,
624
+ position_ids=position_ids,
625
+ past_key_value=past_key_value,
626
+ output_attentions=output_attentions,
627
+ use_cache=use_cache,
628
+ cache_position=cache_position,
629
+ position_embeddings=position_embeddings,
630
+ **kwargs,
631
+ )
632
+ # residual connection
633
+ hidden_states = residual + hidden_states
634
+
635
+ residual = hidden_states
636
+ hidden_states = self.ln_2(hidden_states)
637
+ hidden_states = self.mlp(hidden_states)
638
+
639
+ hidden_states = residual + hidden_states
640
+
641
+ outputs = (hidden_states,)
642
+
643
+ if output_attentions:
644
+ outputs += (self_attn_weights,)
645
+
646
+ if use_cache:
647
+ outputs += (present_key_value,)
648
+
649
+ return outputs
650
+
651
+
652
+ class ExaonePreTrainedModel(PreTrainedModel):
653
+ """
654
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
655
+ models.
656
+ """
657
+
658
+ config_class = ExaoneConfig
659
+ base_model_prefix = "transformer"
660
+ supports_gradient_checkpointing = True
661
+ _no_split_modules = ["ExaoneBlock"]
662
+ _skip_keys_device_placement = "past_key_values"
663
+ _supports_flash_attn_2 = True
664
+ _supports_sdpa = True
665
+ _supports_cache_class = True
666
+
667
+ def __init__(self, *inputs, **kwargs):
668
+ super().__init__(*inputs, **kwargs)
669
+
670
+ def _init_weights(self, module):
671
+ """Initialize the weights."""
672
+ if isinstance(module, (nn.Linear,)):
673
+ # Slightly different from the TF version which uses truncated_normal for initialization
674
+ # cf https://github.com/pytorch/pytorch/pull/5617
675
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
676
+ if module.bias is not None:
677
+ module.bias.data.zero_()
678
+ elif isinstance(module, nn.Embedding):
679
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
680
+ if module.padding_idx is not None:
681
+ module.weight.data[module.padding_idx].zero_()
682
+ elif isinstance(module, ExaoneRMSNorm):
683
+ module.weight.data.fill_(1.0)
684
+
685
+
686
+ EXAONE_START_DOCSTRING = r"""
687
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
688
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
689
+ etc.)
690
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
691
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
692
+ and behavior.
693
+ Parameters:
694
+ config ([`ExaoneConfig`]): Model configuration class with all the parameters of the model.
695
+ Initializing with a config file does not load the weights associated with the model, only the
696
+ configuration. Check out the `PreTrainedModel.from_pretrained` method to load the model weights.
697
+ """
698
+
699
+ EXAONE_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
702
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
703
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
704
+ sequence tokens in the vocabulary.
705
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be
706
+ passed as `input_ids`.
707
+ `What are input IDs? <../glossary.html#input-ids>`__
708
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
709
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
710
+ - 1 for tokens that are **not masked**,
711
+ - 0 for tokens that are **masked**.
712
+ `What are attention masks? <../glossary.html#attention-mask>`__
713
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
714
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
715
+ config.max_position_embeddings - 1]`.
716
+ `What are position IDs? <../glossary.html#position-ids>`_
717
+ past_key_values (`Cache`, *optional*):
718
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
719
+ `past_key_values` output below). Can be used to speed up sequential decoding. This typically consists
720
+ in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or
721
+ `config.use_cache=True`.
722
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
723
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
724
+ This is useful if you want more control over how to convert `input_ids` indices into associated
725
+ vectors than the model's internal embedding lookup matrix.
726
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
727
+ `past_key_values`).
728
+ use_cache (`bool`, *optional*):
729
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up
730
+ decoding (see `past_key_values`).
731
+ output_attentions (`bool`, *optional*):
732
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
733
+ tensors for more detail.
734
+ output_hidden_states (`bool`, *optional*):
735
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
736
+ more detail.
737
+ return_dict (`bool`, *optional*):
738
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
739
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
740
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
741
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
742
+ the complete sequence length.
743
+ """
744
+
745
+
746
+ @add_start_docstrings(
747
+ "The bare EXAONE Model transformer outputting raw hidden-states without any specific head on top.",
748
+ EXAONE_START_DOCSTRING,
749
+ )
750
+ class ExaoneModel(ExaonePreTrainedModel):
751
+ def __init__(self, config):
752
+ super().__init__(config)
753
+ self.config = config
754
+ self.embed_dim = config.hidden_size
755
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim, self.config.pad_token_id)
756
+ self.drop = nn.Dropout(float(config.embed_dropout))
757
+ self.h = nn.ModuleList([ExaoneBlock(config, layer_id=i) for i in range(config.num_layers)])
758
+ self.ln_f = ExaoneRMSNorm(hidden_size=self.embed_dim, eps=config.layer_norm_epsilon)
759
+ self.rotary = ExaoneRotaryEmbedding(config)
760
+ self.gradient_checkpointing = False
761
+ # Initialize weights and apply final processing
762
+ self.post_init()
763
+
764
+ def get_input_embeddings(self):
765
+ return self.wte
766
+
767
+ def set_input_embeddings(self, new_embeddings):
768
+ self.wte = new_embeddings
769
+
770
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
771
+ @add_code_sample_docstrings(
772
+ checkpoint=_CHECKPOINT_FOR_DOC,
773
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
774
+ config_class=_CONFIG_FOR_DOC,
775
+ )
776
+ def forward(
777
+ self,
778
+ input_ids: Optional[torch.Tensor] = None,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.Tensor] = None,
781
+ past_key_values: Optional[Cache] = None,
782
+ inputs_embeds: Optional[torch.Tensor] = None,
783
+ use_cache: Optional[bool] = None,
784
+ output_attentions: Optional[bool] = None,
785
+ output_hidden_states: Optional[bool] = None,
786
+ return_dict: Optional[bool] = None,
787
+ cache_position: Optional[torch.LongTensor] = None,
788
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
789
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
790
+ output_hidden_states = (
791
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
792
+ )
793
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
794
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
795
+
796
+ if self.gradient_checkpointing and self.training:
797
+ if use_cache:
798
+ logger.warning_once(
799
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
800
+ )
801
+ use_cache = False
802
+
803
+ if input_ids is not None and inputs_embeds is not None:
804
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
805
+ elif input_ids is not None:
806
+ batch_size, seq_length = input_ids.shape[:2]
807
+ elif inputs_embeds is not None:
808
+ batch_size, seq_length = inputs_embeds.shape[:2]
809
+ else:
810
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
811
+
812
+ return_legacy_cache = False
813
+ if (
814
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
815
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
816
+ return_legacy_cache = True
817
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
818
+ logger.warning_once(
819
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
820
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
821
+ )
822
+
823
+ if inputs_embeds is None:
824
+ inputs_embeds = self.wte(input_ids)
825
+
826
+ if cache_position is None:
827
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
828
+ cache_position = torch.arange(
829
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
830
+ )
831
+ if position_ids is None:
832
+ position_ids = cache_position.unsqueeze(0)
833
+
834
+ causal_mask = self._update_causal_mask(
835
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
836
+ )
837
+
838
+ hidden_states = inputs_embeds
839
+ hidden_states = self.drop(hidden_states)
840
+
841
+ position_embeddings = self.rotary(hidden_states, position_ids)
842
+
843
+ all_hidden_states = () if output_hidden_states else None
844
+ all_self_attns = () if output_attentions else None
845
+ next_decoder_cache = None
846
+
847
+ for block in self.h:
848
+ if output_hidden_states:
849
+ all_hidden_states = all_hidden_states + (hidden_states,)
850
+
851
+ if self.gradient_checkpointing and self.training:
852
+ outputs = self._gradient_checkpointing_func(
853
+ block.__call__,
854
+ hidden_states,
855
+ causal_mask,
856
+ position_ids,
857
+ past_key_values,
858
+ output_attentions,
859
+ use_cache,
860
+ cache_position,
861
+ position_embeddings,
862
+ )
863
+ else:
864
+ outputs = block(
865
+ hidden_states,
866
+ attention_mask=causal_mask,
867
+ position_ids=position_ids,
868
+ past_key_value=past_key_values,
869
+ output_attentions=output_attentions,
870
+ use_cache=use_cache,
871
+ cache_position=cache_position,
872
+ position_embeddings=position_embeddings,
873
+ )
874
+
875
+ hidden_states = outputs[0]
876
+ if use_cache:
877
+ next_decoder_cache = outputs[2 if output_attentions else 1]
878
+
879
+ if output_attentions:
880
+ all_self_attns += (outputs[1],)
881
+
882
+ hidden_states = self.ln_f(hidden_states)
883
+ # Add last hidden state
884
+ if output_hidden_states:
885
+ all_hidden_states += (hidden_states,)
886
+
887
+ next_cache = None
888
+ if use_cache:
889
+ next_cache = next_decoder_cache.to_legacy_cache() if return_legacy_cache else next_decoder_cache
890
+ if not return_dict:
891
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
892
+
893
+ return BaseModelOutputWithPast(
894
+ last_hidden_state=hidden_states,
895
+ past_key_values=next_cache,
896
+ hidden_states=all_hidden_states,
897
+ attentions=all_self_attns,
898
+ )
899
+
900
+ def _update_causal_mask(
901
+ self,
902
+ attention_mask: torch.Tensor,
903
+ input_tensor: torch.Tensor,
904
+ cache_position: torch.Tensor,
905
+ past_key_values: Cache,
906
+ output_attentions: bool,
907
+ ):
908
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
909
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
910
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
911
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
912
+
913
+ if self.config._attn_implementation == "flash_attention_2":
914
+ if attention_mask is not None and 0.0 in attention_mask:
915
+ return attention_mask
916
+ return None
917
+
918
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
919
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
920
+ # to infer the attention mask.
921
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
922
+ using_static_cache = isinstance(past_key_values, StaticCache)
923
+
924
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
925
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
926
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
927
+ attention_mask,
928
+ inputs_embeds=input_tensor,
929
+ past_key_values_length=past_seen_tokens,
930
+ is_training=self.training,
931
+ ):
932
+ return None
933
+
934
+ dtype, device = input_tensor.dtype, input_tensor.device
935
+ min_dtype = torch.finfo(dtype).min
936
+ sequence_length = input_tensor.shape[1]
937
+ if using_static_cache:
938
+ target_length = past_key_values.get_max_length()
939
+ else:
940
+ target_length = (
941
+ attention_mask.shape[-1]
942
+ if isinstance(attention_mask, torch.Tensor)
943
+ else past_seen_tokens + sequence_length + 1
944
+ )
945
+
946
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
947
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
948
+ attention_mask,
949
+ sequence_length=sequence_length,
950
+ target_length=target_length,
951
+ dtype=dtype,
952
+ device=device,
953
+ min_dtype=min_dtype,
954
+ cache_position=cache_position,
955
+ batch_size=input_tensor.shape[0],
956
+ )
957
+
958
+ if (
959
+ self.config._attn_implementation == "sdpa"
960
+ and attention_mask is not None
961
+ and attention_mask.device.type == "cuda"
962
+ and not output_attentions
963
+ ):
964
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
965
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
966
+ # Details: https://github.com/pytorch/pytorch/issues/110213
967
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
968
+
969
+ return causal_mask
970
+
971
+
972
+ @add_start_docstrings(
973
+ """
974
+ The EXAONE Model transformer with a language modeling head on top (linear layer with weights tied to the input
975
+ embeddings).
976
+ """,
977
+ EXAONE_START_DOCSTRING,
978
+ )
979
+ class ExaoneForCausalLM(ExaonePreTrainedModel, GenerationMixin):
980
+ _tied_weights_keys = ["lm_head.weight"]
981
+
982
+ def __init__(self, config):
983
+ super().__init__(config)
984
+ self.transformer = ExaoneModel(config)
985
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
986
+ self.config = config
987
+ # Initialize weights and apply final processing
988
+ self.post_init()
989
+
990
+ def get_output_embeddings(self):
991
+ return self.lm_head
992
+
993
+ def set_output_embeddings(self, new_embeddings):
994
+ self.lm_head = new_embeddings
995
+
996
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
997
+ @add_code_sample_docstrings(
998
+ checkpoint=_CHECKPOINT_FOR_DOC,
999
+ output_type=BaseModelOutputWithPast,
1000
+ config_class=_CONFIG_FOR_DOC,
1001
+ )
1002
+ def forward(
1003
+ self,
1004
+ input_ids: Optional[torch.Tensor] = None,
1005
+ attention_mask: Optional[torch.Tensor] = None,
1006
+ position_ids: Optional[torch.Tensor] = None,
1007
+ past_key_values: Optional[Cache] = None,
1008
+ inputs_embeds: Optional[torch.Tensor] = None,
1009
+ labels: Optional[torch.Tensor] = None,
1010
+ use_cache: Optional[bool] = None,
1011
+ output_attentions: Optional[bool] = None,
1012
+ output_hidden_states: Optional[bool] = None,
1013
+ return_dict: Optional[bool] = None,
1014
+ cache_position: Optional[torch.LongTensor] = None,
1015
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
1016
+ r"""
1017
+ Args:
1018
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1019
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1020
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1021
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1022
+ Example:
1023
+ ```python
1024
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
1025
+ >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
1026
+ trust_remote_code=True)
1027
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct")
1028
+ >>> prompt = "Explain how wonderful you are"
1029
+ >>> messages = [
1030
+ {"role": "system", "content": "You are a helpful assistant."},
1031
+ {"role": "user", "content": prompt}
1032
+ ]
1033
+ >>> input_ids = tokenizer.apply_chat_template(
1034
+ messages,
1035
+ tokenize=True,
1036
+ add_generation_prompt=True,
1037
+ return_tensors="pt"
1038
+ )
1039
+ >>> output = model.generate(input_ids, max_new_tokens=128)
1040
+ >>> tokenizer.decode(output[0], skip_special_tokens=True)
1041
+ "[|system|]You are a helpful assistant.\n[|user|]Explain how wonderful you are\n[|assistant|]Thank you for your kind words! I'm here to assist you with information, answer questions, and help you in any way I can. My goal is to provide accurate, helpful, and timely responses. Whether you need help with a specific task, want to learn something new, or just need someone to talk to, I'm here for you. How can I assist you today?"
1042
+ ```
1043
+ """
1044
+
1045
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1046
+ output_hidden_states = (
1047
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1048
+ )
1049
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1050
+ transformer_outputs = self.transformer(
1051
+ input_ids,
1052
+ attention_mask=attention_mask,
1053
+ past_key_values=past_key_values,
1054
+ position_ids=position_ids,
1055
+ inputs_embeds=inputs_embeds,
1056
+ use_cache=use_cache,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ return_dict=return_dict,
1060
+ cache_position=cache_position,
1061
+ )
1062
+ hidden_states = transformer_outputs[0]
1063
+ lm_logits = self.lm_head(hidden_states)
1064
+ lm_logits = lm_logits.float()
1065
+ loss = None
1066
+ if labels is not None:
1067
+ lm_logits = lm_logits.to(torch.float32)
1068
+
1069
+ # Shift so that tokens < n predict n
1070
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1071
+ shift_labels = labels[..., 1:].contiguous()
1072
+ # Flatten the tokens
1073
+ loss_fct = CrossEntropyLoss()
1074
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1075
+
1076
+ lm_logits = lm_logits.to(hidden_states.dtype)
1077
+ loss = loss.to(hidden_states.dtype)
1078
+
1079
+ if not return_dict:
1080
+ output = (lm_logits,) + transformer_outputs[1:]
1081
+ return ((loss,) + output) if loss is not None else output
1082
+
1083
+ return CausalLMOutputWithPast(
1084
+ loss=loss,
1085
+ logits=lm_logits,
1086
+ past_key_values=transformer_outputs.past_key_values,
1087
+ hidden_states=transformer_outputs.hidden_states,
1088
+ attentions=transformer_outputs.attentions,
1089
+ )
1090
+
1091
+ def prepare_inputs_for_generation(
1092
+ self,
1093
+ input_ids,
1094
+ past_key_values=None,
1095
+ attention_mask=None,
1096
+ inputs_embeds=None,
1097
+ cache_position=None,
1098
+ position_ids=None,
1099
+ use_cache=True,
1100
+ **kwargs,
1101
+ ):
1102
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1103
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1104
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1105
+ if past_key_values is not None:
1106
+ if inputs_embeds is not None: # Exception 1
1107
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1108
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1109
+ input_ids = input_ids[:, cache_position]
1110
+
1111
+ if attention_mask is not None and position_ids is None:
1112
+ # create position_ids on the fly for batch generation
1113
+ position_ids = attention_mask.long().cumsum(-1) - 1
1114
+ position_ids.masked_fill_(attention_mask == 0, 1)
1115
+ if past_key_values:
1116
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1117
+
1118
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1119
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1120
+
1121
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1122
+ if inputs_embeds is not None and cache_position[0] == 0:
1123
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1124
+ else:
1125
+ model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
1126
+
1127
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1128
+ if inputs_embeds is not None:
1129
+ batch_size, sequence_length, _ = inputs_embeds.shape
1130
+ device = inputs_embeds.device
1131
+ else:
1132
+ batch_size, sequence_length = input_ids.shape
1133
+ device = input_ids.device
1134
+
1135
+ dtype = self.lm_head.weight.dtype
1136
+ min_dtype = torch.finfo(dtype).min
1137
+
1138
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1139
+ attention_mask,
1140
+ sequence_length=sequence_length,
1141
+ target_length=past_key_values.get_max_length(),
1142
+ dtype=dtype,
1143
+ device=device,
1144
+ min_dtype=min_dtype,
1145
+ cache_position=cache_position,
1146
+ batch_size=batch_size,
1147
+ )
1148
+
1149
+ model_inputs.update(
1150
+ {
1151
+ "position_ids": position_ids,
1152
+ "cache_position": cache_position,
1153
+ "past_key_values": past_key_values,
1154
+ "use_cache": use_cache,
1155
+ "attention_mask": attention_mask,
1156
+ }
1157
+ )
1158
+ return model_inputs
1159
+
1160
+
1161
+ @add_start_docstrings(
1162
+ """
1163
+ The EXAONE Model transformer with a sequence classification head on top (linear layer).
1164
+ [`ExaoneForSequenceClassification`] uses the last token in order to do the classification, as
1165
+ other causal models (e.g. GPT-1) do.
1166
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1167
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
1168
+ row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
1169
+ guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
1170
+ the last value in each row of the batch).
1171
+ """,
1172
+ EXAONE_START_DOCSTRING,
1173
+ )
1174
+ class ExaoneForSequenceClassification(ExaonePreTrainedModel):
1175
+ def __init__(self, config):
1176
+ super().__init__(config)
1177
+ self.num_labels = config.num_labels
1178
+ self.transformer = ExaoneModel(config)
1179
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1180
+
1181
+ # Initialize weights and apply final processing
1182
+ self.post_init()
1183
+
1184
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
1185
+ @add_code_sample_docstrings(
1186
+ checkpoint=_CHECKPOINT_FOR_DOC,
1187
+ output_type=SequenceClassifierOutputWithPast,
1188
+ config_class=_CONFIG_FOR_DOC,
1189
+ )
1190
+ def forward(
1191
+ self,
1192
+ input_ids: Optional[torch.Tensor] = None,
1193
+ attention_mask: Optional[torch.Tensor] = None,
1194
+ position_ids: Optional[torch.Tensor] = None,
1195
+ past_key_values: Optional[Cache] = None,
1196
+ inputs_embeds: Optional[torch.Tensor] = None,
1197
+ labels: Optional[torch.Tensor] = None,
1198
+ use_cache: Optional[bool] = None,
1199
+ output_attentions: Optional[bool] = None,
1200
+ output_hidden_states: Optional[bool] = None,
1201
+ return_dict: Optional[bool] = None,
1202
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1203
+ r"""
1204
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1205
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1206
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1207
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1208
+ """
1209
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1210
+
1211
+ transformer_outputs = self.transformer(
1212
+ input_ids,
1213
+ attention_mask=attention_mask,
1214
+ position_ids=position_ids,
1215
+ past_key_values=past_key_values,
1216
+ inputs_embeds=inputs_embeds,
1217
+ use_cache=use_cache,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ )
1222
+ hidden_states = transformer_outputs[0]
1223
+ logits = self.score(hidden_states)
1224
+
1225
+ if input_ids is not None:
1226
+ batch_size, sequence_length = input_ids.shape[:2]
1227
+ else:
1228
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1229
+
1230
+ if self.config.pad_token_id is None and batch_size != 1:
1231
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1232
+ if self.config.pad_token_id is None:
1233
+ sequence_lengths = -1
1234
+ else:
1235
+ if input_ids is not None:
1236
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1237
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1238
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1239
+ sequence_lengths = sequence_lengths.to(logits.device)
1240
+ else:
1241
+ sequence_lengths = -1
1242
+ logger.warning(
1243
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1244
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1245
+ )
1246
+
1247
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1248
+
1249
+ loss = None
1250
+ if labels is not None:
1251
+ labels = labels.to(logits.device)
1252
+ if self.config.problem_type is None:
1253
+ if self.num_labels == 1:
1254
+ self.config.problem_type = "regression"
1255
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1256
+ self.config.problem_type = "single_label_classification"
1257
+ else:
1258
+ self.config.problem_type = "multi_label_classification"
1259
+
1260
+ if self.config.problem_type == "regression":
1261
+ loss_fct = MSELoss()
1262
+ if self.num_labels == 1:
1263
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1264
+ else:
1265
+ loss = loss_fct(pooled_logits, labels)
1266
+ elif self.config.problem_type == "single_label_classification":
1267
+ loss_fct = CrossEntropyLoss()
1268
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1269
+ elif self.config.problem_type == "multi_label_classification":
1270
+ loss_fct = BCEWithLogitsLoss()
1271
+ loss = loss_fct(pooled_logits, labels)
1272
+ if not return_dict:
1273
+ output = (pooled_logits,) + transformer_outputs[1:]
1274
+ return ((loss,) + output) if loss is not None else output
1275
+
1276
+ return SequenceClassifierOutputWithPast(
1277
+ loss=loss,
1278
+ logits=pooled_logits,
1279
+ past_key_values=transformer_outputs.past_key_values,
1280
+ hidden_states=transformer_outputs.hidden_states,
1281
+ attentions=transformer_outputs.attentions,
1282
+ )
1283
+
1284
+
1285
+ @add_start_docstrings(
1286
+ """
1287
+ The EXAONE Model transformer with a span classification head on top for extractive question-answering tasks like
1288
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1289
+ """,
1290
+ EXAONE_START_DOCSTRING,
1291
+ )
1292
+ class ExaoneForQuestionAnswering(ExaonePreTrainedModel):
1293
+ def __init__(self, config):
1294
+ super().__init__(config)
1295
+ self.num_labels = config.num_labels
1296
+ self.transformer = ExaoneModel(config)
1297
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1298
+
1299
+ # Model parallel
1300
+ self.model_parallel = False
1301
+ self.device_map = None
1302
+
1303
+ # Initialize weights and apply final processing
1304
+ self.post_init()
1305
+
1306
+ def forward(
1307
+ self,
1308
+ input_ids: Optional[torch.LongTensor] = None,
1309
+ attention_mask: Optional[torch.FloatTensor] = None,
1310
+ position_ids: Optional[torch.LongTensor] = None,
1311
+ past_key_values: Optional[Cache] = None,
1312
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1313
+ start_positions: Optional[torch.LongTensor] = None,
1314
+ end_positions: Optional[torch.LongTensor] = None,
1315
+ output_attentions: Optional[bool] = None,
1316
+ output_hidden_states: Optional[bool] = None,
1317
+ return_dict: Optional[bool] = None,
1318
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1319
+ r"""
1320
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1321
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1322
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1323
+ sequence are not taken into account for computing the loss.
1324
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1325
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1326
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1327
+ sequence are not taken into account for computing the loss.
1328
+ """
1329
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1330
+
1331
+ outputs = self.transformer(
1332
+ input_ids,
1333
+ attention_mask=attention_mask,
1334
+ position_ids=position_ids,
1335
+ past_key_values=past_key_values,
1336
+ inputs_embeds=inputs_embeds,
1337
+ output_attentions=output_attentions,
1338
+ output_hidden_states=output_hidden_states,
1339
+ return_dict=return_dict,
1340
+ )
1341
+
1342
+ sequence_output = outputs[0]
1343
+
1344
+ logits = self.qa_outputs(sequence_output)
1345
+ start_logits, end_logits = logits.split(1, dim=-1)
1346
+ start_logits = start_logits.squeeze(-1).contiguous()
1347
+ end_logits = end_logits.squeeze(-1).contiguous()
1348
+
1349
+ total_loss = None
1350
+ if start_positions is not None and end_positions is not None:
1351
+ # If we are on multi-GPU, split add a dimension
1352
+ if len(start_positions.size()) > 1:
1353
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1354
+ if len(end_positions.size()) > 1:
1355
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1356
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1357
+ ignored_index = start_logits.size(1)
1358
+ start_positions = start_positions.clamp(0, ignored_index)
1359
+ end_positions = end_positions.clamp(0, ignored_index)
1360
+
1361
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1362
+ start_loss = loss_fct(start_logits, start_positions)
1363
+ end_loss = loss_fct(end_logits, end_positions)
1364
+ total_loss = (start_loss + end_loss) / 2
1365
+
1366
+ if not return_dict:
1367
+ output = (start_logits, end_logits) + outputs[2:]
1368
+ return ((total_loss,) + output) if total_loss is not None else output
1369
+
1370
+ return QuestionAnsweringModelOutput(
1371
+ loss=total_loss,
1372
+ start_logits=start_logits,
1373
+ end_logits=end_logits,
1374
+ hidden_states=outputs.hidden_states,
1375
+ attentions=outputs.attentions,
1376
+ )