sigridjineth commited on
Commit
53500ff
·
verified ·
1 Parent(s): 2e5b18b

Upload 2 files

Browse files
Files changed (2) hide show
  1. modeling_hyperclovax.py +1259 -0
  2. modeling_hyperclovax_old.py +1199 -0
modeling_hyperclovax.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file was created for the HyperCLOVA X SEED 14B Think architecture.
3
+ # partially copied and modified from https://github.com/huggingface/transformers
4
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from typing import List, Iterable, Optional, Union, Tuple
36
+ from collections import deque
37
+ import os
38
+ from transformers.modeling_layers import GradientCheckpointingLayer
39
+ from transformers.modeling_outputs import (
40
+ BaseModelOutputWithPast,
41
+ CausalLMOutputWithPast,
42
+ QuestionAnsweringModelOutput,
43
+ SequenceClassifierOutputWithPast,
44
+ TokenClassifierOutput,
45
+ )
46
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
47
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
48
+ from transformers.processing_utils import Unpack
49
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
50
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
51
+ from .configuration_hyperclovax import HyperCLOVAXConfig
52
+ if is_torch_flex_attn_available():
53
+ from torch.nn.attention.flex_attention import BlockMask
54
+
55
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ # ================= DeepConf: confidence-based online early stop =================
60
+ class DeepConfEOSLogitsProcessor(LogitsProcessor):
61
+ """
62
+ Per-sample early stop: at each step, compute token_conf = mean(logprob of top-r),
63
+ maintain group_conf = mean of last `window` token_conf; if group_conf < threshold,
64
+ force EOS for THAT sample by setting EOS logprob=0 and others to -inf.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ eos_token_ids: List[int],
70
+ window: int = 512,
71
+ top_r: int = 5,
72
+ threshold: float = -3.5,
73
+ warmup_tokens: int = 0,
74
+ prefer_eos_ids: Optional[List[int]] = None,
75
+ require_prev_id: Optional[int] = None,
76
+ im_end_id: Optional[int] = None,
77
+ require_im_end_count: int = 0,
78
+ threshold_think: Optional[float] = None,
79
+ threshold_answer: Optional[float] = None,
80
+ ):
81
+ self.eos_ids: List[int] = sorted({int(i) for i in (eos_token_ids or []) if i is not None and i >= 0})
82
+ self.window: int = max(int(window), 1)
83
+ self.top_r: int = max(int(top_r), 1)
84
+ self.threshold: float = float(threshold)
85
+ self.warmup_tokens: int = max(int(warmup_tokens), 0)
86
+ self.prefer_eos_ids: List[int] = sorted({int(i) for i in (prefer_eos_ids or []) if i is not None and i >= 0})
87
+ self.require_prev_id = require_prev_id
88
+ self.im_end_id = im_end_id
89
+ self.require_im_end_count = max(int(require_im_end_count), 0)
90
+ self.threshold_think = threshold_think
91
+ self.threshold_answer = threshold_answer
92
+ self._base_im_end_counts: Optional[List[int]] = None
93
+ self._buffers: Optional[List[deque]] = None
94
+ self._verbose: bool = os.getenv("HYPERCLOVA_DEEPCONF_VERBOSE", "0").strip().lower() in {"1", "on", "true"}
95
+ self._every: int = max(int(os.getenv("HYPERCLOVA_DEEPCONF_REPORT_EVERY", "64")), 1)
96
+ self._tick: int = 0
97
+ self._stops: int = 0
98
+
99
+ def _ensure(self, bsz: int) -> None:
100
+ if self._buffers is None or len(self._buffers) != bsz:
101
+ self._buffers = [deque(maxlen=self.window) for _ in range(bsz)]
102
+
103
+ @torch.no_grad()
104
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
105
+ bsz, vocab = scores.shape
106
+ self._ensure(bsz)
107
+
108
+ # --- im_end count (only in generated part) ---
109
+ gen_counts = [0] * bsz
110
+ if self.im_end_id is not None and input_ids is not None:
111
+ # Count im_end in the whole context
112
+ curr = (input_ids == self.im_end_id).sum(dim=1).tolist()
113
+ if self._base_im_end_counts is None:
114
+ self._base_im_end_counts = curr[:] # Set baseline
115
+ gen_counts = [curr[i] - self._base_im_end_counts[i] for i in range(bsz)]
116
+
117
+ logprobs = torch.log_softmax(scores, dim=-1)
118
+ k = min(self.top_r, vocab)
119
+ token_conf = torch.topk(logprobs, k=k, dim=-1).values.mean(dim=-1).tolist()
120
+
121
+ for i, c in enumerate(token_conf):
122
+ buf = self._buffers[i]
123
+ buf.append(c)
124
+ group_conf = sum(buf) / len(buf)
125
+ if len(buf) < self.warmup_tokens:
126
+ continue
127
+
128
+ # phase-aware threshold
129
+ if self.threshold_think is not None and gen_counts[i] <= 0:
130
+ thr = self.threshold_think
131
+ elif self.threshold_answer is not None and gen_counts[i] >= 1:
132
+ thr = self.threshold_answer
133
+ else:
134
+ thr = self.threshold
135
+
136
+ # ChatML protection: only force stop after enough im_end tokens
137
+ im_end_gate_ok = gen_counts[i] >= self.require_im_end_count
138
+
139
+ # (Optional) previous token gate
140
+ prev_ok = True
141
+ if self.require_prev_id is not None and input_ids is not None and input_ids.size(1) > 0:
142
+ prev_ok = int(input_ids[i, -1].item()) == self.require_prev_id
143
+
144
+ if group_conf < thr and (self.prefer_eos_ids or self.eos_ids) and im_end_gate_ok and prev_ok:
145
+ targets = self.prefer_eos_ids if self.prefer_eos_ids else self.eos_ids
146
+ scores[i].fill_(-float("inf"))
147
+ for eid in targets:
148
+ if 0 <= eid < vocab:
149
+ scores[i, eid] = 0.0
150
+ self._stops += 1
151
+
152
+ if self._verbose:
153
+ self._tick += 1
154
+ if self._tick % self._every == 0:
155
+ try:
156
+ gcs = [(sum(b) / len(b)) if b else float("nan") for b in (self._buffers or [])]
157
+ valid = [x for x in gcs if not (x != x)]
158
+ mean_gc = float(sum(valid) / max(1, len(valid)))
159
+ except Exception:
160
+ mean_gc = float("nan")
161
+
162
+ if os.getenv("HYPERCLOVA_DEEPCONF_VERBOSE_ATTACH", "0") in {"1", "on", "true"}:
163
+ print(f"[DeepConf] step={self._tick} mean_gc={mean_gc:.4f} stops={self._stops}")
164
+ return scores
165
+
166
+ # (optional) Offline helper: Lowest Group Confidence (LGC)
167
+ def deepconf_lgc_from_scores(scores_list: Iterable[torch.Tensor], top_r: int = 5, window: int = 2048) -> float:
168
+ tensors = [s for s in scores_list]
169
+ if not tensors: return float("-inf")
170
+ with torch.no_grad():
171
+ vals = [
172
+ torch.topk(torch.log_softmax(s, dim=-1), k=min(top_r, s.size(-1)), dim=-1).values.mean(dim=-1)
173
+ for s in tensors
174
+ ] # each (B,)
175
+ conf = torch.stack(vals).squeeze(-1) # (T,) if B=1
176
+ w = min(int(window), conf.numel())
177
+ kernel = torch.ones(1,1,w, device=conf.device) / w
178
+ run = torch.nn.functional.conv1d(conf.view(1,1,-1), weight=kernel).squeeze()
179
+ return float(run.min().item())
180
+ # ==============================================================================
181
+
182
+
183
+ @use_kernel_forward_from_hub("RMSNorm")
184
+ class HyperCLOVAXRMSNorm(nn.Module):
185
+ def __init__(self, hidden_size, eps=1e-6):
186
+ """
187
+ HyperCLOVAXRMSNorm is equivalent to T5LayerNorm
188
+ """
189
+ super().__init__()
190
+ self.weight = nn.Parameter(torch.ones(hidden_size))
191
+ self.variance_epsilon = eps
192
+
193
+ def forward(self, hidden_states):
194
+ input_dtype = hidden_states.dtype
195
+ hidden_states = hidden_states.to(torch.float32)
196
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
197
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
198
+ return self.weight * hidden_states.to(input_dtype)
199
+
200
+ def extra_repr(self):
201
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
202
+
203
+ ALL_LAYERNORM_LAYERS.append(HyperCLOVAXRMSNorm)
204
+ class HyperCLOVAXRotaryEmbedding(nn.Module):
205
+ def __init__(self, config: HyperCLOVAXConfig, device=None):
206
+ super().__init__()
207
+ # BC: "rope_type" was originally "type"
208
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
209
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
210
+ else:
211
+ self.rope_type = "default"
212
+ self.max_seq_len_cached = config.max_position_embeddings
213
+ self.original_max_seq_len = config.max_position_embeddings
214
+
215
+ self.config = config
216
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
217
+
218
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
219
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
220
+ self.original_inv_freq = self.inv_freq
221
+
222
+ @torch.no_grad()
223
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
224
+ def forward(self, x, position_ids):
225
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
226
+ position_ids_expanded = position_ids[:, None, :].float()
227
+
228
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
229
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
230
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
231
+ emb = torch.cat((freqs, freqs), dim=-1)
232
+ cos = emb.cos() * self.attention_scaling
233
+ sin = emb.sin() * self.attention_scaling
234
+
235
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
236
+
237
+
238
+ def rotate_half(x):
239
+ """Rotates half the hidden dims of the input."""
240
+ x1 = x[..., : x.shape[-1] // 2]
241
+ x2 = x[..., x.shape[-1] // 2 :]
242
+ return torch.cat((-x2, x1), dim=-1)
243
+
244
+
245
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
246
+ """Applies Rotary Position Embedding to the query and key tensors.
247
+
248
+ Args:
249
+ q (`torch.Tensor`): The query tensor.
250
+ k (`torch.Tensor`): The key tensor.
251
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
252
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
253
+ position_ids (`torch.Tensor`, *optional*):
254
+ Deprecated and unused.
255
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
256
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
257
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
258
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
259
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
260
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
261
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
262
+ Returns:
263
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
264
+ """
265
+ cos = cos.unsqueeze(unsqueeze_dim)
266
+ sin = sin.unsqueeze(unsqueeze_dim)
267
+ q_embed = (q * cos) + (rotate_half(q) * sin)
268
+ k_embed = (k * cos) + (rotate_half(k) * sin)
269
+ return q_embed, k_embed
270
+
271
+
272
+ class HyperCLOVAXMLP(nn.Module):
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.config = config
276
+ self.hidden_size = config.hidden_size
277
+ self.intermediate_size = config.intermediate_size
278
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
279
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
280
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
281
+ self.act_fn = ACT2FN[config.hidden_act]
282
+
283
+ def forward(self, x):
284
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
285
+ return down_proj
286
+
287
+
288
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
289
+ """
290
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
291
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
292
+ """
293
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
294
+ if n_rep == 1:
295
+ return hidden_states
296
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
297
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
298
+
299
+
300
+ def eager_attention_forward(
301
+ module: nn.Module,
302
+ query: torch.Tensor,
303
+ key: torch.Tensor,
304
+ value: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor],
306
+ scaling: float,
307
+ dropout: float = 0.0,
308
+ **kwargs,
309
+ ):
310
+ key_states = repeat_kv(key, module.num_key_value_groups)
311
+ value_states = repeat_kv(value, module.num_key_value_groups)
312
+
313
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
314
+ if attention_mask is not None:
315
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
316
+ attn_weights = attn_weights + causal_mask
317
+
318
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
319
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
320
+ attn_output = torch.matmul(attn_weights, value_states)
321
+ attn_output = attn_output.transpose(1, 2).contiguous()
322
+
323
+ return attn_output, attn_weights
324
+
325
+
326
+ class HyperCLOVAXAttention(nn.Module):
327
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
328
+
329
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
330
+ super().__init__()
331
+ self.config = config
332
+ self.layer_idx = layer_idx
333
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
334
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
335
+ self.scaling = getattr(config, "attention_multiplier", self.head_dim**-0.5) # MuP
336
+ self.attention_dropout = config.attention_dropout
337
+ self.is_causal = True
338
+
339
+ self.q_proj = nn.Linear(
340
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
341
+ )
342
+ self.k_proj = nn.Linear(
343
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
344
+ )
345
+ self.v_proj = nn.Linear(
346
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
347
+ )
348
+ self.o_proj = nn.Linear(
349
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
350
+ )
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
356
+ attention_mask: Optional[torch.Tensor],
357
+ past_key_value: Optional[Cache] = None,
358
+ cache_position: Optional[torch.LongTensor] = None,
359
+ **kwargs: Unpack[FlashAttentionKwargs],
360
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
361
+ input_shape = hidden_states.shape[:-1]
362
+ hidden_shape = (*input_shape, -1, self.head_dim)
363
+
364
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
365
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
366
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
367
+
368
+ cos, sin = position_embeddings
369
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
370
+
371
+ if past_key_value is not None:
372
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
373
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
374
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
375
+
376
+ attention_interface: Callable = eager_attention_forward
377
+
378
+ if self.config._attn_implementation != "eager":
379
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
380
+ logger.warning_once(
381
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
382
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
383
+ )
384
+ else:
385
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
386
+
387
+ attn_output, attn_weights = attention_interface(
388
+ self,
389
+ query_states,
390
+ key_states,
391
+ value_states,
392
+ attention_mask,
393
+ dropout=0.0 if not self.training else self.attention_dropout,
394
+ scaling=self.scaling,
395
+ **kwargs,
396
+ )
397
+
398
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
399
+ attn_output = self.o_proj(attn_output)
400
+ return attn_output, attn_weights
401
+
402
+
403
+ class HyperCLOVAXDecoderLayer(GradientCheckpointingLayer):
404
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
405
+ super().__init__()
406
+ self.hidden_size = config.hidden_size
407
+
408
+ self.self_attn = HyperCLOVAXAttention(config=config, layer_idx=layer_idx)
409
+
410
+ self.mlp = HyperCLOVAXMLP(config)
411
+ self.input_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
412
+ self.post_attention_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
413
+ self.use_post_norm = getattr(config, "use_post_norm", False)
414
+
415
+ # Peri-LN (post-norm)
416
+ if self.use_post_norm:
417
+ self.post_norm1 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418
+ self.post_norm2 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419
+
420
+ self.residual_multiplier = getattr(config, "residual_multiplier", 1.0) # MuP
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Cache] = None,
428
+ output_attentions: Optional[bool] = False,
429
+ use_cache: Optional[bool] = False,
430
+ cache_position: Optional[torch.LongTensor] = None,
431
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
432
+ **kwargs: Unpack[FlashAttentionKwargs],
433
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
434
+ residual = hidden_states
435
+ hidden_states = self.input_layernorm(hidden_states)
436
+
437
+ # Self Attention
438
+ hidden_states, self_attn_weights = self.self_attn(
439
+ hidden_states=hidden_states,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ past_key_value=past_key_value,
443
+ output_attentions=output_attentions,
444
+ use_cache=use_cache,
445
+ cache_position=cache_position,
446
+ position_embeddings=position_embeddings,
447
+ **kwargs,
448
+ )
449
+
450
+ if self.use_post_norm: # Peri-LN
451
+ hidden_states = self.post_norm1(hidden_states)
452
+
453
+ hidden_states = residual + hidden_states * self.residual_multiplier # MuP
454
+
455
+ # Fully Connected
456
+ residual = hidden_states
457
+ hidden_states = self.post_attention_layernorm(hidden_states)
458
+ hidden_states = self.mlp(hidden_states)
459
+
460
+ if self.use_post_norm: # Peri-LN
461
+ hidden_states = self.post_norm2(hidden_states)
462
+
463
+ hidden_states = residual + hidden_states * self.residual_multiplier # MuP
464
+
465
+ outputs = (hidden_states,)
466
+ if output_attentions:
467
+ outputs += (self_attn_weights,)
468
+
469
+ return outputs
470
+
471
+
472
+ @auto_docstring
473
+ class HyperCLOVAXPreTrainedModel(PreTrainedModel):
474
+ config_class = HyperCLOVAXConfig
475
+ base_model_prefix = "model"
476
+ supports_gradient_checkpointing = True
477
+ _no_split_modules = ["HyperCLOVAXDecoderLayer"]
478
+ _skip_keys_device_placement = ["past_key_values"]
479
+ _supports_flash_attn_2 = True
480
+ _supports_sdpa = True
481
+ _supports_flex_attn = True
482
+ _supports_cache_class = True
483
+ _supports_quantized_cache = True
484
+ _supports_static_cache = True
485
+ _supports_attention_backend = True
486
+
487
+ def _init_weights(self, module):
488
+ std = self.config.initializer_range
489
+ if isinstance(module, nn.Linear):
490
+ module.weight.data.normal_(mean=0.0, std=std)
491
+ if module.bias is not None:
492
+ module.bias.data.zero_()
493
+ elif isinstance(module, nn.Embedding):
494
+ module.weight.data.normal_(mean=0.0, std=std)
495
+ if module.padding_idx is not None:
496
+ module.weight.data[module.padding_idx].zero_()
497
+ elif isinstance(module, HyperCLOVAXRMSNorm):
498
+ module.weight.data.fill_(1.0)
499
+
500
+
501
+ @auto_docstring
502
+ class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
503
+ def __init__(self, config: HyperCLOVAXConfig):
504
+ super().__init__(config)
505
+ self.padding_idx = config.pad_token_id
506
+ self.vocab_size = config.vocab_size
507
+
508
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
509
+ self.layers = nn.ModuleList(
510
+ [HyperCLOVAXDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
511
+ )
512
+ self.norm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
513
+ self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=config)
514
+ self.gradient_checkpointing = False
515
+
516
+ # Initialize weights and apply final processing
517
+ self.post_init()
518
+
519
+ # MuP
520
+ self.embedding_multiplier = getattr(config, "embedding_multiplier", 1.0)
521
+
522
+ def get_input_embeddings(self):
523
+ return self.embed_tokens
524
+
525
+ def set_input_embeddings(self, value):
526
+ self.embed_tokens = value
527
+
528
+ @can_return_tuple
529
+ @auto_docstring
530
+ def forward(
531
+ self,
532
+ input_ids: Optional[torch.LongTensor] = None,
533
+ attention_mask: Optional[torch.Tensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_values: Optional[Cache] = None,
536
+ inputs_embeds: Optional[torch.FloatTensor] = None,
537
+ use_cache: Optional[bool] = None,
538
+ output_attentions: Optional[bool] = None,
539
+ output_hidden_states: Optional[bool] = None,
540
+ cache_position: Optional[torch.LongTensor] = None,
541
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
542
+ ) -> BaseModelOutputWithPast:
543
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
544
+ output_hidden_states = (
545
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
546
+ )
547
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
548
+
549
+ if (input_ids is None) ^ (inputs_embeds is not None):
550
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
551
+
552
+ if self.gradient_checkpointing and self.training and use_cache:
553
+ logger.warning_once(
554
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
555
+ )
556
+ use_cache = False
557
+
558
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
559
+ if not isinstance(past_key_values, (type(None), Cache)):
560
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
561
+
562
+ if inputs_embeds is None:
563
+ inputs_embeds = self.embed_tokens(input_ids)
564
+
565
+ inputs_embeds = inputs_embeds * self.embedding_multiplier # MuP
566
+
567
+ if use_cache and past_key_values is None:
568
+ past_key_values = DynamicCache()
569
+
570
+ if cache_position is None:
571
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
572
+ cache_position = torch.arange(
573
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
574
+ )
575
+
576
+ if position_ids is None:
577
+ position_ids = cache_position.unsqueeze(0)
578
+
579
+ causal_mask = self._update_causal_mask(
580
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
581
+ )
582
+
583
+ hidden_states = inputs_embeds
584
+
585
+ # create position embeddings to be shared across the decoder layers
586
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
587
+
588
+ # decoder layers
589
+ all_hidden_states = () if output_hidden_states else None
590
+ all_self_attns = () if output_attentions else None
591
+
592
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
593
+ if output_hidden_states:
594
+ all_hidden_states += (hidden_states,)
595
+
596
+ layer_outputs = decoder_layer(
597
+ hidden_states,
598
+ attention_mask=causal_mask,
599
+ position_ids=position_ids,
600
+ past_key_value=past_key_values,
601
+ output_attentions=output_attentions,
602
+ use_cache=use_cache,
603
+ cache_position=cache_position,
604
+ position_embeddings=position_embeddings,
605
+ **flash_attn_kwargs,
606
+ )
607
+
608
+ hidden_states = layer_outputs[0]
609
+
610
+ if output_attentions:
611
+ all_self_attns += (layer_outputs[1],)
612
+
613
+ hidden_states = self.norm(hidden_states)
614
+
615
+ # add hidden states from the last decoder layer
616
+ if output_hidden_states:
617
+ all_hidden_states += (hidden_states,)
618
+
619
+ return BaseModelOutputWithPast(
620
+ last_hidden_state=hidden_states,
621
+ past_key_values=past_key_values if use_cache else None,
622
+ hidden_states=all_hidden_states,
623
+ attentions=all_self_attns,
624
+ )
625
+
626
+ def _update_causal_mask(
627
+ self,
628
+ attention_mask: Union[torch.Tensor, "BlockMask"],
629
+ input_tensor: torch.Tensor,
630
+ cache_position: torch.Tensor,
631
+ past_key_values: Cache,
632
+ output_attentions: bool = False,
633
+ ):
634
+ if self.config._attn_implementation == "flash_attention_2":
635
+ if attention_mask is not None and (attention_mask == 0.0).any():
636
+ return attention_mask
637
+ return None
638
+ if self.config._attn_implementation == "flex_attention":
639
+ if isinstance(attention_mask, torch.Tensor):
640
+ attention_mask = make_flex_block_causal_mask(attention_mask)
641
+ return attention_mask
642
+
643
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
644
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
645
+ # to infer the attention mask.
646
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
647
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
648
+
649
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
650
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
651
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
652
+ attention_mask,
653
+ inputs_embeds=input_tensor,
654
+ past_key_values_length=past_seen_tokens,
655
+ is_training=self.training,
656
+ ):
657
+ return None
658
+
659
+ dtype = input_tensor.dtype
660
+ sequence_length = input_tensor.shape[1]
661
+ if using_compilable_cache:
662
+ target_length = past_key_values.get_max_cache_shape()
663
+ else:
664
+ target_length = (
665
+ attention_mask.shape[-1]
666
+ if isinstance(attention_mask, torch.Tensor)
667
+ else past_seen_tokens + sequence_length + 1
668
+ )
669
+
670
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
671
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
672
+ attention_mask,
673
+ sequence_length=sequence_length,
674
+ target_length=target_length,
675
+ dtype=dtype,
676
+ cache_position=cache_position,
677
+ batch_size=input_tensor.shape[0],
678
+ )
679
+
680
+ if (
681
+ self.config._attn_implementation == "sdpa"
682
+ and attention_mask is not None
683
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
684
+ and not output_attentions
685
+ ):
686
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
687
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
688
+ # Details: https://github.com/pytorch/pytorch/issues/110213
689
+ min_dtype = torch.finfo(dtype).min
690
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
691
+
692
+ return causal_mask
693
+
694
+ @staticmethod
695
+ def _prepare_4d_causal_attention_mask_with_cache_position(
696
+ attention_mask: torch.Tensor,
697
+ sequence_length: int,
698
+ target_length: int,
699
+ dtype: torch.dtype,
700
+ cache_position: torch.Tensor,
701
+ batch_size: int,
702
+ **kwargs,
703
+ ):
704
+ """
705
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
706
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
707
+
708
+ Args:
709
+ attention_mask (`torch.Tensor`):
710
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
711
+ `(batch_size, 1, query_length, key_value_length)`.
712
+ sequence_length (`int`):
713
+ The sequence length being processed.
714
+ target_length (`int`):
715
+ The target length: when generating with static cache, the mask should be as long as the static cache,
716
+ to account for the 0 padding, the part of the cache that is not filled yet.
717
+ dtype (`torch.dtype`):
718
+ The dtype to use for the 4D attention mask.
719
+ cache_position (`torch.Tensor`):
720
+ Indices depicting the position of the input sequence tokens in the sequence.
721
+ batch_size (`torch.Tensor`):
722
+ Batch size.
723
+ """
724
+ if attention_mask is not None and attention_mask.dim() == 4:
725
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
726
+ causal_mask = attention_mask
727
+ else:
728
+ min_dtype = torch.finfo(dtype).min
729
+ causal_mask = torch.full(
730
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
731
+ )
732
+ if sequence_length != 1:
733
+ causal_mask = torch.triu(causal_mask, diagonal=1)
734
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
735
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
736
+ if attention_mask is not None:
737
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
738
+ mask_length = attention_mask.shape[-1]
739
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
740
+ causal_mask.device
741
+ )
742
+ padding_mask = padding_mask == 0
743
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
744
+ padding_mask, min_dtype
745
+ )
746
+
747
+ return causal_mask
748
+
749
+
750
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
751
+
752
+
753
+ @auto_docstring
754
+ class HyperCLOVAXForCausalLM(HyperCLOVAXPreTrainedModel, GenerationMixin):
755
+ _tied_weights_keys = ["lm_head.weight"]
756
+ _tp_plan = {"lm_head": "colwise_rep"}
757
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
758
+
759
+ def __init__(self, config):
760
+ super().__init__(config)
761
+ self.model = HyperCLOVAXModel(config)
762
+ self.vocab_size = config.vocab_size
763
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
764
+ self.logits_scaling = getattr(config, "logits_scaling", 1.0)
765
+
766
+ # Initialize weights and apply final processing
767
+ self.post_init()
768
+
769
+ def get_input_embeddings(self):
770
+ return self.model.embed_tokens
771
+
772
+ def set_input_embeddings(self, value):
773
+ self.model.embed_tokens = value
774
+
775
+ def get_output_embeddings(self):
776
+ return self.lm_head
777
+
778
+ def set_output_embeddings(self, new_embeddings):
779
+ self.lm_head = new_embeddings
780
+
781
+ # -------- DeepConf helpers ----------
782
+ def _dc_collect_eos(self, explicit: Optional[Union[int, List[int]]] = None, **kwargs) -> List[int]:
783
+ ids: List[int] = []
784
+ if explicit is not None:
785
+ ids.extend([int(x) for x in (explicit if isinstance(explicit, (list,tuple)) else [explicit])])
786
+ else:
787
+ if getattr(self.config, "eos_token_id", None) is not None:
788
+ ids.append(int(self.config.eos_token_id))
789
+ if getattr(self.config, "eos_token_id_list", None):
790
+ ids.extend(int(x) for x in self.config.eos_token_id_list if x is not None)
791
+ extra = os.getenv("HYPERCLOVA_DEEPCONF_EOS_IDS", "").strip()
792
+ if extra:
793
+ ids.extend(int(tok) for tok in extra.split(",") if tok.strip().isdigit())
794
+ return sorted({i for i in ids if i >= 0})
795
+
796
+ def _dc_enabled(self) -> bool:
797
+ enabled = True
798
+ env = os.getenv("HYPERCLOVA_DEEPCONF", "").strip().lower()
799
+ if env in {"0","off","false"}: enabled = False
800
+ elif env in {"1","on","true"}: enabled = True
801
+ cfg_en = getattr(self.config, "deepconf_enable", None)
802
+ if cfg_en is not None:
803
+ enabled = bool(cfg_en) # If config is specified, it takes precedence
804
+ if getattr(self.config, "deepconf_disable", False):
805
+ enabled = False # Force OFF flag
806
+ return enabled
807
+
808
+ def _dc_params(self) -> Tuple[int,int,float,int]:
809
+ def env_int(k, d): v=os.getenv(k); return int(v) if v not in (None,"") else d
810
+ def env_flt(k, d): v=os.getenv(k); return float(v) if v not in (None,"") else d
811
+ window = env_int("HYPERCLOVA_DEEPCONF_WINDOW", getattr(self.config, "deepconf_window", 512))
812
+ top_r = env_int("HYPERCLOVA_DEEPCONF_TOPR", getattr(self.config, "deepconf_top_r", 5))
813
+ thr = env_flt("HYPERCLOVA_DEEPCONF_THRESH", getattr(self.config, "deepconf_threshold", -3.5))
814
+ warmup = env_int("HYPERCLOVA_DEEPCONF_WARMUP", getattr(self.config, "deepconf_warmup_tokens", 0))
815
+ return window, top_r, thr, warmup
816
+
817
+ def deepconf_generate(self, *args,
818
+ eos_token_id: Optional[Union[int, List[int]]] = None,
819
+ window: int = 512, top_r: int = 5, threshold: float = -3.5,
820
+ warmup_tokens: int = 0,
821
+ **kwargs):
822
+ # Prefer ChatML stop strings if tokenizer+stop_strings are provided
823
+ prefer_ids: List[int] = []
824
+ tok = kwargs.get("tokenizer", None)
825
+ stop_strings = kwargs.get("stop_strings", None)
826
+ if tok is not None and stop_strings:
827
+ for s in stop_strings:
828
+ try:
829
+ eid = tok.convert_tokens_to_ids(s)
830
+ if isinstance(eid, int) and eid >= 0:
831
+ prefer_ids.append(int(eid)); continue
832
+ except Exception:
833
+ pass
834
+ try:
835
+ enc = tok.encode(s, add_special_tokens=False)
836
+ if isinstance(enc, list) and len(enc) == 1:
837
+ prefer_ids.append(int(enc[0]))
838
+ except Exception:
839
+ pass
840
+ lp: LogitsProcessorList = kwargs.pop("logits_processor", None) or LogitsProcessorList()
841
+ lp.append(
842
+ DeepConfEOSLogitsProcessor(
843
+ self._dc_collect_eos(eos_token_id, **kwargs),
844
+ window, top_r, threshold,
845
+ warmup_tokens=warmup_tokens,
846
+ prefer_eos_ids=prefer_ids or None
847
+ )
848
+ )
849
+ kwargs["logits_processor"] = lp
850
+ return super().generate(*args, **kwargs)
851
+
852
+ # Override generate() to be default ON (auto-attach DeepConf; merge with external lps)
853
+ def generate(self, *args, **kwargs):
854
+ if self._dc_enabled():
855
+ eos_ids = self._dc_collect_eos(kwargs.get("eos_token_id", None), **kwargs)
856
+ # Prefer ChatML end tokens if provided
857
+ prefer_ids: List[int] = []
858
+ tok = kwargs.get("tokenizer", None)
859
+ stop_strings = kwargs.get("stop_strings", None)
860
+ im_end_id = None
861
+ if tok is not None and stop_strings:
862
+ for s in stop_strings:
863
+ try:
864
+ eid = tok.convert_tokens_to_ids(s)
865
+ if isinstance(eid, int) and eid >= 0:
866
+ prefer_ids.append(int(eid))
867
+ continue
868
+ except Exception:
869
+ pass
870
+ try:
871
+ enc = tok.encode(s, add_special_tokens=False)
872
+ if isinstance(enc, list) and len(enc) == 1:
873
+ prefer_ids.append(int(enc[0]))
874
+ except Exception:
875
+ pass
876
+
877
+ # For ChatML protection: extract <|im_end|> id
878
+ if tok is not None:
879
+ try:
880
+ im_end_id = tok.convert_tokens_to_ids("<|im_end|>")
881
+ if not isinstance(im_end_id, int) or im_end_id < 0:
882
+ im_end_id = None
883
+ except Exception:
884
+ im_end_id = None
885
+
886
+ if eos_ids:
887
+ window, top_r, thr, warmup = self._dc_params()
888
+
889
+ def env_int(k, d):
890
+ v = os.getenv(k)
891
+ return int(v) if v not in (None, "") else d
892
+
893
+ # Phase-aware params from ENV
894
+ require_count = env_int(
895
+ "HYPERCLOVA_DEEPCONF_REQUIRE_IM_END_COUNT", 2 if (prefer_ids and im_end_id is not None) else 0
896
+ )
897
+ thr_think_str = os.getenv("HYPERCLOVA_DEEPCONF_THRESH_THINK", None)
898
+ thr_ans_str = os.getenv("HYPERCLOVA_DEEPCONF_THRESH_ANS", None)
899
+ thr_think = float(thr_think_str) if thr_think_str is not None and thr_think_str.strip() != "" else None
900
+ thr_ans = float(thr_ans_str) if thr_ans_str is not None and thr_ans_str.strip() != "" else None
901
+
902
+ # require_prev_id is deprecated in favor of require_im_end_count, setting to None as recommended.
903
+ require_prev = None
904
+ if os.getenv("HYPERCLOVA_DEEPCONF_REQUIRE_IM_END", "0").lower() in {"1", "on", "true"}: # Keep for BC, but default off
905
+ require_prev = im_end_id
906
+
907
+ lp: LogitsProcessorList = kwargs.pop("logits_processor", None) or LogitsProcessorList()
908
+
909
+ if os.getenv("HYPERCLOVA_DEEPCONF_VERBOSE_ATTACH", "0") in {"1", "on", "true"}:
910
+ print(
911
+ f"[DeepConf] attach window={window} top_r={top_r} thr={thr} warmup={warmup} eos={eos_ids} prefer={prefer_ids} "
912
+ f"require_prev={require_prev} im_end_id={im_end_id} require_count={require_count} thr_think={thr_think} thr_ans={thr_ans}"
913
+ )
914
+
915
+ lp.append(
916
+ DeepConfEOSLogitsProcessor(
917
+ eos_ids,
918
+ window,
919
+ top_r,
920
+ thr,
921
+ warmup_tokens=warmup,
922
+ prefer_eos_ids=prefer_ids or None,
923
+ require_prev_id=require_prev,
924
+ im_end_id=im_end_id,
925
+ require_im_end_count=require_count,
926
+ threshold_think=thr_think,
927
+ threshold_answer=thr_ans,
928
+ )
929
+ )
930
+ kwargs["logits_processor"] = lp
931
+ return super().generate(*args, **kwargs)
932
+
933
+ def set_decoder(self, decoder):
934
+ self.model = decoder
935
+
936
+ def get_decoder(self):
937
+ return self.model
938
+
939
+ @can_return_tuple
940
+ @auto_docstring
941
+ def forward(
942
+ self,
943
+ input_ids: Optional[torch.LongTensor] = None,
944
+ attention_mask: Optional[torch.Tensor] = None,
945
+ position_ids: Optional[torch.LongTensor] = None,
946
+ past_key_values: Optional[Cache] = None,
947
+ inputs_embeds: Optional[torch.FloatTensor] = None,
948
+ labels: Optional[torch.LongTensor] = None,
949
+ use_cache: Optional[bool] = None,
950
+ output_attentions: Optional[bool] = None,
951
+ output_hidden_states: Optional[bool] = None,
952
+ cache_position: Optional[torch.LongTensor] = None,
953
+ logits_to_keep: Union[int, torch.Tensor] = 0,
954
+ **kwargs: Unpack[KwargsForCausalLM],
955
+ ) -> CausalLMOutputWithPast:
956
+ r"""
957
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
958
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
959
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
960
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
961
+
962
+ Example:
963
+
964
+ ```python
965
+ >>> from transformers import AutoTokenizer, HyperCLOVAXForCausalLM
966
+
967
+ >>> model = HyperCLOVAXForCausalLM.from_pretrained("naver-hyperclovax/{model_name}")
968
+ >>> tokenizer = AutoTokenizer.from_pretrained("naver-hyperclovax/{model_name}")
969
+
970
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
971
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
972
+
973
+ >>> # Generate
974
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
975
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
976
+ "Hey, are you conscious? Can you talk to me?
977
+ I'm not conscious, but I can talk to you."
978
+ ```"""
979
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
980
+ output_hidden_states = (
981
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
982
+ )
983
+
984
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
985
+ outputs: BaseModelOutputWithPast = self.model(
986
+ input_ids=input_ids,
987
+ attention_mask=attention_mask,
988
+ position_ids=position_ids,
989
+ past_key_values=past_key_values,
990
+ inputs_embeds=inputs_embeds,
991
+ use_cache=use_cache,
992
+ output_attentions=output_attentions,
993
+ output_hidden_states=output_hidden_states,
994
+ cache_position=cache_position,
995
+ **kwargs,
996
+ )
997
+
998
+ hidden_states = outputs.last_hidden_state
999
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1000
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1001
+ # MuP
1002
+ logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.logits_scaling
1003
+
1004
+ loss = None
1005
+ if labels is not None:
1006
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1007
+
1008
+ return CausalLMOutputWithPast(
1009
+ loss=loss,
1010
+ logits=logits,
1011
+ past_key_values=outputs.past_key_values,
1012
+ hidden_states=outputs.hidden_states,
1013
+ attentions=outputs.attentions,
1014
+ )
1015
+
1016
+
1017
+ @auto_docstring(
1018
+ custom_intro="""
1019
+ The HyperCLOVAX Model transformer with a sequence classification head on top (linear layer).
1020
+
1021
+ [`HyperCLOVAXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1022
+ (e.g. GPT-2) do.
1023
+
1024
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1025
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1026
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1027
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1028
+ each row of the batch).
1029
+ """
1030
+ )
1031
+ class HyperCLOVAXForSequenceClassification(HyperCLOVAXPreTrainedModel):
1032
+ def __init__(self, config):
1033
+ super().__init__(config)
1034
+ self.num_labels = config.num_labels
1035
+ self.model = HyperCLOVAXModel(config)
1036
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1037
+
1038
+ # Initialize weights and apply final processing
1039
+ self.post_init()
1040
+
1041
+ def get_input_embeddings(self):
1042
+ return self.model.embed_tokens
1043
+
1044
+ def set_input_embeddings(self, value):
1045
+ self.model.embed_tokens = value
1046
+
1047
+ @can_return_tuple
1048
+ @auto_docstring
1049
+ def forward(
1050
+ self,
1051
+ input_ids: Optional[torch.LongTensor] = None,
1052
+ attention_mask: Optional[torch.Tensor] = None,
1053
+ position_ids: Optional[torch.LongTensor] = None,
1054
+ past_key_values: Optional[Cache] = None,
1055
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1056
+ labels: Optional[torch.LongTensor] = None,
1057
+ use_cache: Optional[bool] = None,
1058
+ output_attentions: Optional[bool] = None,
1059
+ output_hidden_states: Optional[bool] = None,
1060
+ ) -> SequenceClassifierOutputWithPast:
1061
+ r"""
1062
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1063
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1064
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1065
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1066
+ """
1067
+
1068
+ transformer_outputs: BaseModelOutputWithPast = self.model(
1069
+ input_ids,
1070
+ attention_mask=attention_mask,
1071
+ position_ids=position_ids,
1072
+ past_key_values=past_key_values,
1073
+ inputs_embeds=inputs_embeds,
1074
+ use_cache=use_cache,
1075
+ output_attentions=output_attentions,
1076
+ output_hidden_states=output_hidden_states,
1077
+ )
1078
+ hidden_states = transformer_outputs.last_hidden_state
1079
+ logits = self.score(hidden_states)
1080
+
1081
+ if input_ids is not None:
1082
+ batch_size = input_ids.shape[0]
1083
+ else:
1084
+ batch_size = inputs_embeds.shape[0]
1085
+
1086
+ if self.config.pad_token_id is None and batch_size != 1:
1087
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1088
+ if self.config.pad_token_id is None:
1089
+ last_non_pad_token = -1
1090
+ elif input_ids is not None:
1091
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1092
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1093
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1094
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1095
+ else:
1096
+ last_non_pad_token = -1
1097
+ logger.warning_once(
1098
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1099
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1100
+ )
1101
+
1102
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1103
+
1104
+ loss = None
1105
+ if labels is not None:
1106
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1107
+
1108
+ return SequenceClassifierOutputWithPast(
1109
+ loss=loss,
1110
+ logits=pooled_logits,
1111
+ past_key_values=transformer_outputs.past_key_values,
1112
+ hidden_states=transformer_outputs.hidden_states,
1113
+ attentions=transformer_outputs.attentions,
1114
+ )
1115
+
1116
+
1117
+ @auto_docstring
1118
+ class HyperCLOVAXForQuestionAnswering(HyperCLOVAXPreTrainedModel):
1119
+ base_model_prefix = "transformer"
1120
+
1121
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->HyperCLOVAX
1122
+ def __init__(self, config):
1123
+ super().__init__(config)
1124
+ self.transformer = HyperCLOVAXModel(config)
1125
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1126
+
1127
+ # Initialize weights and apply final processing
1128
+ self.post_init()
1129
+
1130
+ def get_input_embeddings(self):
1131
+ return self.transformer.embed_tokens
1132
+
1133
+ def set_input_embeddings(self, value):
1134
+ self.transformer.embed_tokens = value
1135
+
1136
+ @can_return_tuple
1137
+ @auto_docstring
1138
+ def forward(
1139
+ self,
1140
+ input_ids: Optional[torch.LongTensor] = None,
1141
+ attention_mask: Optional[torch.Tensor] = None,
1142
+ position_ids: Optional[torch.LongTensor] = None,
1143
+ past_key_values: Optional[Cache] = None,
1144
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1145
+ start_positions: Optional[torch.LongTensor] = None,
1146
+ end_positions: Optional[torch.LongTensor] = None,
1147
+ output_attentions: Optional[bool] = None,
1148
+ output_hidden_states: Optional[bool] = None,
1149
+ **kwargs,
1150
+ ) -> QuestionAnsweringModelOutput:
1151
+ outputs: BaseModelOutputWithPast = self.transformer(
1152
+ input_ids,
1153
+ attention_mask=attention_mask,
1154
+ position_ids=position_ids,
1155
+ past_key_values=past_key_values,
1156
+ inputs_embeds=inputs_embeds,
1157
+ output_attentions=output_attentions,
1158
+ output_hidden_states=output_hidden_states,
1159
+ )
1160
+
1161
+ sequence_output = outputs.last_hidden_state
1162
+
1163
+ logits = self.qa_outputs(sequence_output)
1164
+ start_logits, end_logits = logits.split(1, dim=-1)
1165
+ start_logits = start_logits.squeeze(-1).contiguous()
1166
+ end_logits = end_logits.squeeze(-1).contiguous()
1167
+
1168
+ loss = None
1169
+ if start_positions is not None and end_positions is not None:
1170
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1171
+
1172
+ return QuestionAnsweringModelOutput(
1173
+ loss=loss,
1174
+ start_logits=start_logits,
1175
+ end_logits=end_logits,
1176
+ hidden_states=outputs.hidden_states,
1177
+ attentions=outputs.attentions,
1178
+ )
1179
+
1180
+
1181
+ @auto_docstring
1182
+ class HyperCLOVAXForTokenClassification(HyperCLOVAXPreTrainedModel):
1183
+ def __init__(self, config):
1184
+ super().__init__(config)
1185
+ self.num_labels = config.num_labels
1186
+ self.model = HyperCLOVAXModel(config)
1187
+ if getattr(config, "classifier_dropout", None) is not None:
1188
+ classifier_dropout = config.classifier_dropout
1189
+ elif getattr(config, "hidden_dropout", None) is not None:
1190
+ classifier_dropout = config.hidden_dropout
1191
+ else:
1192
+ classifier_dropout = 0.1
1193
+ self.dropout = nn.Dropout(classifier_dropout)
1194
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1195
+
1196
+ # Initialize weights and apply final processing
1197
+ self.post_init()
1198
+
1199
+ def get_input_embeddings(self):
1200
+ return self.model.embed_tokens
1201
+
1202
+ def set_input_embeddings(self, value):
1203
+ self.model.embed_tokens = value
1204
+
1205
+ @can_return_tuple
1206
+ @auto_docstring
1207
+ def forward(
1208
+ self,
1209
+ input_ids: Optional[torch.LongTensor] = None,
1210
+ attention_mask: Optional[torch.Tensor] = None,
1211
+ position_ids: Optional[torch.LongTensor] = None,
1212
+ past_key_values: Optional[Cache] = None,
1213
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1214
+ labels: Optional[torch.LongTensor] = None,
1215
+ use_cache: Optional[bool] = None,
1216
+ output_attentions: Optional[bool] = None,
1217
+ output_hidden_states: Optional[bool] = None,
1218
+ ) -> TokenClassifierOutput:
1219
+ r"""
1220
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1221
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1222
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1223
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1224
+ """
1225
+
1226
+ outputs: BaseModelOutputWithPast = self.model(
1227
+ input_ids,
1228
+ attention_mask=attention_mask,
1229
+ position_ids=position_ids,
1230
+ past_key_values=past_key_values,
1231
+ inputs_embeds=inputs_embeds,
1232
+ use_cache=use_cache,
1233
+ output_attentions=output_attentions,
1234
+ output_hidden_states=output_hidden_states,
1235
+ )
1236
+ sequence_output = outputs.last_hidden_state
1237
+ sequence_output = self.dropout(sequence_output)
1238
+ logits = self.score(sequence_output)
1239
+
1240
+ loss = None
1241
+ if labels is not None:
1242
+ loss = self.loss_function(logits, labels, self.config)
1243
+
1244
+ return TokenClassifierOutput(
1245
+ loss=loss,
1246
+ logits=logits,
1247
+ hidden_states=outputs.hidden_states,
1248
+ attentions=outputs.attentions,
1249
+ )
1250
+
1251
+
1252
+ __all__ = [
1253
+ "HyperCLOVAXForCausalLM",
1254
+ "HyperCLOVAXModel",
1255
+ "HyperCLOVAXPreTrainedModel",
1256
+ "HyperCLOVAXForSequenceClassification",
1257
+ "HyperCLOVAXForQuestionAnswering",
1258
+ "HyperCLOVAXForTokenClassification",
1259
+ ]
modeling_hyperclovax_old.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file was created for the HyperCLOVA X SEED 14B Think architecture.
3
+ # partially copied and modified from https://github.com/huggingface/transformers
4
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from typing import List, Iterable, Optional, Union, Tuple
36
+ from collections import deque
37
+ import os
38
+ from transformers.modeling_layers import GradientCheckpointingLayer
39
+ from transformers.modeling_outputs import (
40
+ BaseModelOutputWithPast,
41
+ CausalLMOutputWithPast,
42
+ QuestionAnsweringModelOutput,
43
+ SequenceClassifierOutputWithPast,
44
+ TokenClassifierOutput,
45
+ )
46
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
47
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
48
+ from transformers.processing_utils import Unpack
49
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
50
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
51
+ from .configuration_hyperclovax import HyperCLOVAXConfig
52
+ if is_torch_flex_attn_available():
53
+ from torch.nn.attention.flex_attention import BlockMask
54
+
55
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ # ================= DeepConf: confidence-based online early stop =================
60
+ class DeepConfEOSLogitsProcessor(LogitsProcessor):
61
+ """
62
+ Per-sample early stop: at each step, compute token_conf = mean(logprob of top-r),
63
+ maintain group_conf = mean of last `window` token_conf; if group_conf < threshold,
64
+ force EOS for THAT sample by setting EOS logprob=0 and others to -inf.
65
+ """
66
+ def __init__(
67
+ self,
68
+ eos_token_ids: List[int],
69
+ window: int = 512,
70
+ top_r: int = 5,
71
+ threshold: float = -3.5,
72
+ warmup_tokens: int = 0,
73
+ prefer_eos_ids: Optional[List[int]] = None,
74
+ require_prev_id: Optional[int] = None,
75
+ ):
76
+ self.eos_ids: List[int] = sorted({int(i) for i in (eos_token_ids or []) if i is not None and i >= 0})
77
+ self.window: int = max(int(window), 1)
78
+ self.top_r: int = max(int(top_r), 1)
79
+ self.threshold: float = float(threshold)
80
+ self.warmup_tokens: int = max(int(warmup_tokens), 0)
81
+ self.prefer_eos_ids: List[int] = sorted({int(i) for i in (prefer_eos_ids or []) if i is not None and i >= 0})
82
+ self.require_prev_id = require_prev_id
83
+ self._buffers: Optional[List[deque]] = None
84
+ self._verbose: bool = os.getenv("HYPERCLOVA_DEEPCONF_VERBOSE", "0").strip().lower() in {"1","on","true"}
85
+ self._every: int = max(int(os.getenv("HYPERCLOVA_DEEPCONF_REPORT_EVERY", "64")), 1)
86
+ self._tick: int = 0
87
+ self._stops: int = 0
88
+
89
+ def _ensure(self, bsz: int) -> None:
90
+ if self._buffers is None or len(self._buffers) != bsz:
91
+ self._buffers = [deque(maxlen=self.window) for _ in range(bsz)]
92
+
93
+ @torch.no_grad()
94
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
95
+ bsz, vocab = scores.shape
96
+ self._ensure(bsz)
97
+ logprobs = torch.log_softmax(scores, dim=-1) # (B, V)
98
+ k = min(self.top_r, vocab)
99
+ token_conf = torch.topk(logprobs, k=k, dim=-1).values.mean(dim=-1) # (B,)
100
+
101
+ stopped = False
102
+ for i, c in enumerate(token_conf.tolist()):
103
+ buf = self._buffers[i]; buf.append(c)
104
+ group_conf = sum(buf) / len(buf)
105
+ # --- warmup gate: do not early-stop until we have enough tokens ---
106
+ if len(buf) < self.warmup_tokens:
107
+ continue
108
+
109
+ # ChatML protection: only force preferred EOS after the required previous token
110
+ prev_ok = True
111
+ if self.require_prev_id is not None:
112
+ prev_tok = int(input_ids[i, -1].item()) if input_ids is not None and input_ids.size(1) > 0 else -1
113
+ prev_ok = (prev_tok == self.require_prev_id)
114
+
115
+ if group_conf < self.threshold and (self.prefer_eos_ids or self.eos_ids) and prev_ok:
116
+ # Prefer ChatML end tokens if available; else fall back to config eos
117
+ targets = self.prefer_eos_ids if self.prefer_eos_ids else self.eos_ids
118
+ scores[i].fill_(-float("inf"))
119
+ for eid in targets:
120
+ if 0 <= eid < vocab:
121
+ scores[i, eid] = 0.0
122
+ self._stops += 1
123
+ stopped = True
124
+
125
+ if self._verbose:
126
+ self._tick += 1
127
+ if self._tick % self._every == 0:
128
+ try:
129
+ gcs = [(sum(b)/len(b)) if b else float("nan") for b in (self._buffers or [])]
130
+ valid = [x for x in gcs if not (x != x)]
131
+ mean_gc = float(sum(valid)/max(1, len(valid)))
132
+ except Exception:
133
+ mean_gc = float("nan")
134
+ print(f"[DeepConf] step={self._tick} mean_gc={mean_gc:.4f} stops={self._stops}")
135
+ return scores
136
+
137
+ # (optional) Offline helper: Lowest Group Confidence (LGC)
138
+ def deepconf_lgc_from_scores(scores_list: Iterable[torch.Tensor], top_r: int = 5, window: int = 2048) -> float:
139
+ tensors = [s for s in scores_list]
140
+ if not tensors: return float("-inf")
141
+ with torch.no_grad():
142
+ vals = [
143
+ torch.topk(torch.log_softmax(s, dim=-1), k=min(top_r, s.size(-1)), dim=-1).values.mean(dim=-1)
144
+ for s in tensors
145
+ ] # each (B,)
146
+ conf = torch.stack(vals).squeeze(-1) # (T,) if B=1
147
+ w = min(int(window), conf.numel())
148
+ kernel = torch.ones(1,1,w, device=conf.device) / w
149
+ run = torch.nn.functional.conv1d(conf.view(1,1,-1), weight=kernel).squeeze()
150
+ return float(run.min().item())
151
+ # ==============================================================================
152
+
153
+
154
+ @use_kernel_forward_from_hub("RMSNorm")
155
+ class HyperCLOVAXRMSNorm(nn.Module):
156
+ def __init__(self, hidden_size, eps=1e-6):
157
+ """
158
+ HyperCLOVAXRMSNorm is equivalent to T5LayerNorm
159
+ """
160
+ super().__init__()
161
+ self.weight = nn.Parameter(torch.ones(hidden_size))
162
+ self.variance_epsilon = eps
163
+
164
+ def forward(self, hidden_states):
165
+ input_dtype = hidden_states.dtype
166
+ hidden_states = hidden_states.to(torch.float32)
167
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
168
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
169
+ return self.weight * hidden_states.to(input_dtype)
170
+
171
+ def extra_repr(self):
172
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
173
+
174
+ ALL_LAYERNORM_LAYERS.append(HyperCLOVAXRMSNorm)
175
+ class HyperCLOVAXRotaryEmbedding(nn.Module):
176
+ def __init__(self, config: HyperCLOVAXConfig, device=None):
177
+ super().__init__()
178
+ # BC: "rope_type" was originally "type"
179
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
180
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
181
+ else:
182
+ self.rope_type = "default"
183
+ self.max_seq_len_cached = config.max_position_embeddings
184
+ self.original_max_seq_len = config.max_position_embeddings
185
+
186
+ self.config = config
187
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
188
+
189
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
190
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
191
+ self.original_inv_freq = self.inv_freq
192
+
193
+ @torch.no_grad()
194
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
195
+ def forward(self, x, position_ids):
196
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
197
+ position_ids_expanded = position_ids[:, None, :].float()
198
+
199
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
200
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
201
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
202
+ emb = torch.cat((freqs, freqs), dim=-1)
203
+ cos = emb.cos() * self.attention_scaling
204
+ sin = emb.sin() * self.attention_scaling
205
+
206
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
207
+
208
+
209
+ def rotate_half(x):
210
+ """Rotates half the hidden dims of the input."""
211
+ x1 = x[..., : x.shape[-1] // 2]
212
+ x2 = x[..., x.shape[-1] // 2 :]
213
+ return torch.cat((-x2, x1), dim=-1)
214
+
215
+
216
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
217
+ """Applies Rotary Position Embedding to the query and key tensors.
218
+
219
+ Args:
220
+ q (`torch.Tensor`): The query tensor.
221
+ k (`torch.Tensor`): The key tensor.
222
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
223
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
224
+ position_ids (`torch.Tensor`, *optional*):
225
+ Deprecated and unused.
226
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
227
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
228
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
229
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
230
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
231
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
232
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
233
+ Returns:
234
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
235
+ """
236
+ cos = cos.unsqueeze(unsqueeze_dim)
237
+ sin = sin.unsqueeze(unsqueeze_dim)
238
+ q_embed = (q * cos) + (rotate_half(q) * sin)
239
+ k_embed = (k * cos) + (rotate_half(k) * sin)
240
+ return q_embed, k_embed
241
+
242
+
243
+ class HyperCLOVAXMLP(nn.Module):
244
+ def __init__(self, config):
245
+ super().__init__()
246
+ self.config = config
247
+ self.hidden_size = config.hidden_size
248
+ self.intermediate_size = config.intermediate_size
249
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
250
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
251
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
252
+ self.act_fn = ACT2FN[config.hidden_act]
253
+
254
+ def forward(self, x):
255
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
256
+ return down_proj
257
+
258
+
259
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
260
+ """
261
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
262
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
263
+ """
264
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
265
+ if n_rep == 1:
266
+ return hidden_states
267
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
268
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
269
+
270
+
271
+ def eager_attention_forward(
272
+ module: nn.Module,
273
+ query: torch.Tensor,
274
+ key: torch.Tensor,
275
+ value: torch.Tensor,
276
+ attention_mask: Optional[torch.Tensor],
277
+ scaling: float,
278
+ dropout: float = 0.0,
279
+ **kwargs,
280
+ ):
281
+ key_states = repeat_kv(key, module.num_key_value_groups)
282
+ value_states = repeat_kv(value, module.num_key_value_groups)
283
+
284
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
285
+ if attention_mask is not None:
286
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
287
+ attn_weights = attn_weights + causal_mask
288
+
289
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
290
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
291
+ attn_output = torch.matmul(attn_weights, value_states)
292
+ attn_output = attn_output.transpose(1, 2).contiguous()
293
+
294
+ return attn_output, attn_weights
295
+
296
+
297
+ class HyperCLOVAXAttention(nn.Module):
298
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
299
+
300
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
305
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
306
+ self.scaling = getattr(config, "attention_multiplier", self.head_dim**-0.5) # MuP
307
+ self.attention_dropout = config.attention_dropout
308
+ self.is_causal = True
309
+
310
+ self.q_proj = nn.Linear(
311
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
312
+ )
313
+ self.k_proj = nn.Linear(
314
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
315
+ )
316
+ self.v_proj = nn.Linear(
317
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
318
+ )
319
+ self.o_proj = nn.Linear(
320
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
321
+ )
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states: torch.Tensor,
326
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
327
+ attention_mask: Optional[torch.Tensor],
328
+ past_key_value: Optional[Cache] = None,
329
+ cache_position: Optional[torch.LongTensor] = None,
330
+ **kwargs: Unpack[FlashAttentionKwargs],
331
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
332
+ input_shape = hidden_states.shape[:-1]
333
+ hidden_shape = (*input_shape, -1, self.head_dim)
334
+
335
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
336
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
337
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
338
+
339
+ cos, sin = position_embeddings
340
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
341
+
342
+ if past_key_value is not None:
343
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
344
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
345
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
346
+
347
+ attention_interface: Callable = eager_attention_forward
348
+
349
+ if self.config._attn_implementation != "eager":
350
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
351
+ logger.warning_once(
352
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
353
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
354
+ )
355
+ else:
356
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
357
+
358
+ attn_output, attn_weights = attention_interface(
359
+ self,
360
+ query_states,
361
+ key_states,
362
+ value_states,
363
+ attention_mask,
364
+ dropout=0.0 if not self.training else self.attention_dropout,
365
+ scaling=self.scaling,
366
+ **kwargs,
367
+ )
368
+
369
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
370
+ attn_output = self.o_proj(attn_output)
371
+ return attn_output, attn_weights
372
+
373
+
374
+ class HyperCLOVAXDecoderLayer(GradientCheckpointingLayer):
375
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
376
+ super().__init__()
377
+ self.hidden_size = config.hidden_size
378
+
379
+ self.self_attn = HyperCLOVAXAttention(config=config, layer_idx=layer_idx)
380
+
381
+ self.mlp = HyperCLOVAXMLP(config)
382
+ self.input_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
+ self.post_attention_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384
+ self.use_post_norm = getattr(config, "use_post_norm", False)
385
+
386
+ # Peri-LN (post-norm)
387
+ if self.use_post_norm:
388
+ self.post_norm1 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
389
+ self.post_norm2 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
390
+
391
+ self.residual_multiplier = getattr(config, "residual_multiplier", 1.0) # MuP
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ position_ids: Optional[torch.LongTensor] = None,
398
+ past_key_value: Optional[Cache] = None,
399
+ output_attentions: Optional[bool] = False,
400
+ use_cache: Optional[bool] = False,
401
+ cache_position: Optional[torch.LongTensor] = None,
402
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
403
+ **kwargs: Unpack[FlashAttentionKwargs],
404
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
405
+ residual = hidden_states
406
+ hidden_states = self.input_layernorm(hidden_states)
407
+
408
+ # Self Attention
409
+ hidden_states, self_attn_weights = self.self_attn(
410
+ hidden_states=hidden_states,
411
+ attention_mask=attention_mask,
412
+ position_ids=position_ids,
413
+ past_key_value=past_key_value,
414
+ output_attentions=output_attentions,
415
+ use_cache=use_cache,
416
+ cache_position=cache_position,
417
+ position_embeddings=position_embeddings,
418
+ **kwargs,
419
+ )
420
+
421
+ if self.use_post_norm: # Peri-LN
422
+ hidden_states = self.post_norm1(hidden_states)
423
+
424
+ hidden_states = residual + hidden_states * self.residual_multiplier # MuP
425
+
426
+ # Fully Connected
427
+ residual = hidden_states
428
+ hidden_states = self.post_attention_layernorm(hidden_states)
429
+ hidden_states = self.mlp(hidden_states)
430
+
431
+ if self.use_post_norm: # Peri-LN
432
+ hidden_states = self.post_norm2(hidden_states)
433
+
434
+ hidden_states = residual + hidden_states * self.residual_multiplier # MuP
435
+
436
+ outputs = (hidden_states,)
437
+ if output_attentions:
438
+ outputs += (self_attn_weights,)
439
+
440
+ return outputs
441
+
442
+
443
+ @auto_docstring
444
+ class HyperCLOVAXPreTrainedModel(PreTrainedModel):
445
+ config_class = HyperCLOVAXConfig
446
+ base_model_prefix = "model"
447
+ supports_gradient_checkpointing = True
448
+ _no_split_modules = ["HyperCLOVAXDecoderLayer"]
449
+ _skip_keys_device_placement = ["past_key_values"]
450
+ _supports_flash_attn_2 = True
451
+ _supports_sdpa = True
452
+ _supports_flex_attn = True
453
+ _supports_cache_class = True
454
+ _supports_quantized_cache = True
455
+ _supports_static_cache = True
456
+ _supports_attention_backend = True
457
+
458
+ def _init_weights(self, module):
459
+ std = self.config.initializer_range
460
+ if isinstance(module, nn.Linear):
461
+ module.weight.data.normal_(mean=0.0, std=std)
462
+ if module.bias is not None:
463
+ module.bias.data.zero_()
464
+ elif isinstance(module, nn.Embedding):
465
+ module.weight.data.normal_(mean=0.0, std=std)
466
+ if module.padding_idx is not None:
467
+ module.weight.data[module.padding_idx].zero_()
468
+ elif isinstance(module, HyperCLOVAXRMSNorm):
469
+ module.weight.data.fill_(1.0)
470
+
471
+
472
+ @auto_docstring
473
+ class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
474
+ def __init__(self, config: HyperCLOVAXConfig):
475
+ super().__init__(config)
476
+ self.padding_idx = config.pad_token_id
477
+ self.vocab_size = config.vocab_size
478
+
479
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
480
+ self.layers = nn.ModuleList(
481
+ [HyperCLOVAXDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
482
+ )
483
+ self.norm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
484
+ self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=config)
485
+ self.gradient_checkpointing = False
486
+
487
+ # Initialize weights and apply final processing
488
+ self.post_init()
489
+
490
+ # MuP
491
+ self.embedding_multiplier = getattr(config, "embedding_multiplier", 1.0)
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @can_return_tuple
500
+ @auto_docstring
501
+ def forward(
502
+ self,
503
+ input_ids: Optional[torch.LongTensor] = None,
504
+ attention_mask: Optional[torch.Tensor] = None,
505
+ position_ids: Optional[torch.LongTensor] = None,
506
+ past_key_values: Optional[Cache] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ use_cache: Optional[bool] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> BaseModelOutputWithPast:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+
520
+ if (input_ids is None) ^ (inputs_embeds is not None):
521
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
522
+
523
+ if self.gradient_checkpointing and self.training and use_cache:
524
+ logger.warning_once(
525
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
526
+ )
527
+ use_cache = False
528
+
529
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
530
+ if not isinstance(past_key_values, (type(None), Cache)):
531
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
532
+
533
+ if inputs_embeds is None:
534
+ inputs_embeds = self.embed_tokens(input_ids)
535
+
536
+ inputs_embeds = inputs_embeds * self.embedding_multiplier # MuP
537
+
538
+ if use_cache and past_key_values is None:
539
+ past_key_values = DynamicCache()
540
+
541
+ if cache_position is None:
542
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
543
+ cache_position = torch.arange(
544
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
545
+ )
546
+
547
+ if position_ids is None:
548
+ position_ids = cache_position.unsqueeze(0)
549
+
550
+ causal_mask = self._update_causal_mask(
551
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
552
+ )
553
+
554
+ hidden_states = inputs_embeds
555
+
556
+ # create position embeddings to be shared across the decoder layers
557
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
558
+
559
+ # decoder layers
560
+ all_hidden_states = () if output_hidden_states else None
561
+ all_self_attns = () if output_attentions else None
562
+
563
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
564
+ if output_hidden_states:
565
+ all_hidden_states += (hidden_states,)
566
+
567
+ layer_outputs = decoder_layer(
568
+ hidden_states,
569
+ attention_mask=causal_mask,
570
+ position_ids=position_ids,
571
+ past_key_value=past_key_values,
572
+ output_attentions=output_attentions,
573
+ use_cache=use_cache,
574
+ cache_position=cache_position,
575
+ position_embeddings=position_embeddings,
576
+ **flash_attn_kwargs,
577
+ )
578
+
579
+ hidden_states = layer_outputs[0]
580
+
581
+ if output_attentions:
582
+ all_self_attns += (layer_outputs[1],)
583
+
584
+ hidden_states = self.norm(hidden_states)
585
+
586
+ # add hidden states from the last decoder layer
587
+ if output_hidden_states:
588
+ all_hidden_states += (hidden_states,)
589
+
590
+ return BaseModelOutputWithPast(
591
+ last_hidden_state=hidden_states,
592
+ past_key_values=past_key_values if use_cache else None,
593
+ hidden_states=all_hidden_states,
594
+ attentions=all_self_attns,
595
+ )
596
+
597
+ def _update_causal_mask(
598
+ self,
599
+ attention_mask: Union[torch.Tensor, "BlockMask"],
600
+ input_tensor: torch.Tensor,
601
+ cache_position: torch.Tensor,
602
+ past_key_values: Cache,
603
+ output_attentions: bool = False,
604
+ ):
605
+ if self.config._attn_implementation == "flash_attention_2":
606
+ if attention_mask is not None and (attention_mask == 0.0).any():
607
+ return attention_mask
608
+ return None
609
+ if self.config._attn_implementation == "flex_attention":
610
+ if isinstance(attention_mask, torch.Tensor):
611
+ attention_mask = make_flex_block_causal_mask(attention_mask)
612
+ return attention_mask
613
+
614
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
615
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
616
+ # to infer the attention mask.
617
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
618
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
619
+
620
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
621
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
622
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
623
+ attention_mask,
624
+ inputs_embeds=input_tensor,
625
+ past_key_values_length=past_seen_tokens,
626
+ is_training=self.training,
627
+ ):
628
+ return None
629
+
630
+ dtype = input_tensor.dtype
631
+ sequence_length = input_tensor.shape[1]
632
+ if using_compilable_cache:
633
+ target_length = past_key_values.get_max_cache_shape()
634
+ else:
635
+ target_length = (
636
+ attention_mask.shape[-1]
637
+ if isinstance(attention_mask, torch.Tensor)
638
+ else past_seen_tokens + sequence_length + 1
639
+ )
640
+
641
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
642
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
643
+ attention_mask,
644
+ sequence_length=sequence_length,
645
+ target_length=target_length,
646
+ dtype=dtype,
647
+ cache_position=cache_position,
648
+ batch_size=input_tensor.shape[0],
649
+ )
650
+
651
+ if (
652
+ self.config._attn_implementation == "sdpa"
653
+ and attention_mask is not None
654
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
655
+ and not output_attentions
656
+ ):
657
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
658
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
659
+ # Details: https://github.com/pytorch/pytorch/issues/110213
660
+ min_dtype = torch.finfo(dtype).min
661
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
662
+
663
+ return causal_mask
664
+
665
+ @staticmethod
666
+ def _prepare_4d_causal_attention_mask_with_cache_position(
667
+ attention_mask: torch.Tensor,
668
+ sequence_length: int,
669
+ target_length: int,
670
+ dtype: torch.dtype,
671
+ cache_position: torch.Tensor,
672
+ batch_size: int,
673
+ **kwargs,
674
+ ):
675
+ """
676
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
677
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
678
+
679
+ Args:
680
+ attention_mask (`torch.Tensor`):
681
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
682
+ `(batch_size, 1, query_length, key_value_length)`.
683
+ sequence_length (`int`):
684
+ The sequence length being processed.
685
+ target_length (`int`):
686
+ The target length: when generating with static cache, the mask should be as long as the static cache,
687
+ to account for the 0 padding, the part of the cache that is not filled yet.
688
+ dtype (`torch.dtype`):
689
+ The dtype to use for the 4D attention mask.
690
+ cache_position (`torch.Tensor`):
691
+ Indices depicting the position of the input sequence tokens in the sequence.
692
+ batch_size (`torch.Tensor`):
693
+ Batch size.
694
+ """
695
+ if attention_mask is not None and attention_mask.dim() == 4:
696
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
697
+ causal_mask = attention_mask
698
+ else:
699
+ min_dtype = torch.finfo(dtype).min
700
+ causal_mask = torch.full(
701
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
702
+ )
703
+ if sequence_length != 1:
704
+ causal_mask = torch.triu(causal_mask, diagonal=1)
705
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
706
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
707
+ if attention_mask is not None:
708
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
709
+ mask_length = attention_mask.shape[-1]
710
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
711
+ causal_mask.device
712
+ )
713
+ padding_mask = padding_mask == 0
714
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
715
+ padding_mask, min_dtype
716
+ )
717
+
718
+ return causal_mask
719
+
720
+
721
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
722
+
723
+
724
+ @auto_docstring
725
+ class HyperCLOVAXForCausalLM(HyperCLOVAXPreTrainedModel, GenerationMixin):
726
+ _tied_weights_keys = ["lm_head.weight"]
727
+ _tp_plan = {"lm_head": "colwise_rep"}
728
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
729
+
730
+ def __init__(self, config):
731
+ super().__init__(config)
732
+ self.model = HyperCLOVAXModel(config)
733
+ self.vocab_size = config.vocab_size
734
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
735
+ self.logits_scaling = getattr(config, "logits_scaling", 1.0)
736
+
737
+ # Initialize weights and apply final processing
738
+ self.post_init()
739
+
740
+ def get_input_embeddings(self):
741
+ return self.model.embed_tokens
742
+
743
+ def set_input_embeddings(self, value):
744
+ self.model.embed_tokens = value
745
+
746
+ def get_output_embeddings(self):
747
+ return self.lm_head
748
+
749
+ def set_output_embeddings(self, new_embeddings):
750
+ self.lm_head = new_embeddings
751
+
752
+ # -------- DeepConf helpers ----------
753
+ def _dc_collect_eos(self, explicit: Optional[Union[int, List[int]]] = None, **kwargs) -> List[int]:
754
+ ids: List[int] = []
755
+ if explicit is not None:
756
+ ids.extend([int(x) for x in (explicit if isinstance(explicit, (list,tuple)) else [explicit])])
757
+ else:
758
+ if getattr(self.config, "eos_token_id", None) is not None:
759
+ ids.append(int(self.config.eos_token_id))
760
+ if getattr(self.config, "eos_token_id_list", None):
761
+ ids.extend(int(x) for x in self.config.eos_token_id_list if x is not None)
762
+ extra = os.getenv("HYPERCLOVA_DEEPCONF_EOS_IDS", "").strip()
763
+ if extra:
764
+ ids.extend(int(tok) for tok in extra.split(",") if tok.strip().isdigit())
765
+ return sorted({i for i in ids if i >= 0})
766
+
767
+ def _dc_enabled(self) -> bool:
768
+ enabled = True
769
+ env = os.getenv("HYPERCLOVA_DEEPCONF", "").strip().lower()
770
+ if env in {"0","off","false"}: enabled = False
771
+ elif env in {"1","on","true"}: enabled = True
772
+ cfg_en = getattr(self.config, "deepconf_enable", None)
773
+ if cfg_en is not None:
774
+ enabled = bool(cfg_en) # If config is specified, it takes precedence
775
+ if getattr(self.config, "deepconf_disable", False):
776
+ enabled = False # Force OFF flag
777
+ return enabled
778
+
779
+ def _dc_params(self) -> Tuple[int,int,float,int]:
780
+ def env_int(k, d): v=os.getenv(k); return int(v) if v not in (None,"") else d
781
+ def env_flt(k, d): v=os.getenv(k); return float(v) if v not in (None,"") else d
782
+ window = env_int("HYPERCLOVA_DEEPCONF_WINDOW", getattr(self.config, "deepconf_window", 512))
783
+ top_r = env_int("HYPERCLOVA_DEEPCONF_TOPR", getattr(self.config, "deepconf_top_r", 5))
784
+ thr = env_flt("HYPERCLOVA_DEEPCONF_THRESH", getattr(self.config, "deepconf_threshold", -3.5))
785
+ warmup = env_int("HYPERCLOVA_DEEPCONF_WARMUP", getattr(self.config, "deepconf_warmup_tokens", 0))
786
+ return window, top_r, thr, warmup
787
+
788
+ def deepconf_generate(self, *args,
789
+ eos_token_id: Optional[Union[int, List[int]]] = None,
790
+ window: int = 512, top_r: int = 5, threshold: float = -3.5,
791
+ warmup_tokens: int = 0,
792
+ **kwargs):
793
+ # Prefer ChatML stop strings if tokenizer+stop_strings are provided
794
+ prefer_ids: List[int] = []
795
+ tok = kwargs.get("tokenizer", None)
796
+ stop_strings = kwargs.get("stop_strings", None)
797
+ if tok is not None and stop_strings:
798
+ for s in stop_strings:
799
+ try:
800
+ eid = tok.convert_tokens_to_ids(s)
801
+ if isinstance(eid, int) and eid >= 0:
802
+ prefer_ids.append(int(eid)); continue
803
+ except Exception:
804
+ pass
805
+ try:
806
+ enc = tok.encode(s, add_special_tokens=False)
807
+ if isinstance(enc, list) and len(enc) == 1:
808
+ prefer_ids.append(int(enc[0]))
809
+ except Exception:
810
+ pass
811
+ lp: LogitsProcessorList = kwargs.pop("logits_processor", None) or LogitsProcessorList()
812
+ lp.append(
813
+ DeepConfEOSLogitsProcessor(
814
+ self._dc_collect_eos(eos_token_id, **kwargs),
815
+ window, top_r, threshold,
816
+ warmup_tokens=warmup_tokens,
817
+ prefer_eos_ids=prefer_ids or None
818
+ )
819
+ )
820
+ kwargs["logits_processor"] = lp
821
+ return super().generate(*args, **kwargs)
822
+
823
+ # Override generate() to be default ON (auto-attach DeepConf; merge with external lps)
824
+ def generate(self, *args, **kwargs):
825
+ if self._dc_enabled():
826
+ eos_ids = self._dc_collect_eos(kwargs.get("eos_token_id", None), **kwargs)
827
+ # Prefer ChatML end tokens if provided
828
+ prefer_ids: List[int] = []
829
+ tok = kwargs.get("tokenizer", None)
830
+ stop_strings = kwargs.get("stop_strings", None)
831
+ im_end_id = None
832
+ if tok is not None and stop_strings:
833
+ for s in stop_strings:
834
+ try:
835
+ eid = tok.convert_tokens_to_ids(s)
836
+ if isinstance(eid, int) and eid >= 0: prefer_ids.append(int(eid)); continue
837
+ except Exception: pass
838
+ try:
839
+ enc = tok.encode(s, add_special_tokens=False)
840
+ if isinstance(enc, list) and len(enc) == 1: prefer_ids.append(int(enc[0]))
841
+ except Exception: pass
842
+
843
+ # For ChatML protection: extract <|im_end|> id
844
+ if tok is not None:
845
+ try:
846
+ im_end_id = tok.convert_tokens_to_ids("<|im_end|>")
847
+ if not isinstance(im_end_id, int) or im_end_id < 0:
848
+ im_end_id = None
849
+ except Exception:
850
+ im_end_id = None
851
+
852
+ if eos_ids:
853
+ window, top_r, thr, warmup = self._dc_params()
854
+ require_prev = None
855
+ if (os.getenv("HYPERCLOVA_DEEPCONF_REQUIRE_IM_END", "1").lower() in {"1","on","true"}) and prefer_ids and im_end_id is not None:
856
+ require_prev = im_end_id
857
+
858
+ lp: LogitsProcessorList = kwargs.pop("logits_processor", None) or LogitsProcessorList()
859
+
860
+ if os.getenv("HYPERCLOVA_DEEPCONF_VERBOSE_ATTACH","0") in {"1","on","true"}:
861
+ print(f"[DeepConf] attach window={window} top_r={top_r} thr={thr} warmup={warmup} eos={eos_ids} prefer={prefer_ids} require_prev={require_prev}")
862
+
863
+ lp.append(
864
+ DeepConfEOSLogitsProcessor(
865
+ eos_ids, window, top_r, thr,
866
+ warmup_tokens=warmup,
867
+ prefer_eos_ids=prefer_ids or None,
868
+ require_prev_id=require_prev
869
+ )
870
+ )
871
+ kwargs["logits_processor"] = lp
872
+ return super().generate(*args, **kwargs)
873
+
874
+ def set_decoder(self, decoder):
875
+ self.model = decoder
876
+
877
+ def get_decoder(self):
878
+ return self.model
879
+
880
+ @can_return_tuple
881
+ @auto_docstring
882
+ def forward(
883
+ self,
884
+ input_ids: Optional[torch.LongTensor] = None,
885
+ attention_mask: Optional[torch.Tensor] = None,
886
+ position_ids: Optional[torch.LongTensor] = None,
887
+ past_key_values: Optional[Cache] = None,
888
+ inputs_embeds: Optional[torch.FloatTensor] = None,
889
+ labels: Optional[torch.LongTensor] = None,
890
+ use_cache: Optional[bool] = None,
891
+ output_attentions: Optional[bool] = None,
892
+ output_hidden_states: Optional[bool] = None,
893
+ cache_position: Optional[torch.LongTensor] = None,
894
+ logits_to_keep: Union[int, torch.Tensor] = 0,
895
+ **kwargs: Unpack[KwargsForCausalLM],
896
+ ) -> CausalLMOutputWithPast:
897
+ r"""
898
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
899
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
900
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
901
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
902
+
903
+ Example:
904
+
905
+ ```python
906
+ >>> from transformers import AutoTokenizer, HyperCLOVAXForCausalLM
907
+
908
+ >>> model = HyperCLOVAXForCausalLM.from_pretrained("naver-hyperclovax/{model_name}")
909
+ >>> tokenizer = AutoTokenizer.from_pretrained("naver-hyperclovax/{model_name}")
910
+
911
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
912
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
913
+
914
+ >>> # Generate
915
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
916
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
917
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
918
+ ```"""
919
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
920
+ output_hidden_states = (
921
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
922
+ )
923
+
924
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
925
+ outputs: BaseModelOutputWithPast = self.model(
926
+ input_ids=input_ids,
927
+ attention_mask=attention_mask,
928
+ position_ids=position_ids,
929
+ past_key_values=past_key_values,
930
+ inputs_embeds=inputs_embeds,
931
+ use_cache=use_cache,
932
+ output_attentions=output_attentions,
933
+ output_hidden_states=output_hidden_states,
934
+ cache_position=cache_position,
935
+ **kwargs,
936
+ )
937
+
938
+ hidden_states = outputs.last_hidden_state
939
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
940
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
941
+ # MuP
942
+ logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.logits_scaling
943
+
944
+ loss = None
945
+ if labels is not None:
946
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
947
+
948
+ return CausalLMOutputWithPast(
949
+ loss=loss,
950
+ logits=logits,
951
+ past_key_values=outputs.past_key_values,
952
+ hidden_states=outputs.hidden_states,
953
+ attentions=outputs.attentions,
954
+ )
955
+
956
+
957
+ @auto_docstring(
958
+ custom_intro="""
959
+ The HyperCLOVAX Model transformer with a sequence classification head on top (linear layer).
960
+
961
+ [`HyperCLOVAXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
962
+ (e.g. GPT-2) do.
963
+
964
+ Since it does classification on the last token, it requires to know the position of the last token. If a
965
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
966
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
967
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
968
+ each row of the batch).
969
+ """
970
+ )
971
+ class HyperCLOVAXForSequenceClassification(HyperCLOVAXPreTrainedModel):
972
+ def __init__(self, config):
973
+ super().__init__(config)
974
+ self.num_labels = config.num_labels
975
+ self.model = HyperCLOVAXModel(config)
976
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.model.embed_tokens
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.model.embed_tokens = value
986
+
987
+ @can_return_tuple
988
+ @auto_docstring
989
+ def forward(
990
+ self,
991
+ input_ids: Optional[torch.LongTensor] = None,
992
+ attention_mask: Optional[torch.Tensor] = None,
993
+ position_ids: Optional[torch.LongTensor] = None,
994
+ past_key_values: Optional[Cache] = None,
995
+ inputs_embeds: Optional[torch.FloatTensor] = None,
996
+ labels: Optional[torch.LongTensor] = None,
997
+ use_cache: Optional[bool] = None,
998
+ output_attentions: Optional[bool] = None,
999
+ output_hidden_states: Optional[bool] = None,
1000
+ ) -> SequenceClassifierOutputWithPast:
1001
+ r"""
1002
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1003
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1004
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1005
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1006
+ """
1007
+
1008
+ transformer_outputs: BaseModelOutputWithPast = self.model(
1009
+ input_ids,
1010
+ attention_mask=attention_mask,
1011
+ position_ids=position_ids,
1012
+ past_key_values=past_key_values,
1013
+ inputs_embeds=inputs_embeds,
1014
+ use_cache=use_cache,
1015
+ output_attentions=output_attentions,
1016
+ output_hidden_states=output_hidden_states,
1017
+ )
1018
+ hidden_states = transformer_outputs.last_hidden_state
1019
+ logits = self.score(hidden_states)
1020
+
1021
+ if input_ids is not None:
1022
+ batch_size = input_ids.shape[0]
1023
+ else:
1024
+ batch_size = inputs_embeds.shape[0]
1025
+
1026
+ if self.config.pad_token_id is None and batch_size != 1:
1027
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1028
+ if self.config.pad_token_id is None:
1029
+ last_non_pad_token = -1
1030
+ elif input_ids is not None:
1031
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1032
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1033
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1034
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1035
+ else:
1036
+ last_non_pad_token = -1
1037
+ logger.warning_once(
1038
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1039
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1040
+ )
1041
+
1042
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1043
+
1044
+ loss = None
1045
+ if labels is not None:
1046
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1047
+
1048
+ return SequenceClassifierOutputWithPast(
1049
+ loss=loss,
1050
+ logits=pooled_logits,
1051
+ past_key_values=transformer_outputs.past_key_values,
1052
+ hidden_states=transformer_outputs.hidden_states,
1053
+ attentions=transformer_outputs.attentions,
1054
+ )
1055
+
1056
+
1057
+ @auto_docstring
1058
+ class HyperCLOVAXForQuestionAnswering(HyperCLOVAXPreTrainedModel):
1059
+ base_model_prefix = "transformer"
1060
+
1061
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->HyperCLOVAX
1062
+ def __init__(self, config):
1063
+ super().__init__(config)
1064
+ self.transformer = HyperCLOVAXModel(config)
1065
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1066
+
1067
+ # Initialize weights and apply final processing
1068
+ self.post_init()
1069
+
1070
+ def get_input_embeddings(self):
1071
+ return self.transformer.embed_tokens
1072
+
1073
+ def set_input_embeddings(self, value):
1074
+ self.transformer.embed_tokens = value
1075
+
1076
+ @can_return_tuple
1077
+ @auto_docstring
1078
+ def forward(
1079
+ self,
1080
+ input_ids: Optional[torch.LongTensor] = None,
1081
+ attention_mask: Optional[torch.Tensor] = None,
1082
+ position_ids: Optional[torch.LongTensor] = None,
1083
+ past_key_values: Optional[Cache] = None,
1084
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1085
+ start_positions: Optional[torch.LongTensor] = None,
1086
+ end_positions: Optional[torch.LongTensor] = None,
1087
+ output_attentions: Optional[bool] = None,
1088
+ output_hidden_states: Optional[bool] = None,
1089
+ **kwargs,
1090
+ ) -> QuestionAnsweringModelOutput:
1091
+ outputs: BaseModelOutputWithPast = self.transformer(
1092
+ input_ids,
1093
+ attention_mask=attention_mask,
1094
+ position_ids=position_ids,
1095
+ past_key_values=past_key_values,
1096
+ inputs_embeds=inputs_embeds,
1097
+ output_attentions=output_attentions,
1098
+ output_hidden_states=output_hidden_states,
1099
+ )
1100
+
1101
+ sequence_output = outputs.last_hidden_state
1102
+
1103
+ logits = self.qa_outputs(sequence_output)
1104
+ start_logits, end_logits = logits.split(1, dim=-1)
1105
+ start_logits = start_logits.squeeze(-1).contiguous()
1106
+ end_logits = end_logits.squeeze(-1).contiguous()
1107
+
1108
+ loss = None
1109
+ if start_positions is not None and end_positions is not None:
1110
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1111
+
1112
+ return QuestionAnsweringModelOutput(
1113
+ loss=loss,
1114
+ start_logits=start_logits,
1115
+ end_logits=end_logits,
1116
+ hidden_states=outputs.hidden_states,
1117
+ attentions=outputs.attentions,
1118
+ )
1119
+
1120
+
1121
+ @auto_docstring
1122
+ class HyperCLOVAXForTokenClassification(HyperCLOVAXPreTrainedModel):
1123
+ def __init__(self, config):
1124
+ super().__init__(config)
1125
+ self.num_labels = config.num_labels
1126
+ self.model = HyperCLOVAXModel(config)
1127
+ if getattr(config, "classifier_dropout", None) is not None:
1128
+ classifier_dropout = config.classifier_dropout
1129
+ elif getattr(config, "hidden_dropout", None) is not None:
1130
+ classifier_dropout = config.hidden_dropout
1131
+ else:
1132
+ classifier_dropout = 0.1
1133
+ self.dropout = nn.Dropout(classifier_dropout)
1134
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1135
+
1136
+ # Initialize weights and apply final processing
1137
+ self.post_init()
1138
+
1139
+ def get_input_embeddings(self):
1140
+ return self.model.embed_tokens
1141
+
1142
+ def set_input_embeddings(self, value):
1143
+ self.model.embed_tokens = value
1144
+
1145
+ @can_return_tuple
1146
+ @auto_docstring
1147
+ def forward(
1148
+ self,
1149
+ input_ids: Optional[torch.LongTensor] = None,
1150
+ attention_mask: Optional[torch.Tensor] = None,
1151
+ position_ids: Optional[torch.LongTensor] = None,
1152
+ past_key_values: Optional[Cache] = None,
1153
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1154
+ labels: Optional[torch.LongTensor] = None,
1155
+ use_cache: Optional[bool] = None,
1156
+ output_attentions: Optional[bool] = None,
1157
+ output_hidden_states: Optional[bool] = None,
1158
+ ) -> TokenClassifierOutput:
1159
+ r"""
1160
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1161
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1162
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1163
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1164
+ """
1165
+
1166
+ outputs: BaseModelOutputWithPast = self.model(
1167
+ input_ids,
1168
+ attention_mask=attention_mask,
1169
+ position_ids=position_ids,
1170
+ past_key_values=past_key_values,
1171
+ inputs_embeds=inputs_embeds,
1172
+ use_cache=use_cache,
1173
+ output_attentions=output_attentions,
1174
+ output_hidden_states=output_hidden_states,
1175
+ )
1176
+ sequence_output = outputs.last_hidden_state
1177
+ sequence_output = self.dropout(sequence_output)
1178
+ logits = self.score(sequence_output)
1179
+
1180
+ loss = None
1181
+ if labels is not None:
1182
+ loss = self.loss_function(logits, labels, self.config)
1183
+
1184
+ return TokenClassifierOutput(
1185
+ loss=loss,
1186
+ logits=logits,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ )
1190
+
1191
+
1192
+ __all__ = [
1193
+ "HyperCLOVAXForCausalLM",
1194
+ "HyperCLOVAXModel",
1195
+ "HyperCLOVAXPreTrainedModel",
1196
+ "HyperCLOVAXForSequenceClassification",
1197
+ "HyperCLOVAXForQuestionAnswering",
1198
+ "HyperCLOVAXForTokenClassification",
1199
+ ]