sudy-super commited on
Commit
b7c4310
·
verified ·
1 Parent(s): 6abf0f0

Update modeling_c_cubed.py

Browse files
Files changed (1) hide show
  1. modeling_c_cubed.py +710 -738
modeling_c_cubed.py CHANGED
@@ -1,738 +1,710 @@
1
- # coding=utf-8
2
- """PyTorch Ccubed model."""
3
-
4
- import math
5
- from dataclasses import dataclass
6
- from typing import List, Optional, Tuple, Union
7
-
8
- import numpy as np
9
- import torch
10
- import torch.utils.checkpoint
11
- from torch import nn
12
-
13
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
14
- from transformers.activations import ACT2FN
15
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
- from transformers.processing_utils import Unpack
17
- from transformers.image_processing_utils import select_best_resolution
18
- from transformers.modeling_outputs import ModelOutput
19
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
20
- from transformers.utils import (
21
- add_start_docstrings,
22
- add_start_docstrings_to_model_forward,
23
- logging,
24
- replace_return_docstrings,
25
- is_flash_attn_2_available,
26
- is_flash_attn_greater_or_equal_2_10
27
- )
28
- from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
29
- from .configuration_c_cubed import CcubedConfig
30
-
31
-
32
- logger = logging.get_logger(__name__)
33
-
34
- _CONFIG_FOR_DOC = "CcubedConfig"
35
-
36
-
37
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
38
- """
39
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
40
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
41
- """
42
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
43
- if n_rep == 1:
44
- return hidden_states
45
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
46
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
47
-
48
-
49
- @dataclass
50
- class CcubedCausalLMOutputWithPast(ModelOutput):
51
- """
52
- Base class for Ccubed causal language model (or autoregressive) outputs.
53
-
54
- Args:
55
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
56
- Language modeling loss (for next-token prediction).
57
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
58
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
59
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
60
- Tuple of `tuple(torch.FloatTensor)` of length `config.context_config.num_layers`, with each tuple having 2 tensors of shape
61
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
62
-
63
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
64
- `past_key_values` input) to speed up sequential decoding.
65
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
-
69
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
- sequence_length)`.
73
-
74
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
- heads.
76
- context_hidden_states (`torch.FloatTensor`, *optional*):
77
- A `torch.FloatTensor` of size (batch_size, sequence_length, hidden_size)`.
78
- context_hidden_states of the model produced by the context encoder and after projecting the last hidden state.
79
- """
80
-
81
- loss: Optional[torch.FloatTensor] = None
82
- logits: torch.FloatTensor = None
83
- past_key_values: Optional[List[torch.FloatTensor]] = None
84
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
85
- attentions: Optional[Tuple[torch.FloatTensor]] = None
86
- context_hidden_states: Optional[torch.FloatTensor] = None
87
-
88
-
89
- class CcubedDynamicAttention(nn.Module):
90
- """
91
- Attention mechanism adapted for dynamic output size based on Mistral's architecture. This attention layer computes
92
- the output attention scores which are used to determine the pooling size dynamically.
93
- """
94
-
95
- def __init__(self, config: CcubedConfig):
96
- super().__init__()
97
-
98
- self.config = config
99
- self.hidden_size = config.context_config.hidden_size
100
- self.num_heads = config.context_config.num_attention_heads
101
- self.head_dim = getattr(config.context_config, "head_dim", self.hidden_size // self.num_heads)
102
- self.num_key_value_heads = config.context_config.num_key_value_heads
103
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
104
- self.scaling = self.head_dim ** -0.5
105
- self.attention_dropout = getattr(self.config.context_config, "attention_dropout", 0.0)
106
-
107
- # Query, Key, Value, and Output Projections
108
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
109
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
110
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
111
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, 1, bias=False)
112
-
113
- def forward(
114
- self,
115
- hidden_states: torch.Tensor,
116
- attention_mask: Optional[torch.Tensor] = None,
117
- output_attentions: bool = False,
118
- ):
119
- # Get input dimensions
120
- bsz, seq_len, hidden_size = hidden_states.size()
121
-
122
- # Query, Key, Value projections
123
- query_states = self.q_proj(hidden_states)
124
- key_states = self.k_proj(hidden_states)
125
- value_states = self.v_proj(hidden_states)
126
-
127
- # Reshape and transpose to [batch_size, num_heads, seq_len, head_dim]
128
- query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
129
- key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
130
- value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
131
-
132
- # Repeat key and value states for multi-head attention
133
- key_states = repeat_kv(key_states, self.num_key_value_groups)
134
- value_states = repeat_kv(value_states, self.num_key_value_groups)
135
-
136
- # Compute attention scores
137
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
138
-
139
- # Apply softmax to get attention probabilities
140
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
141
-
142
- # Apply attention to values
143
- attn_output = torch.matmul(attn_weights, value_states)
144
-
145
- # Reshape attention output
146
- attn_output = attn_output.transpose(1, 2).contiguous()
147
- attn_output = attn_output.reshape(bsz, seq_len, -1)
148
-
149
- # Project to output dimension
150
- attn_output = self.o_proj(attn_output)
151
-
152
- if not output_attentions:
153
- attn_weights = None
154
-
155
- return attn_output, attn_weights
156
-
157
-
158
- class CcubedDynamicFlashAttention2(CcubedDynamicAttention):
159
- def __init__(self, config: CcubedConfig):
160
- super().__init__(config)
161
- self.is_causal = False # Assuming non-causal attention for this context
162
-
163
- def forward(
164
- self,
165
- hidden_states: torch.Tensor,
166
- attention_mask: Optional[torch.Tensor] = None,
167
- output_attentions: bool = False,
168
- **kwargs: Unpack[FlashAttentionKwargs],
169
- ):
170
- input_shape = hidden_states.shape[:-1]
171
- hidden_shape = (*input_shape, -1, self.head_dim)
172
-
173
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
174
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
175
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
-
177
- sliding_window = None
178
- if getattr(self.config, "sliding_window", None) is not None:
179
- sliding_window = self.config.sliding_window
180
-
181
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
182
-
183
- attn_output, attn_weights = attention_interface(
184
- self,
185
- query_states,
186
- key_states,
187
- value_states,
188
- attention_mask,
189
- dropout=0.0 if not self.training else self.attention_dropout,
190
- scaling=self.scaling,
191
- sliding_window=sliding_window, # main diff with Llama
192
- **kwargs,
193
- )
194
-
195
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
196
- attn_output = self.o_proj(attn_output)
197
- return attn_output, attn_weights
198
-
199
-
200
- class CcubedDynamicWeightedAvgPool1d(nn.Module):
201
- """
202
- A module that dynamically determines the output size based on input
203
- and performs weighted average pooling with separate attention mechanisms
204
- for output size estimation and weighted pooling.
205
- """
206
- def __init__(self, config, output_size_min=32, output_size_max=131072):
207
- super().__init__()
208
- # Attention mechanism for estimating output size
209
- self.size_estim_attn = CcubedDynamicFlashAttention2(config) # CcubedDynamicAttention(config)
210
- # Attention mechanism for weighted pooling
211
- self.imp_estim_attn = CcubedDynamicFlashAttention2(config) # CcubedDynamicAttention(config)
212
- self.output_size_min = output_size_min
213
- self.output_size_max = (
214
- config.context_config.max_position_embeddings if config.context_config.max_position_embeddings is not None else output_size_max
215
- )
216
- self.scale_param = nn.Parameter(torch.tensor(0.01))
217
-
218
- def forward(self, hidden_states, context_attention_mask=None):
219
- """
220
- Args:
221
- x: Input tensor of shape (batch_size, seq_len, hidden_size)
222
-
223
- Returns:
224
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
- - pooled_output: Padded tensor of compressed sequences (batch_size, max_pooled_len, hidden_size)
226
- - attention_mask: Binary mask indicating valid tokens (batch_size, max_pooled_len)
227
- - dynamic_output_sizes: Dynamic output sizes for each batch (batch_size,)
228
- """
229
- batch_size, seq_len, hidden_size = hidden_states.size()
230
- device = hidden_states.device
231
-
232
- # Estimate output size using attention mechanism
233
- # attn_output_size: (batch_size, seq_len, 1)
234
- attn_output_size, _ = self.size_estim_attn(hidden_states)
235
-
236
- # Calculate dynamic output sizes for each batch item
237
- # (batch_size, seq_len, 1) -> (batch_size, 1)
238
- batch_attn_means = torch.sigmoid(attn_output_size).mean(dim=1)
239
- scaled_batch_means = batch_attn_means * self.scale_param.to(batch_attn_means.dtype)
240
-
241
- # Calculate dynamic output sizes (batch_size,)
242
- dynamic_output_sizes = (
243
- (scaled_batch_means * (self.output_size_max - self.output_size_min)) + self.output_size_min
244
- ).int().squeeze(-1)
245
-
246
- max_pooled_len = dynamic_output_sizes.max().item()
247
-
248
- # Compute attention weights for weighted pooling
249
- # attn_output_weights: (batch_size, seq_len, 1)
250
- attn_output_weights, _ = self.imp_estim_attn(hidden_states)
251
- # Normalize with sigmoid function for use as weights
252
- # attention_weights: (batch_size, seq_len)
253
- attention_weights = torch.sigmoid(attn_output_weights).squeeze(-1)
254
-
255
- # If context_attention_mask is provided, apply it to zero out weights for invalid tokens
256
- if context_attention_mask is not None:
257
- attention_weights = attention_weights * context_attention_mask
258
-
259
- # Initialize output tensors
260
- # pooled_output: (batch_size, max_pooled_len, hidden_size)
261
- pooled_output = torch.zeros(
262
- batch_size, max_pooled_len, hidden_size,
263
- device=device, dtype=hidden_states.dtype
264
- )
265
- # attention_mask: (batch_size, max_pooled_len)
266
- attention_mask = torch.zeros(
267
- batch_size, max_pooled_len,
268
- dtype=torch.bool, device=device
269
- )
270
-
271
- for batch_idx in range(batch_size):
272
- output_size = dynamic_output_sizes[batch_idx].item()
273
- item_input = hidden_states[batch_idx] # Shape: (seq_len, hidden_size)
274
- item_weights = attention_weights[batch_idx] # Shape: (seq_len)
275
-
276
- # Perform weighted pooling
277
- pooled_values = []
278
- batch_attn_mask = torch.zeros(output_size, dtype=torch.bool, device=device)
279
- # Split the sequence evenly
280
- intervals = torch.linspace(0, seq_len, steps=output_size + 1).long()
281
- for i in range(output_size):
282
- start = intervals[i].item()
283
- end = intervals[i + 1].item()
284
- chunk_input = item_input[start:end] # Shape: (chunk_size, hidden_size)
285
- chunk_weights = item_weights[start:end] # Shape: (chunk_size)
286
- if chunk_weights.sum() == 0:
287
- # If the sum of weights is zero, add a zero vector
288
- pooled_value = torch.zeros(hidden_size, device=device, dtype=hidden_states.dtype)
289
- else:
290
- # Calculate weighted average
291
- weighted_input = chunk_input * chunk_weights.unsqueeze(-1) # Shape: (chunk_size, hidden_size)
292
- pooled_value = weighted_input.sum(dim=0) / (chunk_weights.sum() + 1e-8) # Shape: (hidden_size)
293
- batch_attn_mask[i] = True
294
- pooled_values.append(pooled_value)
295
-
296
- if pooled_values: # Only stack if there are values
297
- # Convert the result to a tensor
298
- pooled_values = torch.stack(pooled_values) # Shape: (output_size, hidden_size)
299
- # Store the result
300
- pooled_output[batch_idx, -output_size:] = pooled_values
301
- attention_mask[batch_idx, -output_size:] = batch_attn_mask
302
-
303
- return pooled_output, attention_mask, dynamic_output_sizes
304
-
305
-
306
- class CcubedContextLanguageConnector(nn.Module):
307
- def __init__(self, config: CcubedConfig):
308
- super().__init__()
309
-
310
- self.dynamic_pooling = CcubedDynamicWeightedAvgPool1d(config)
311
-
312
- self.linear_1 = nn.Linear(
313
- config.context_config.hidden_size,
314
- config.text_config.hidden_size,
315
- bias=True
316
- )
317
- self.act = ACT2FN[config.projector_hidden_act]
318
- self.linear_2 = nn.Linear(
319
- config.text_config.hidden_size,
320
- config.text_config.hidden_size,
321
- bias=True
322
- )
323
-
324
- def forward(self, context_features):
325
- # context_features: [batch_size, seq_len, hidden_size]
326
- # Apply dynamic adaptive average pooling with attention
327
- pooled_output, attention_mask, dynamic_output_sizes = self.dynamic_pooling(
328
- hidden_states=context_features
329
- )
330
-
331
- hidden_states = self.linear_1(pooled_output)
332
- hidden_states = self.act(hidden_states)
333
- hidden_states = self.linear_2(hidden_states)
334
-
335
- return hidden_states, attention_mask
336
-
337
-
338
- class CcubedContextTower(nn.Module):
339
- def __init__(self, config: CcubedConfig):
340
- super().__init__()
341
-
342
- self.tower = AutoModelForCausalLM.from_config(
343
- config.context_config,
344
- attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
345
- )
346
- self.select_layer = config.context_feature_layer
347
-
348
- def feature_select(self, llm_outputs):
349
- hidden_states = llm_outputs.hidden_states
350
- return hidden_states[self.select_layer]
351
-
352
- def forward(
353
- self,
354
- input_ids,
355
- inputs_embeds,
356
- attention_mask
357
- ):
358
- outputs = self.tower(
359
- input_ids=input_ids,
360
- inputs_embeds=inputs_embeds,
361
- attention_mask=attention_mask,
362
- output_hidden_states=True
363
- )
364
- features = self.feature_select(outputs)
365
- return features
366
-
367
-
368
- class CcubedPreTrainedModel(PreTrainedModel):
369
- config_class = CcubedConfig
370
- base_model_prefix = "model"
371
- supports_gradient_checkpointing = True
372
- _no_split_modules = [] # ["CcubedContextLanguageConnector", "CcubedContextTower"]
373
- _skip_keys_device_placement = ["past_key_values"]
374
- _supports_flash_attn_2 = True
375
- _supports_sdpa = True
376
- _supports_cache_class = True
377
- _supports_quantized_cache = True
378
- _supports_static_cache = True
379
-
380
- def _init_weights(self, module):
381
- std = (
382
- self.config.initializer_range
383
- if hasattr(self.config, "initializer_range")
384
- else self.config.text_config.initializer_range
385
- )
386
- if isinstance(module, nn.Linear):
387
- module.weight.data.normal_(mean=0.0, std=std)
388
- if module.bias is not None:
389
- module.bias.data.zero_()
390
- elif isinstance(module, nn.Embedding):
391
- module.weight.data.normal_(mean=0.0, std=std)
392
- if module.padding_idx is not None:
393
- module.weight.data[module.padding_idx].zero_()
394
-
395
-
396
- class CcubedForConditionalGeneration(CcubedPreTrainedModel):
397
- def __init__(self, config: CcubedConfig):
398
- super().__init__(config)
399
- self.context_tower = CcubedContextTower(config)
400
- self.connector = CcubedContextLanguageConnector(config)
401
-
402
- self.language_model = AutoModelForCausalLM.from_config(
403
- config.text_config,
404
- attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
405
- )
406
-
407
- self.vocab_size = config.text_config.vocab_size
408
- self.ignore_index = config.ignore_index if hasattr(config, 'ignore_index') else -100
409
- self.start_of_context_token_id = config.start_of_context_token_id
410
- self.end_of_context_token_id = config.end_of_context_token_id
411
-
412
- self.post_init()
413
-
414
- def get_input_embeddings(self):
415
- return self.language_model.get_input_embeddings()
416
-
417
- def get_context_input_embeddings(self):
418
- return self.context_tower.tower.get_input_embeddings()
419
-
420
- def set_input_embeddings(self, value):
421
- self.language_model.set_input_embeddings(value)
422
-
423
- def set_context_input_embeddings(self, value):
424
- self.context_tower.tower.set_input_embeddings(value)
425
-
426
- def get_output_embeddings(self):
427
- return self.language_model.get_output_embeddings()
428
-
429
- def get_context_output_embeddings(self):
430
- return self.context_tower.tower.get_output_embeddings()
431
-
432
- def set_output_embeddings(self, new_embeddings):
433
- self.language_model.set_output_embeddings(new_embeddings)
434
-
435
- def set_context_output_embeddings(self, new_embeddings):
436
- self.context_tower.tower.set_output_embeddings(new_embeddings)
437
-
438
- def set_decoder(self, decoder):
439
- self.language_model.set_decoder(decoder)
440
-
441
- def set_context_encoder(self, decoder):
442
- self.context_tower.tower.set_decoder(decoder)
443
-
444
- def get_decoder(self):
445
- return self.language_model.get_decoder()
446
-
447
- def get_context_encoder(self):
448
- return self.context_tower.tower.get_decoder()
449
-
450
- def tie_weights(self):
451
- return self.language_model.tie_weights()
452
-
453
- def context_tie_weights(self):
454
- return self.context_tower.tower.tie_weights()
455
-
456
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
457
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
458
- # update vocab size
459
- self.config.text_config.vocab_size = model_embeds.num_embeddings
460
- self.vocab_size = model_embeds.num_embeddings
461
- return model_embeds
462
-
463
- def _merge_context_features(
464
- self,
465
- context_features = None,
466
- inputs_embeds = None,
467
- attention_mask = None,
468
- context_attention_mask=None,
469
- position_ids=None,
470
- labels=None,
471
- ):
472
- if context_features is None:
473
- return inputs_embeds, attention_mask, position_ids, labels
474
-
475
- batch_size, seq_length, embed_dim = inputs_embeds.shape
476
- context_seq_len = context_features.size(1)
477
-
478
- # Create embeddings for begin and end of context tokens
479
- begin_context_embed = self.get_input_embeddings()(torch.tensor(self.start_of_context_token_id, device=context_features.device))
480
- end_context_embed = self.get_input_embeddings()(torch.tensor(self.end_of_context_token_id, device=context_features.device))
481
-
482
- # Determine the actual lengths of context sequences (excluding padding)
483
- if context_attention_mask is not None:
484
- # context_attention_mask: [batch_size, context_seq_len, 1]
485
- context_attention_mask = context_attention_mask.squeeze(-1) # [batch_size, context_seq_len]
486
- # Sum over sequence length to get actual lengths
487
- context_lengths = context_attention_mask.sum(dim=1).long() # [batch_size]
488
- else:
489
- # If no context_attention_mask is provided, assume full length
490
- context_lengths = torch.full((batch_size,), context_seq_len, device=context_features.device, dtype=torch.long)
491
- context_attention_mask = torch.ones(batch_size, context_seq_len, device=context_features.device, dtype=torch.long)
492
-
493
- # Rearrange context features to include padding at the beginning
494
- # Identify the maximum context length (excluding padding)
495
- max_context_length = context_lengths.max().item()
496
- # Calculate the amount of padding needed for each sample
497
- padding_lengths = context_seq_len - context_lengths # [batch_size]
498
-
499
- # Create new context_features with padding at the beginning
500
- new_context_features = []
501
- for i in range(batch_size):
502
- padding_len = padding_lengths[i].item()
503
- # Create padding embeddings (zeros)
504
- padding_embed = torch.zeros(padding_len, embed_dim, device=context_features.device, dtype=context_features.dtype)
505
- # Get actual context features (excluding padding)
506
- actual_context = context_features[i, padding_len:context_seq_len]
507
- # Concatenate padding, begin token, actual context, end token
508
- sample_context = torch.cat([
509
- padding_embed,
510
- begin_context_embed.unsqueeze(0),
511
- actual_context,
512
- end_context_embed.unsqueeze(0)
513
- ], dim=0) # [context_seq_len + 2, embed_dim]
514
- new_context_features.append(sample_context)
515
- # Stack to create [batch_size, new_context_seq_len, embed_dim]
516
- context_features = torch.stack(new_context_features, dim=0)
517
- new_context_seq_len = context_features.size(1)
518
-
519
- # Update context_attention_mask accordingly
520
- new_context_attention_mask = []
521
- for i in range(batch_size):
522
- padding_len = padding_lengths[i].item()
523
- # Create padding mask (zeros)
524
- padding_mask = torch.zeros(padding_len, device=context_features.device, dtype=attention_mask.dtype)
525
- # Begin and end token masks
526
- begin_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
527
- end_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
528
- # Actual context attention mask (excluding padding)
529
- actual_mask = context_attention_mask[i, padding_len:context_seq_len]
530
- # Concatenate masks
531
- sample_mask = torch.cat([
532
- padding_mask,
533
- begin_attention,
534
- actual_mask,
535
- end_attention
536
- ], dim=0) # [context_seq_len + 2]
537
- new_context_attention_mask.append(sample_mask)
538
- # Stack to create [batch_size, new_context_seq_len]
539
- context_attention_mask = torch.stack(new_context_attention_mask, dim=0)
540
-
541
- # Concatenate context features with input embeddings
542
- new_inputs_embeds = torch.cat([context_features, inputs_embeds], dim=1) # [batch_size, total_seq_len, embed_dim]
543
-
544
- # Concatenate attention masks
545
- new_attention_mask = torch.cat([context_attention_mask, attention_mask], dim=1)
546
-
547
- # Create new position_ids
548
- total_seq_len = new_inputs_embeds.size(1)
549
- new_position_ids = torch.arange(total_seq_len, device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
550
-
551
- # Update labels if provided
552
- if labels is not None:
553
- # Create ignore labels for context (including padding and special tokens)
554
- context_labels = torch.full((batch_size, new_context_seq_len), self.ignore_index, device=labels.device, dtype=labels.dtype)
555
- new_labels = torch.cat([context_labels, labels], dim=1)
556
- else:
557
- new_labels = None
558
-
559
- return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
560
-
561
-
562
- @replace_return_docstrings(output_type=CcubedCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
563
- def forward(
564
- self,
565
- context_input_ids: torch.LongTensor = None,
566
- context_inputs_embeds: Optional[torch.FloatTensor] = None,
567
- context_attention_mask: Optional[torch.Tensor] = None,
568
- input_ids: torch.LongTensor = None,
569
- inputs_embeds: Optional[torch.FloatTensor] = None,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- position_ids: Optional[torch.LongTensor] = None,
572
- past_key_values: Optional[List[torch.FloatTensor]] = None,
573
- labels: Optional[torch.LongTensor] = None,
574
- use_cache: Optional[bool] = None,
575
- output_attentions: Optional[bool] = None,
576
- output_hidden_states: Optional[bool] = None,
577
- return_dict: Optional[bool] = None,
578
- cache_position: Optional[torch.LongTensor] = None,
579
- logits_to_keep: int = 0,
580
- ) -> Union[Tuple, CcubedCausalLMOutputWithPast]:
581
- """
582
- Perform a forward pass through the Ccubed model, optionally conditioning on context input.
583
-
584
- Args:
585
- context_input_ids (`torch.LongTensor` of shape `(batch_size, context_sequence_length)`, *optional*):
586
- Token IDs of the context input sequence.
587
- context_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, context_sequence_length, hidden_size)`, *optional*):
588
- Pre-computed context embeddings. If provided, will not compute embeddings from context_input_ids.
589
- context_attention_mask (`torch.Tensor` of shape `(batch_size, context_sequence_length)`, *optional*):
590
- Attention mask for context input sequence.
591
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
592
- Token IDs of the input sequence.
593
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
594
- Optionally, instead of passing `input_ids`, you can pass an embedded representation directly.
595
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
596
- Mask to avoid performing attention on padding token indices.
597
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
598
- Indices of positions of each input sequence token.
599
- past_key_values (`List[torch.FloatTensor]`, *optional*):
600
- Pre-computed hidden-states (key and value tensors) that can be used to speed up sequential decoding.
601
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
602
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
603
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
604
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
605
- use_cache (`bool`, *optional*):
606
- If `True`, past key values will be used to speed up decoding.
607
- output_attentions (`bool`, *optional*):
608
- If `True`, return the attention tensors for each layer.
609
- output_hidden_states (`bool`, *optional*):
610
- If `True`, return the hidden states of all layers.
611
- return_dict (`bool`, *optional*):
612
- If `True`, return a `CcubedCausalLMOutputWithPast` instead of a plain tuple.
613
- num_logits_to_keep (`int`, *optional*):
614
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
615
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
616
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
617
-
618
- Returns:
619
- `Union[Tuple, CcubedCausalLMOutputWithPast]`: A tuple containing various model outputs or a `CcubedCausalLMOutputWithPast` instance.
620
- The CcubedCausalLMOutputWithPast contains the following fields:
621
- - loss (`torch.FloatTensor`, *optional*): Language modeling loss if labels provided, None otherwise.
622
- - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): Prediction scores.
623
- - past_key_values (`List[torch.FloatTensor]`, *optional*): Pre-computed hidden states for efficient decoding.
624
- - hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Layer hidden states if output_hidden_states=True.
625
- - attentions (`Tuple[torch.FloatTensor]`, *optional*): Layer attention weights if output_attentions=True.
626
- - context_hidden_states (`torch.FloatTensor`, *optional*): Final hidden states from the context tower.
627
- """
628
-
629
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
- output_hidden_states = (
631
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
632
- )
633
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
-
635
-
636
- all_inputs_none = (
637
- input_ids is None and
638
- inputs_embeds is None and
639
- context_input_ids is None and
640
- context_inputs_embeds is None
641
- )
642
-
643
- if all_inputs_none:
644
- raise ValueError("You must provide either non-empty input_ids/inputs_embeds or context_input_ids/context_inputs_embeds.")
645
-
646
-
647
- if context_input_ids is not None or context_inputs_embeds is not None:
648
- context_features = self.context_tower(
649
- input_ids=context_input_ids,
650
- inputs_embeds=context_inputs_embeds,
651
- attention_mask=context_attention_mask,
652
- )
653
- context_features, context_attention_mask = self.connector(
654
- context_features=context_features
655
- )
656
- else:
657
- context_features = None
658
- context_attention_mask = None
659
-
660
-
661
- if inputs_embeds is None and input_ids is not None:
662
- inputs_embeds = self.get_input_embeddings()(input_ids)
663
-
664
- if inputs_embeds is not None:
665
- inputs_embeds, attention_mask, position_ids, labels = self._merge_context_features(
666
- context_features=context_features,
667
- inputs_embeds=inputs_embeds,
668
- attention_mask=attention_mask,
669
- context_attention_mask=context_attention_mask,
670
- position_ids=position_ids,
671
- labels=labels,
672
- )
673
- else:
674
- inputs_embeds = context_features
675
- attention_mask = context_attention_mask
676
-
677
- outputs = self.language_model(
678
- attention_mask=attention_mask,
679
- position_ids=position_ids,
680
- past_key_values=past_key_values,
681
- inputs_embeds=inputs_embeds,
682
- use_cache=use_cache,
683
- output_attentions=output_attentions,
684
- output_hidden_states=output_hidden_states,
685
- return_dict=return_dict,
686
- cache_position=cache_position,
687
- logits_to_keep=logits_to_keep,
688
- )
689
-
690
- logits = outputs[0]
691
-
692
- loss = None
693
- if labels is not None:
694
- shift_logits = logits[..., :-1, :].contiguous()
695
- shift_labels = labels[..., 1:].contiguous()
696
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
697
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
698
-
699
- if not return_dict:
700
- output = (logits,) + outputs[1:]
701
- return (loss,) + output if loss is not None else output
702
-
703
- return CcubedCausalLMOutputWithPast(
704
- loss=loss,
705
- logits=logits,
706
- past_key_values=outputs.past_key_values,
707
- hidden_states=outputs.hidden_states,
708
- attentions=outputs.attentions,
709
- context_hidden_states=context_features,
710
- )
711
-
712
- def prepare_inputs_for_generation(
713
- self,
714
- input_ids,
715
- past_key_values=None,
716
- attention_mask=None,
717
- inputs_embeds=None,
718
- context_features=None,
719
- **kwargs
720
- ):
721
- if past_key_values:
722
- input_ids = input_ids[:, -1:]
723
-
724
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
725
- if inputs_embeds is not None and past_key_values is None:
726
- model_inputs = {"inputs_embeds": inputs_embeds}
727
- else:
728
- model_inputs = {"input_ids": input_ids}
729
-
730
- model_inputs.update(
731
- {
732
- "past_key_values": past_key_values,
733
- "use_cache": kwargs.get("use_cache"),
734
- "attention_mask": attention_mask,
735
- "context_features": context_features,
736
- }
737
- )
738
- return model_inputs
 
1
+ # coding=utf-8
2
+ """PyTorch Ccubed model."""
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+
13
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
+ from transformers.processing_utils import Unpack
17
+ from transformers.image_processing_utils import select_best_resolution
18
+ from transformers.modeling_outputs import ModelOutput
19
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
20
+ from transformers.utils import (
21
+ add_start_docstrings,
22
+ add_start_docstrings_to_model_forward,
23
+ logging,
24
+ replace_return_docstrings,
25
+ is_flash_attn_2_available,
26
+ is_flash_attn_greater_or_equal_2_10
27
+ )
28
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
29
+ from .configuration_c_cubed import CcubedConfig
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ _CONFIG_FOR_DOC = "CcubedConfig"
35
+
36
+
37
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
38
+ """
39
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
40
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
41
+ """
42
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
43
+ if n_rep == 1:
44
+ return hidden_states
45
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
46
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
47
+
48
+
49
+ @dataclass
50
+ class CcubedCausalLMOutputWithPast(ModelOutput):
51
+ """
52
+ Base class for Ccubed causal language model (or autoregressive) outputs.
53
+
54
+ Args:
55
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
56
+ Language modeling loss (for next-token prediction).
57
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
58
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
59
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
60
+ Tuple of `tuple(torch.FloatTensor)` of length `config.context_config.num_layers`, with each tuple having 2 tensors of shape
61
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
62
+
63
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
64
+ `past_key_values` input) to speed up sequential decoding.
65
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ context_hidden_states (`torch.FloatTensor`, *optional*):
77
+ A `torch.FloatTensor` of size (batch_size, sequence_length, hidden_size)`.
78
+ context_hidden_states of the model produced by the context encoder and after projecting the last hidden state.
79
+ """
80
+
81
+ loss: Optional[torch.FloatTensor] = None
82
+ logits: torch.FloatTensor = None
83
+ past_key_values: Optional[List[torch.FloatTensor]] = None
84
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
85
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
86
+ context_hidden_states: Optional[torch.FloatTensor] = None
87
+
88
+
89
+ class CcubedDynamicAttention(nn.Module):
90
+ """
91
+ Attention mechanism adapted for dynamic output size based on Mistral's architecture. This attention layer computes
92
+ the output attention scores which are used to determine the pooling size dynamically.
93
+ """
94
+
95
+ def __init__(self, config: CcubedConfig):
96
+ super().__init__()
97
+
98
+ self.config = config
99
+ self.hidden_size = config.context_config.hidden_size
100
+ self.num_heads = config.context_config.num_attention_heads
101
+ self.head_dim = getattr(config.context_config, "head_dim", self.hidden_size // self.num_heads)
102
+ self.num_key_value_heads = config.context_config.num_key_value_heads
103
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
104
+ self.scaling = self.head_dim ** -0.5
105
+ self.attention_dropout = getattr(self.config.context_config, "attention_dropout", 0.0)
106
+
107
+ # Query, Key, Value, and Output Projections
108
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
109
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
110
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
111
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, 1, bias=False)
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ attention_mask: Optional[torch.Tensor] = None,
117
+ output_attentions: bool = False,
118
+ ):
119
+ # Get input dimensions
120
+ bsz, seq_len, hidden_size = hidden_states.size()
121
+
122
+ # Query, Key, Value projections
123
+ query_states = self.q_proj(hidden_states)
124
+ key_states = self.k_proj(hidden_states)
125
+ value_states = self.v_proj(hidden_states)
126
+
127
+ # Reshape and transpose to [batch_size, num_heads, seq_len, head_dim]
128
+ query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
129
+ key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
130
+ value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
131
+
132
+ # Repeat key and value states for multi-head attention
133
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
134
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
135
+
136
+ # Compute attention scores
137
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
138
+
139
+ # Apply softmax to get attention probabilities
140
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
141
+
142
+ # Apply attention to values
143
+ attn_output = torch.matmul(attn_weights, value_states)
144
+
145
+ # Reshape attention output
146
+ attn_output = attn_output.transpose(1, 2).contiguous()
147
+ attn_output = attn_output.reshape(bsz, seq_len, -1)
148
+
149
+ # Project to output dimension
150
+ attn_output = self.o_proj(attn_output)
151
+
152
+ if not output_attentions:
153
+ attn_weights = None
154
+
155
+ return attn_output, attn_weights
156
+
157
+
158
+ class CcubedDynamicFlashAttention2(CcubedDynamicAttention):
159
+ def __init__(self, config: CcubedConfig):
160
+ super().__init__(config)
161
+ self.is_causal = False # Assuming non-causal attention for this context
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ output_attentions: bool = False,
168
+ **kwargs: Unpack[FlashAttentionKwargs],
169
+ ):
170
+ input_shape = hidden_states.shape[:-1]
171
+ hidden_shape = (*input_shape, -1, self.head_dim)
172
+
173
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
174
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
175
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
+
177
+ sliding_window = None
178
+ if getattr(self.config, "sliding_window", None) is not None:
179
+ sliding_window = self.config.sliding_window
180
+
181
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
182
+
183
+ attn_output, attn_weights = attention_interface(
184
+ self,
185
+ query_states,
186
+ key_states,
187
+ value_states,
188
+ attention_mask,
189
+ dropout=0.0 if not self.training else self.attention_dropout,
190
+ scaling=self.scaling,
191
+ sliding_window=sliding_window, # main diff with Llama
192
+ **kwargs,
193
+ )
194
+
195
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
196
+ attn_output = self.o_proj(attn_output)
197
+ return attn_output, attn_weights
198
+
199
+
200
+ class CcubedDynamicWeightedAvgPool1d(nn.Module):
201
+ """
202
+ A module that dynamically determines the output size based on input
203
+ and performs weighted average pooling with separate attention mechanisms
204
+ for output size estimation and weighted pooling.
205
+ """
206
+ def __init__(self, config, output_size_min=32, output_size_max=131072):
207
+ super().__init__()
208
+ # Attention mechanism for estimating output size
209
+ self.size_estim_attn = CcubedDynamicFlashAttention2(config) # CcubedDynamicAttention(config)
210
+ # Attention mechanism for weighted pooling
211
+ self.imp_estim_attn = CcubedDynamicFlashAttention2(config) # CcubedDynamicAttention(config)
212
+ self.output_size_min = output_size_min
213
+ self.output_size_max = (
214
+ config.context_config.max_position_embeddings if config.context_config.max_position_embeddings is not None else output_size_max
215
+ )
216
+ self.scale_param = nn.Parameter(torch.tensor(0.01))
217
+
218
+ def forward(self, hidden_states, context_attention_mask=None):
219
+ """
220
+ Args:
221
+ x: Input tensor of shape (batch_size, seq_len, hidden_size)
222
+
223
+ Returns:
224
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
+ - pooled_output: Padded tensor of compressed sequences (batch_size, max_pooled_len, hidden_size)
226
+ - attention_mask: Binary mask indicating valid tokens (batch_size, max_pooled_len)
227
+ - dynamic_output_sizes: Dynamic output sizes for each batch (batch_size,)
228
+ """
229
+ batch_size, seq_len, hidden_size = hidden_states.size()
230
+ device = hidden_states.device
231
+
232
+ # Estimate output size using attention mechanism
233
+ # attn_output_size: (batch_size, seq_len, 1)
234
+ attn_output_size, _ = self.size_estim_attn(hidden_states)
235
+
236
+ # Calculate dynamic output sizes for each batch item
237
+ # (batch_size, seq_len, 1) -> (batch_size, 1)
238
+ batch_attn_means = torch.sigmoid(attn_output_size).mean(dim=1)
239
+ scaled_batch_means = batch_attn_means * self.scale_param.to(batch_attn_means.dtype)
240
+
241
+ # Calculate dynamic output sizes (batch_size,)
242
+ dynamic_output_sizes = (
243
+ (scaled_batch_means * (self.output_size_max - self.output_size_min)) + self.output_size_min
244
+ ).int().squeeze(-1)
245
+
246
+ max_pooled_len = dynamic_output_sizes.max().item()
247
+
248
+ # Compute attention weights for weighted pooling
249
+ # attn_output_weights: (batch_size, seq_len, 1)
250
+ attn_output_weights, _ = self.imp_estim_attn(hidden_states)
251
+ # Normalize with sigmoid function for use as weights
252
+ # attention_weights: (batch_size, seq_len)
253
+ attention_weights = torch.sigmoid(attn_output_weights).squeeze(-1)
254
+
255
+ # If context_attention_mask is provided, apply it to zero out weights for invalid tokens
256
+ if context_attention_mask is not None:
257
+ attention_weights = attention_weights * context_attention_mask
258
+
259
+ # Initialize output tensors
260
+ # pooled_output: (batch_size, max_pooled_len, hidden_size)
261
+ pooled_output = torch.zeros(
262
+ batch_size, max_pooled_len, hidden_size,
263
+ device=device, dtype=hidden_states.dtype
264
+ )
265
+ # attention_mask: (batch_size, max_pooled_len)
266
+ attention_mask = torch.zeros(
267
+ batch_size, max_pooled_len,
268
+ dtype=torch.bool, device=device
269
+ )
270
+
271
+ for batch_idx in range(batch_size):
272
+ output_size = dynamic_output_sizes[batch_idx].item()
273
+ item_input = hidden_states[batch_idx] # Shape: (seq_len, hidden_size)
274
+ item_weights = attention_weights[batch_idx] # Shape: (seq_len)
275
+
276
+ # Perform weighted pooling
277
+ pooled_values = []
278
+ batch_attn_mask = torch.zeros(output_size, dtype=torch.bool, device=device)
279
+ # Split the sequence evenly
280
+ intervals = torch.linspace(0, seq_len, steps=output_size + 1).long()
281
+ for i in range(output_size):
282
+ start = intervals[i].item()
283
+ end = intervals[i + 1].item()
284
+ chunk_input = item_input[start:end] # Shape: (chunk_size, hidden_size)
285
+ chunk_weights = item_weights[start:end] # Shape: (chunk_size)
286
+ if chunk_weights.sum() == 0:
287
+ # If the sum of weights is zero, add a zero vector
288
+ pooled_value = torch.zeros(hidden_size, device=device, dtype=hidden_states.dtype)
289
+ else:
290
+ # Calculate weighted average
291
+ weighted_input = chunk_input * chunk_weights.unsqueeze(-1) # Shape: (chunk_size, hidden_size)
292
+ pooled_value = weighted_input.sum(dim=0) / (chunk_weights.sum() + 1e-8) # Shape: (hidden_size)
293
+ batch_attn_mask[i] = True
294
+ pooled_values.append(pooled_value)
295
+
296
+ if pooled_values: # Only stack if there are values
297
+ # Convert the result to a tensor
298
+ pooled_values = torch.stack(pooled_values) # Shape: (output_size, hidden_size)
299
+ # Store the result
300
+ pooled_output[batch_idx, -output_size:] = pooled_values
301
+ attention_mask[batch_idx, -output_size:] = batch_attn_mask
302
+
303
+ return pooled_output, attention_mask, dynamic_output_sizes
304
+
305
+
306
+ class CcubedContextLanguageConnector(nn.Module):
307
+ def __init__(self, config: CcubedConfig):
308
+ super().__init__()
309
+
310
+ self.dynamic_pooling = CcubedDynamicWeightedAvgPool1d(config)
311
+
312
+ self.linear_1 = nn.Linear(
313
+ config.context_config.hidden_size,
314
+ config.text_config.hidden_size,
315
+ bias=True
316
+ )
317
+ self.act = ACT2FN[config.projector_hidden_act]
318
+ self.linear_2 = nn.Linear(
319
+ config.text_config.hidden_size,
320
+ config.text_config.hidden_size,
321
+ bias=True
322
+ )
323
+
324
+ def forward(self, context_features):
325
+ # context_features: [batch_size, seq_len, hidden_size]
326
+ # Apply dynamic adaptive average pooling with attention
327
+ pooled_output, attention_mask, dynamic_output_sizes = self.dynamic_pooling(
328
+ hidden_states=context_features
329
+ )
330
+
331
+ hidden_states = self.linear_1(pooled_output)
332
+ hidden_states = self.act(hidden_states)
333
+ hidden_states = self.linear_2(hidden_states)
334
+
335
+ return hidden_states, attention_mask
336
+
337
+
338
+ class CcubedContextTower(nn.Module):
339
+ def __init__(self, config: CcubedConfig):
340
+ super().__init__()
341
+
342
+ self.tower = AutoModelForCausalLM.from_config(
343
+ config.context_config,
344
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
345
+ )
346
+ self.select_layer = config.context_feature_layer
347
+
348
+ def feature_select(self, llm_outputs):
349
+ hidden_states = llm_outputs.hidden_states
350
+ return hidden_states[self.select_layer]
351
+
352
+ def forward(
353
+ self,
354
+ input_ids,
355
+ inputs_embeds,
356
+ attention_mask
357
+ ):
358
+ outputs = self.tower(
359
+ input_ids=input_ids,
360
+ inputs_embeds=inputs_embeds,
361
+ attention_mask=attention_mask,
362
+ output_hidden_states=True
363
+ )
364
+ features = self.feature_select(outputs)
365
+ return features
366
+
367
+
368
+ class CcubedPreTrainedModel(PreTrainedModel):
369
+ config_class = CcubedConfig
370
+ base_model_prefix = "model"
371
+ supports_gradient_checkpointing = True
372
+ _no_split_modules = [] # ["CcubedContextLanguageConnector", "CcubedContextTower"]
373
+ _skip_keys_device_placement = ["past_key_values"]
374
+ _supports_flash_attn_2 = True
375
+ _supports_sdpa = True
376
+ _supports_cache_class = True
377
+ _supports_quantized_cache = True
378
+ _supports_static_cache = True
379
+
380
+ def _init_weights(self, module):
381
+ std = (
382
+ self.config.initializer_range
383
+ if hasattr(self.config, "initializer_range")
384
+ else self.config.text_config.initializer_range
385
+ )
386
+ if isinstance(module, nn.Linear):
387
+ module.weight.data.normal_(mean=0.0, std=std)
388
+ if module.bias is not None:
389
+ module.bias.data.zero_()
390
+ elif isinstance(module, nn.Embedding):
391
+ module.weight.data.normal_(mean=0.0, std=std)
392
+ if module.padding_idx is not None:
393
+ module.weight.data[module.padding_idx].zero_()
394
+
395
+
396
+ class CcubedForConditionalGeneration(CcubedPreTrainedModel):
397
+ def __init__(self, config: CcubedConfig):
398
+ super().__init__(config)
399
+ self.context_tower = CcubedContextTower(config)
400
+ self.connector = CcubedContextLanguageConnector(config)
401
+
402
+ self.language_model = AutoModelForCausalLM.from_config(
403
+ config.text_config,
404
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
405
+ )
406
+
407
+ self.vocab_size = config.text_config.vocab_size
408
+ self.ignore_index = config.ignore_index if hasattr(config, 'ignore_index') else -100
409
+ self.start_of_context_token_id = config.start_of_context_token_id
410
+ self.end_of_context_token_id = config.end_of_context_token_id
411
+
412
+ self.post_init()
413
+
414
+ def get_input_embeddings(self):
415
+ return self.language_model.get_input_embeddings()
416
+
417
+ def get_context_input_embeddings(self):
418
+ return self.context_tower.tower.get_input_embeddings()
419
+
420
+ def set_input_embeddings(self, value):
421
+ self.language_model.set_input_embeddings(value)
422
+
423
+ def set_context_input_embeddings(self, value):
424
+ self.context_tower.tower.set_input_embeddings(value)
425
+
426
+ def get_output_embeddings(self):
427
+ return self.language_model.get_output_embeddings()
428
+
429
+ def get_context_output_embeddings(self):
430
+ return self.context_tower.tower.get_output_embeddings()
431
+
432
+ def set_output_embeddings(self, new_embeddings):
433
+ self.language_model.set_output_embeddings(new_embeddings)
434
+
435
+ def set_context_output_embeddings(self, new_embeddings):
436
+ self.context_tower.tower.set_output_embeddings(new_embeddings)
437
+
438
+ def set_decoder(self, decoder):
439
+ self.language_model.set_decoder(decoder)
440
+
441
+ def set_context_encoder(self, decoder):
442
+ self.context_tower.tower.set_decoder(decoder)
443
+
444
+ def get_decoder(self):
445
+ return self.language_model.get_decoder()
446
+
447
+ def get_context_encoder(self):
448
+ return self.context_tower.tower.get_decoder()
449
+
450
+ def tie_weights(self):
451
+ return self.language_model.tie_weights()
452
+
453
+ def context_tie_weights(self):
454
+ return self.context_tower.tower.tie_weights()
455
+
456
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
457
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
458
+ # update vocab size
459
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
460
+ self.vocab_size = model_embeds.num_embeddings
461
+ return model_embeds
462
+
463
+ def _merge_context_features(
464
+ self,
465
+ context_features = None,
466
+ inputs_embeds = None,
467
+ attention_mask = None,
468
+ context_attention_mask=None,
469
+ position_ids=None,
470
+ labels=None,
471
+ ):
472
+ if context_features is None:
473
+ return inputs_embeds, attention_mask, position_ids, labels
474
+
475
+ batch_size, seq_length, embed_dim = inputs_embeds.shape
476
+ context_seq_len = context_features.size(1)
477
+
478
+ # Create embeddings for begin and end of context tokens
479
+ begin_context_embed = self.get_input_embeddings()(torch.tensor(self.start_of_context_token_id, device=context_features.device))
480
+ end_context_embed = self.get_input_embeddings()(torch.tensor(self.end_of_context_token_id, device=context_features.device))
481
+
482
+ # Determine the actual lengths of context sequences (excluding padding)
483
+ if context_attention_mask is not None:
484
+ # context_attention_mask: [batch_size, context_seq_len, 1]
485
+ context_attention_mask = context_attention_mask.squeeze(-1) # [batch_size, context_seq_len]
486
+ # Sum over sequence length to get actual lengths
487
+ context_lengths = context_attention_mask.sum(dim=1).long() # [batch_size]
488
+ else:
489
+ # If no context_attention_mask is provided, assume full length
490
+ context_lengths = torch.full((batch_size,), context_seq_len, device=context_features.device, dtype=torch.long)
491
+ context_attention_mask = torch.ones(batch_size, context_seq_len, device=context_features.device, dtype=torch.long)
492
+
493
+ # Rearrange context features to include padding at the beginning
494
+ # Identify the maximum context length (excluding padding)
495
+ max_context_length = context_lengths.max().item()
496
+ # Calculate the amount of padding needed for each sample
497
+ padding_lengths = context_seq_len - context_lengths # [batch_size]
498
+
499
+ # Create new context_features with padding at the beginning
500
+ new_context_features = []
501
+ for i in range(batch_size):
502
+ padding_len = padding_lengths[i].item()
503
+ # Create padding embeddings (zeros)
504
+ padding_embed = torch.zeros(padding_len, embed_dim, device=context_features.device, dtype=context_features.dtype)
505
+ # Get actual context features (excluding padding)
506
+ actual_context = context_features[i, padding_len:context_seq_len]
507
+ # Concatenate padding, begin token, actual context, end token
508
+ sample_context = torch.cat([
509
+ padding_embed,
510
+ begin_context_embed.unsqueeze(0),
511
+ actual_context,
512
+ end_context_embed.unsqueeze(0)
513
+ ], dim=0) # [context_seq_len + 2, embed_dim]
514
+ new_context_features.append(sample_context)
515
+ # Stack to create [batch_size, new_context_seq_len, embed_dim]
516
+ context_features = torch.stack(new_context_features, dim=0)
517
+ new_context_seq_len = context_features.size(1)
518
+
519
+ # Update context_attention_mask accordingly
520
+ new_context_attention_mask = []
521
+ for i in range(batch_size):
522
+ padding_len = padding_lengths[i].item()
523
+ # Create padding mask (zeros)
524
+ padding_mask = torch.zeros(padding_len, device=context_features.device, dtype=attention_mask.dtype)
525
+ # Begin and end token masks
526
+ begin_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
527
+ end_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
528
+ # Actual context attention mask (excluding padding)
529
+ actual_mask = context_attention_mask[i, padding_len:context_seq_len]
530
+ # Concatenate masks
531
+ sample_mask = torch.cat([
532
+ padding_mask,
533
+ begin_attention,
534
+ actual_mask,
535
+ end_attention
536
+ ], dim=0) # [context_seq_len + 2]
537
+ new_context_attention_mask.append(sample_mask)
538
+ # Stack to create [batch_size, new_context_seq_len]
539
+ context_attention_mask = torch.stack(new_context_attention_mask, dim=0)
540
+
541
+ # Concatenate context features with input embeddings
542
+ new_inputs_embeds = torch.cat([context_features, inputs_embeds], dim=1) # [batch_size, total_seq_len, embed_dim]
543
+
544
+ # Concatenate attention masks
545
+ new_attention_mask = torch.cat([context_attention_mask, attention_mask], dim=1)
546
+
547
+ # Create new position_ids
548
+ total_seq_len = new_inputs_embeds.size(1)
549
+ new_position_ids = torch.arange(total_seq_len, device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
550
+
551
+ # Update labels if provided
552
+ if labels is not None:
553
+ # Create ignore labels for context (including padding and special tokens)
554
+ context_labels = torch.full((batch_size, new_context_seq_len), self.ignore_index, device=labels.device, dtype=labels.dtype)
555
+ new_labels = torch.cat([context_labels, labels], dim=1)
556
+ else:
557
+ new_labels = None
558
+
559
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
560
+
561
+
562
+ @replace_return_docstrings(output_type=CcubedCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
563
+ def forward(
564
+ self,
565
+ context_input_ids: torch.LongTensor = None,
566
+ context_inputs_embeds: Optional[torch.FloatTensor] = None,
567
+ context_attention_mask: Optional[torch.Tensor] = None,
568
+ input_ids: torch.LongTensor = None,
569
+ inputs_embeds: Optional[torch.FloatTensor] = None,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ position_ids: Optional[torch.LongTensor] = None,
572
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
573
+ labels: Optional[torch.LongTensor] = None,
574
+ use_cache: Optional[bool] = None,
575
+ output_attentions: Optional[bool] = None,
576
+ output_hidden_states: Optional[bool] = None,
577
+ return_dict: Optional[bool] = None,
578
+ cache_position: Optional[torch.LongTensor] = None,
579
+ logits_to_keep: int = 0,
580
+ ) -> Union[Tuple, CcubedCausalLMOutputWithPast]:
581
+ """
582
+ Perform a forward pass through the Ccubed model, optionally conditioning on context input.
583
+
584
+ Args:
585
+ context_input_ids (`torch.LongTensor` of shape `(batch_size, context_sequence_length)`, *optional*):
586
+ Token IDs of the context input sequence.
587
+ context_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, context_sequence_length, hidden_size)`, *optional*):
588
+ Pre-computed context embeddings. If provided, will not compute embeddings from context_input_ids.
589
+ context_attention_mask (`torch.Tensor` of shape `(batch_size, context_sequence_length)`, *optional*):
590
+ Attention mask for context input sequence.
591
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
592
+ Token IDs of the input sequence.
593
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
594
+ Optionally, instead of passing `input_ids`, you can pass an embedded representation directly.
595
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
596
+ Mask to avoid performing attention on padding token indices.
597
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
598
+ Indices of positions of each input sequence token.
599
+ past_key_values (`List[torch.FloatTensor]`, *optional*):
600
+ Pre-computed hidden-states (key and value tensors) that can be used to speed up sequential decoding.
601
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
602
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
603
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
604
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
605
+ use_cache (`bool`, *optional*):
606
+ If `True`, past key values will be used to speed up decoding.
607
+ output_attentions (`bool`, *optional*):
608
+ If `True`, return the attention tensors for each layer.
609
+ output_hidden_states (`bool`, *optional*):
610
+ If `True`, return the hidden states of all layers.
611
+ return_dict (`bool`, *optional*):
612
+ If `True`, return a `CcubedCausalLMOutputWithPast` instead of a plain tuple.
613
+ num_logits_to_keep (`int`, *optional*):
614
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
615
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
616
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
617
+
618
+ Returns:
619
+ `Union[Tuple, CcubedCausalLMOutputWithPast]`: A tuple containing various model outputs or a `CcubedCausalLMOutputWithPast` instance.
620
+ The CcubedCausalLMOutputWithPast contains the following fields:
621
+ - loss (`torch.FloatTensor`, *optional*): Language modeling loss if labels provided, None otherwise.
622
+ - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): Prediction scores.
623
+ - past_key_values (`List[torch.FloatTensor]`, *optional*): Pre-computed hidden states for efficient decoding.
624
+ - hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Layer hidden states if output_hidden_states=True.
625
+ - attentions (`Tuple[torch.FloatTensor]`, *optional*): Layer attention weights if output_attentions=True.
626
+ - context_hidden_states (`torch.FloatTensor`, *optional*): Final hidden states from the context tower.
627
+ """
628
+
629
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
+ output_hidden_states = (
631
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
632
+ )
633
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
+
635
+
636
+ all_inputs_none = (
637
+ input_ids is None and
638
+ inputs_embeds is None and
639
+ context_input_ids is None and
640
+ context_inputs_embeds is None
641
+ )
642
+
643
+ if all_inputs_none:
644
+ raise ValueError("You must provide either non-empty input_ids/inputs_embeds or context_input_ids/context_inputs_embeds.")
645
+
646
+
647
+ if context_input_ids is not None or context_inputs_embeds is not None:
648
+ context_features = self.context_tower(
649
+ input_ids=context_input_ids,
650
+ inputs_embeds=context_inputs_embeds,
651
+ attention_mask=context_attention_mask,
652
+ )
653
+ context_features, context_attention_mask = self.connector(
654
+ context_features=context_features
655
+ )
656
+ else:
657
+ context_features = None
658
+ context_attention_mask = None
659
+
660
+
661
+ if inputs_embeds is None and input_ids is not None:
662
+ inputs_embeds = self.get_input_embeddings()(input_ids)
663
+
664
+ if inputs_embeds is not None:
665
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_context_features(
666
+ context_features=context_features,
667
+ inputs_embeds=inputs_embeds,
668
+ attention_mask=attention_mask,
669
+ context_attention_mask=context_attention_mask,
670
+ position_ids=position_ids,
671
+ labels=labels,
672
+ )
673
+ else:
674
+ inputs_embeds = context_features
675
+ attention_mask = context_attention_mask
676
+
677
+ outputs = self.language_model(
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ past_key_values=past_key_values,
681
+ inputs_embeds=inputs_embeds,
682
+ use_cache=use_cache,
683
+ output_attentions=output_attentions,
684
+ output_hidden_states=output_hidden_states,
685
+ return_dict=return_dict,
686
+ cache_position=cache_position,
687
+ logits_to_keep=logits_to_keep,
688
+ )
689
+
690
+ logits = outputs[0]
691
+
692
+ loss = None
693
+ if labels is not None:
694
+ shift_logits = logits[..., :-1, :].contiguous()
695
+ shift_labels = labels[..., 1:].contiguous()
696
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
697
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
698
+
699
+ if not return_dict:
700
+ output = (logits,) + outputs[1:]
701
+ return (loss,) + output if loss is not None else output
702
+
703
+ return CcubedCausalLMOutputWithPast(
704
+ loss=loss,
705
+ logits=logits,
706
+ past_key_values=outputs.past_key_values,
707
+ hidden_states=outputs.hidden_states,
708
+ attentions=outputs.attentions,
709
+ context_hidden_states=context_features,
710
+ )