Kwai-Keye commited on
Commit
e0fb953
·
verified ·
1 Parent(s): 4afc13d

upload modeling_keye.py to support non-flash inference

Browse files
Files changed (1) hide show
  1. modeling_keye.py +226 -765
modeling_keye.py CHANGED
@@ -31,19 +31,10 @@ import torch.nn.functional as F
31
  from torch.nn import CrossEntropyLoss
32
 
33
  from transformers.activations import ACT2FN
34
- from transformers.cache_utils import (
35
- Cache,
36
- DynamicCache,
37
- SlidingWindowCache,
38
- StaticCache,
39
- )
40
  from transformers.generation import GenerationMixin
41
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
42
- from transformers.modeling_outputs import (
43
- BaseModelOutputWithPast,
44
- BaseModelOutput,
45
- BaseModelOutputWithPooling,
46
- )
47
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
48
  from transformers.modeling_utils import PreTrainedModel, sdpa_attention_forward
49
  from transformers.activations import GELUActivation, ACT2FN, PytorchGELUTanh
@@ -55,7 +46,7 @@ from transformers.utils import (
55
  logging,
56
  replace_return_docstrings,
57
  torch_int,
58
- is_flash_attn_greater_or_equal_2_10,
59
  )
60
  from .configuration_keye import KeyeConfig, KeyeVisionConfig
61
 
@@ -64,9 +55,9 @@ import warnings
64
  from typing import Any, Callable, Optional, Tuple, Union, List
65
  from torch import nn
66
  from torch.nn.init import _calculate_fan_in_and_fan_out
 
67
 
68
 
69
- assert is_flash_attn_2_available()
70
  if is_flash_attn_2_available():
71
  from flash_attn import flash_attn_varlen_func
72
  from flash_attn.layers.rotary import apply_rotary_emb
@@ -80,7 +71,6 @@ logger = logging.get_logger(__name__)
80
 
81
  _CONFIG_FOR_DOC = "KeyeConfig"
82
 
83
-
84
  class KeyeMLP(nn.Module):
85
  def __init__(self, config, bias: bool = False):
86
  super().__init__()
@@ -92,9 +82,7 @@ class KeyeMLP(nn.Module):
92
  self.act_fn = ACT2FN[config.hidden_act]
93
 
94
  def forward(self, hidden_state):
95
- return self.down_proj(
96
- self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
97
- )
98
 
99
 
100
  def _trunc_normal_(tensor, mean, std, a, b):
@@ -134,11 +122,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
134
 
135
 
136
  def trunc_normal_tf_(
137
- tensor: torch.Tensor,
138
- mean: float = 0.0,
139
- std: float = 1.0,
140
- a: float = -2.0,
141
- b: float = 2.0,
142
  ) -> torch.Tensor:
143
  """Fills the input Tensor with values drawn from a truncated
144
  normal distribution. The values are effectively drawn from the
@@ -196,39 +180,9 @@ def default_flax_embed_init(tensor):
196
  variance_scaling_(tensor, mode="fan_in", distribution="normal")
197
 
198
 
199
- @dataclass
200
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
201
- class SiglipVisionModelOutput(ModelOutput):
202
- """
203
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
204
-
205
- Args:
206
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
207
- The image embeddings obtained by applying the projection layer to the pooler_output.
208
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
209
- Sequence of hidden-states at the output of the last layer of the model.
210
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
211
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
212
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
213
-
214
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
215
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
216
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
217
- sequence_length)`.
218
-
219
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
220
- heads.
221
- """
222
-
223
- image_embeds: Optional[torch.FloatTensor] = None
224
- last_hidden_state: Optional[torch.FloatTensor] = None
225
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
226
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
227
-
228
-
229
  class Projector(nn.Module):
230
 
231
- def __init__(self, text_config: KeyeConfig, vision_config: KeyeVisionConfig):
232
  super().__init__()
233
  self.text_config = text_config
234
  self.vision_config = vision_config
@@ -247,9 +201,7 @@ class Projector(nn.Module):
247
  self.hidden_size, self.text_config.hidden_size, bias=True
248
  )
249
 
250
- def forward(
251
- self, image_features: torch.Tensor, image_grid_thw: List[Tuple[int, int, int]]
252
- ) -> torch.Tensor:
253
  m1, m2 = self.merge_kernel_size
254
  if isinstance(image_features, (list, tuple)):
255
  processed_features = list()
@@ -258,15 +210,7 @@ class Projector(nn.Module):
258
  t, h, w = image_grid
259
  from einops import rearrange
260
 
261
- image_feature = rearrange(
262
- image_feature,
263
- "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
264
- t=t,
265
- h=h // m1,
266
- p1=m1,
267
- w=w // m2,
268
- p2=m2,
269
- )
270
  hidden_states = self.linear_1(image_feature)
271
  hidden_states = self.act(hidden_states)
272
  hidden_states = self.linear_2(hidden_states)
@@ -284,7 +228,6 @@ class Projector(nn.Module):
284
 
285
  return hidden_states.view(*dims, -1)
286
 
287
-
288
  class SiglipVisionEmbeddings(nn.Module):
289
  def __init__(self, config: KeyeVisionConfig):
290
  super().__init__()
@@ -308,19 +251,9 @@ class SiglipVisionEmbeddings(nn.Module):
308
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
309
  self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
310
 
311
- self.register_buffer(
312
- "position_ids",
313
- torch.arange(self.num_positions).expand((1, -1)),
314
- persistent=False,
315
- )
316
 
317
- def interpolate_pos_encoding(
318
- self,
319
- embeddings: torch.Tensor,
320
- height: int,
321
- width: int,
322
- is_after_patchify: bool = False,
323
- ) -> torch.Tensor:
324
  """
325
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
326
  images. This method is also adapted to support torch.jit tracing and no class embeddings.
@@ -343,9 +276,7 @@ class SiglipVisionEmbeddings(nn.Module):
343
  new_width = width // self.patch_size
344
 
345
  sqrt_num_positions = torch_int(num_positions**0.5)
346
- patch_pos_embed = patch_pos_embed.reshape(
347
- 1, sqrt_num_positions, sqrt_num_positions, dim
348
- )
349
  patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
350
 
351
  patch_pos_embed = nn.functional.interpolate(
@@ -373,42 +304,33 @@ class SiglipVisionEmbeddings(nn.Module):
373
  if grid in self.cache_position_embedding:
374
  self.cache_position_count[grid] += 1
375
  return self.cache_position_embedding[grid]
376
-
377
  if len(self.cache_position_embedding) >= max_cache:
378
- min_hit_grid = min(
379
- self.cache_position_count, key=self.cache_position_count.get
380
- )
381
  self.cache_position_count.pop(min_hit_grid)
382
  self.cache_position_embedding.pop(min_hit_grid)
383
-
384
  position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
385
  self.cache_position_count[grid] = 1
386
  self.cache_position_embedding[grid] = position_embedding
387
  return position_embedding
388
 
389
  def forward(
390
- self,
391
- pixel_values: torch.FloatTensor,
392
  position_ids: Optional[torch.Tensor] = None,
393
- image_grid_thw: Optional[
394
- List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
395
- ] = None,
396
- interpolate_pos_encoding=False,
397
  ) -> torch.Tensor:
398
  if pixel_values.dim() == 5:
399
  assert position_ids is not None
400
  from einops import rearrange
401
-
402
  batch_size, squence_len, channel, height, width = pixel_values.shape
403
  target_dtype = self.patch_embedding.weight.dtype
404
  pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
405
- patch_embeds = self.patch_embedding(
406
- pixel_values.to(dtype=target_dtype)
407
- ) # shape = [*, width, grid, grid]
408
  embeddings = patch_embeds.flatten(-2).squeeze(-1)
409
- embeddings = rearrange(
410
- embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len
411
- )
412
 
413
  # todo: not dubug
414
  if interpolate_pos_encoding and image_grid_thw is not None:
@@ -416,21 +338,15 @@ class SiglipVisionEmbeddings(nn.Module):
416
  assert batch_size == 1
417
  start = 0
418
  image_embedding_list = list()
419
- assert (
420
- sum([np.prod(x) for x in flatten_image_grid_thw])
421
- == embeddings.shape[1]
422
- ), (flatten_image_grid_thw, embeddings.shape)
423
  embeddings = embeddings.squeeze(0)
424
  tmp_embeddings = list()
425
  for image_grid in image_grid_thw:
426
  t, h, w = image_grid
427
  end = start + t * h * w
428
- image_embeddings = embeddings[start:end, :]
429
- position_embedding = (
430
- self.interpolate_pos_encoding(image_embeddings, h, w, True)
431
- .squeeze(0)
432
- .repeat(t, 1)
433
- )
434
  image_embeddings = image_embeddings + position_embedding
435
  tmp_embeddings.append(image_embeddings)
436
  start = end
@@ -456,12 +372,8 @@ def eager_attention_forward(
456
  if attention_mask is not None:
457
  attn_weights = attn_weights + attention_mask
458
 
459
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
460
- query.dtype
461
- )
462
- attn_weights = nn.functional.dropout(
463
- attn_weights, p=dropout, training=module.training
464
- )
465
 
466
  attn_output = torch.matmul(attn_weights, value)
467
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -502,9 +414,7 @@ class SiglipAttention(nn.Module):
502
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
503
  """Input shape: Batch x Time x Channel"""
504
 
505
- use_flash_attn = (
506
- cu_seqlens is not None
507
- ) and self.config._attn_implementation == "flash_attention_2"
508
 
509
  batch_size, seq_length, embed_dim = hidden_states.shape
510
 
@@ -513,28 +423,21 @@ class SiglipAttention(nn.Module):
513
  values = self.v_proj(hidden_states)
514
 
515
  if rope_emb is None:
516
- queries = queries.view(
517
- batch_size, seq_length, self.num_heads, self.head_dim
518
- ).transpose(1, 2)
519
- keys = keys.view(
520
- batch_size, seq_length, self.num_heads, self.head_dim
521
- ).transpose(1, 2)
522
- values = values.view(
523
- batch_size, seq_length, self.num_heads, self.head_dim
524
- ).transpose(1, 2)
525
  else:
526
  assert cu_seqlens is not None, "Rope support flash attn only."
527
  cos, sin = rope_emb
528
- queries = queries.view(
529
- batch_size, seq_length, self.num_heads, self.head_dim
530
- )
531
  keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim)
532
- queries, keys = apply_rotary_pos_emb_flashatt(queries, keys, cos, sin)
 
 
 
533
  queries = queries.transpose(1, 2)
534
  keys = keys.transpose(1, 2)
535
- values = values.view(
536
- batch_size, seq_length, self.num_heads, self.head_dim
537
- ).transpose(1, 2)
538
 
539
  if not use_flash_attn:
540
  attention_interface: Callable = eager_attention_forward
@@ -557,25 +460,16 @@ class SiglipAttention(nn.Module):
557
  scaling=self.scale,
558
  dropout=0.0 if not self.training else self.dropout,
559
  )
560
- attn_output = attn_output.reshape(
561
- batch_size, seq_length, embed_dim
562
- ).contiguous()
563
  else:
564
  assert batch_size == 1, hidden_states.shape
565
  queries = queries.transpose(1, 2).squeeze(0)
566
  keys = keys.transpose(1, 2).squeeze(0)
567
  values = values.transpose(1, 2).squeeze(0)
568
 
569
- from flash_attn import flash_attn_func, flash_attn_varlen_func
570
-
571
  max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
572
  max_seqlen_k = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
573
- assert (
574
- cu_seqlens[-1].item()
575
- == queries.shape[0]
576
- == keys.shape[0]
577
- == values.shape[0]
578
- ), (cu_seqlens, queries.shape, keys.shape, values.shape)
579
 
580
  attn_output = flash_attn_varlen_func(
581
  queries,
@@ -841,9 +735,7 @@ class SiglipEncoder(nn.Module):
841
  embed_dim = config.hidden_size
842
  num_heads = config.num_attention_heads
843
  head_dim = embed_dim // num_heads
844
- self.layers = nn.ModuleList(
845
- [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
846
- )
847
  self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
848
  self.gradient_checkpointing = False
849
 
@@ -859,7 +751,6 @@ class SiglipEncoder(nn.Module):
859
 
860
  def build_window_index(self, image_grid, window_size, device):
861
  from einops import rearrange
862
-
863
  window_indices = list()
864
  pad_values = -100
865
  start_window_index = 0
@@ -871,25 +762,16 @@ class SiglipEncoder(nn.Module):
871
  pad_w = (-w) % window_size
872
  assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
873
  window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
874
- window_index = rearrange(
875
- window_index,
876
- "t (h p1) (w p2) -> t (h w) (p1 p2)",
877
- p1=window_size,
878
- p2=window_size,
879
- )
880
  window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
881
  window_index = window_index.reshape(-1)
882
  window_index = window_index[window_index != pad_values]
883
  window_indices.append(window_index + start_window_index)
884
- cu_seqlens_within_windows.append(
885
- window_seqlens.cumsum(0) + start_window_index
886
- )
887
  start_window_index += t * h * w
888
  window_indices = torch.concat(window_indices, dim=0)
889
  cu_seqlens_within_windows = torch.concat(cu_seqlens_within_windows, dim=0)
890
- cu_seqlens_within_windows = F.pad(
891
- cu_seqlens_within_windows, (1, 0), value=0
892
- ).to(torch.int32)
893
  return window_indices, cu_seqlens_within_windows
894
 
895
  # Ignore copy
@@ -901,9 +783,7 @@ class SiglipEncoder(nn.Module):
901
  output_attentions: Optional[bool] = None,
902
  output_hidden_states: Optional[bool] = None,
903
  cu_seqlens: Optional[List[torch.Tensor]] = None,
904
- image_grid_thw: Optional[
905
- List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
906
- ] = None,
907
  height_position_ids: Optional[torch.Tensor] = None,
908
  width_position_ids: Optional[torch.Tensor] = None,
909
  use_rope: Optional[bool] = False,
@@ -936,17 +816,11 @@ class SiglipEncoder(nn.Module):
936
 
937
  vision_or_text = "vision"
938
  assert vision_or_text in ["vision", "text"]
939
- use_window_attn = window_size > 0 and vision_or_text == "vision"
940
  use_rope = (use_rope is True) and (vision_or_text == "vision")
941
- output_attentions = (
942
- output_attentions
943
- if output_attentions is not None
944
- else self.config.output_attentions
945
- )
946
  output_hidden_states = (
947
- output_hidden_states
948
- if output_hidden_states is not None
949
- else self.config.output_hidden_states
950
  )
951
 
952
  encoder_states = () if output_hidden_states else None
@@ -954,17 +828,10 @@ class SiglipEncoder(nn.Module):
954
 
955
  device = inputs_embeds.device
956
  hidden_states = inputs_embeds
957
- attention_mask = (
958
- attention_mask.to(inputs_embeds.dtype)
959
- if attention_mask is not None
960
- else None
961
- )
962
  if use_rope is True:
963
  flatten_image_grid_thw = self.flatten_list(image_grid_thw)
964
- assert (
965
- sum([np.prod(x) for x in flatten_image_grid_thw])
966
- == hidden_states.shape[1]
967
- ), (flatten_image_grid_thw, hidden_states.shape)
968
 
969
  if width_position_ids is None or height_position_ids is None:
970
  split_hids = list()
@@ -977,13 +844,11 @@ class SiglipEncoder(nn.Module):
977
  split_wids.append(sample_wids)
978
  width_position_ids = torch.concat(split_wids, dim=0)
979
  height_position_ids = torch.concat(split_hids, dim=0)
980
-
981
  window_indices, cu_seqlens_within_windows = None, None
982
 
983
  if use_window_attn:
984
- window_indices, cu_seqlens_within_windows = self.build_window_index(
985
- flatten_image_grid_thw, window_size, device
986
- )
987
  reversed_window_indices = window_indices.argsort()
988
  height_position_ids = height_position_ids[window_indices]
989
  width_position_ids = width_position_ids[window_indices]
@@ -998,17 +863,12 @@ class SiglipEncoder(nn.Module):
998
 
999
  rope_emb = None
1000
  window_indices, cu_seqlens_within_windows = None, None
1001
-
1002
  if use_window_attn:
1003
  flatten_image_grid_thw = self.flatten_list(image_grid_thw)
1004
- assert (
1005
- sum([np.prod(x) for x in flatten_image_grid_thw])
1006
- == hidden_states.shape[1]
1007
- ), (flatten_image_grid_thw, hidden_states.shape)
1008
-
1009
- window_indices, cu_seqlens_within_windows = self.build_window_index(
1010
- flatten_image_grid_thw, window_size, device
1011
- )
1012
  reversed_window_indices = window_indices.argsort()
1013
 
1014
  if use_window_attn:
@@ -1020,11 +880,7 @@ class SiglipEncoder(nn.Module):
1020
 
1021
  for encoder_layer in self.layers:
1022
  if output_hidden_states:
1023
- encoder_states = encoder_states + (
1024
- (hidden_states[:, reversed_window_indices, :],)
1025
- if use_window_attn
1026
- else (hidden_states,)
1027
- )
1028
  if self.gradient_checkpointing and self.training:
1029
  layer_outputs = self._gradient_checkpointing_func(
1030
  encoder_layer.__call__,
@@ -1070,17 +926,13 @@ class SiglipVisionTransformer(nn.Module):
1070
  self.embeddings = SiglipVisionEmbeddings(config)
1071
  self.encoder = SiglipEncoder(config)
1072
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1073
- self.use_head = (
1074
- True if not hasattr(config, "vision_use_head") else config.vision_use_head
1075
- )
1076
  if self.use_head:
1077
  self.head = SiglipMultiheadAttentionPoolingHead(config)
1078
 
1079
  # @can_return_tuple
1080
  @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1081
- @replace_return_docstrings(
1082
- output_type=BaseModelOutputWithPooling, config_class=KeyeVisionConfig
1083
- )
1084
  def forward(
1085
  self,
1086
  pixel_values,
@@ -1096,9 +948,7 @@ class SiglipVisionTransformer(nn.Module):
1096
  cu_seqlens: Optional[List[torch.Tensor]] = None,
1097
  padding_mask: Optional[torch.Tensor] = None,
1098
  vision_return_embed_list: Optional[bool] = False,
1099
- image_grid_thw: Optional[
1100
- List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
1101
- ] = None,
1102
  return_pooler_output: Optional[bool] = True,
1103
  use_rope: Optional[bool] = False,
1104
  window_size: Optional[bool] = -1,
@@ -1107,21 +957,15 @@ class SiglipVisionTransformer(nn.Module):
1107
  Returns:
1108
 
1109
  """
1110
- output_attentions = (
1111
- output_attentions
1112
- if output_attentions is not None
1113
- else self.config.output_attentions
1114
- )
1115
  output_hidden_states = (
1116
- output_hidden_states
1117
- if output_hidden_states is not None
1118
- else self.config.output_hidden_states
1119
  )
1120
  hidden_states = self.embeddings(
1121
- pixel_values,
1122
- interpolate_pos_encoding=interpolate_pos_encoding,
1123
  position_ids=position_ids,
1124
- image_grid_thw=image_grid_thw,
1125
  )
1126
 
1127
  encoder_outputs: BaseModelOutput = self.encoder(
@@ -1157,32 +1001,22 @@ class SiglipVisionTransformer(nn.Module):
1157
  token_indices = (sample_index == sample_idx).nonzero().flatten()
1158
  sample_hidden_state = hidden_state[token_indices]
1159
  sample_hidden_state_list.append(sample_hidden_state)
1160
-
1161
  if not vision_return_embed_list:
1162
- max_length = max(
1163
- [_state.shape[0] for _state in sample_hidden_state_list]
1164
- )
1165
  tmp_sample_hidden_state_list = list()
1166
  padding_mask = list()
1167
  for idx, _state in enumerate(sample_hidden_state_list):
1168
  padding_length = max_length - _state.shape[0]
1169
- mask = _state.new_zeros(size=(max_length,), dtype=torch.int64)
1170
- mask[-padding_length:] = 1
1171
  padding_mask.append(mask)
1172
  padding = _state.new_zeros(size=(padding_length, dim))
1173
  new_state = torch.concat([_state, padding], dim=0)
1174
  tmp_sample_hidden_state_list.append(new_state)
1175
- sample_hidden_state = torch.stack(
1176
- tmp_sample_hidden_state_list, dim=0
1177
- )
1178
- padding_mask = (
1179
- torch.stack(padding_mask, dim=0)
1180
- .float()
1181
- .to(last_hidden_state.dtype)
1182
- )
1183
- pooler_output = self.head(
1184
- sample_hidden_state, key_padding_mask=padding_mask
1185
- )
1186
  else:
1187
  pooler_output = list()
1188
  for state in sample_hidden_state_list:
@@ -1206,15 +1040,15 @@ class SiglipVisionTransformer(nn.Module):
1206
  hidden_states=encoder_outputs.hidden_states,
1207
  attentions=encoder_outputs.attentions,
1208
  )
1209
-
1210
  sample_hidden_state = list()
1211
  assert cu_seqlens is not None
1212
  for i in range(cu_seqlens.shape[0] - 1):
1213
  start = cu_seqlens[i]
1214
  end = cu_seqlens[i + 1]
1215
- tensor = last_hidden_state[:, start:end, :].squeeze(0)
1216
  sample_hidden_state.append(tensor)
1217
-
1218
  return BaseModelOutputWithPooling(
1219
  last_hidden_state=sample_hidden_state,
1220
  pooler_output=None,
@@ -1230,9 +1064,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1230
  super().__init__()
1231
 
1232
  self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1233
- self.attention = torch.nn.MultiheadAttention(
1234
- config.hidden_size, config.num_attention_heads, batch_first=True
1235
- )
1236
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
  self.mlp = SiglipMLP(config)
1238
 
@@ -1240,9 +1072,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1240
  batch_size = hidden_state.shape[0]
1241
  probe = self.probe.repeat(batch_size, 1, 1)
1242
 
1243
- hidden_state = self.attention(
1244
- probe, hidden_state, hidden_state, key_padding_mask=key_padding_mask
1245
- )[0]
1246
 
1247
  residual = hidden_state
1248
  hidden_state = self.layernorm(hidden_state)
@@ -1272,9 +1102,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1272
 
1273
  # @can_return_tuple
1274
  @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1275
- @replace_return_docstrings(
1276
- output_type=BaseModelOutputWithPooling, config_class=KeyeVisionConfig
1277
- )
1278
  def forward(
1279
  self,
1280
  pixel_values,
@@ -1284,9 +1112,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1284
  interpolate_pos_encoding: bool = False,
1285
  position_ids: Optional[torch.Tensor] = None,
1286
  vision_return_embed_list: Optional[bool] = False,
1287
- image_grid_thw: Optional[
1288
- List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
1289
- ] = None,
1290
  cu_seqlens: Optional[List[torch.Tensor]] = None,
1291
  return_pooler_output: Optional[bool] = True,
1292
  use_rope: Optional[bool] = False,
@@ -1331,6 +1157,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1331
  )
1332
 
1333
 
 
1334
  class Qwen3RMSNorm(nn.Module):
1335
  def __init__(self, hidden_size, eps=1e-6):
1336
  """
@@ -1377,6 +1204,7 @@ def apply_rotary_pos_emb_flashatt(
1377
  return q_embed, k_embed
1378
 
1379
 
 
1380
  def rotate_half(x):
1381
  """Rotates half the hidden dims of the input."""
1382
  x1 = x[..., : x.shape[-1] // 2]
@@ -1397,156 +1225,6 @@ def apply_rotary_pos_emb_vision(
1397
  k_embed = k_embed.to(orig_k_dtype)
1398
  return q_embed, k_embed
1399
 
1400
-
1401
- class KeyeVisionAttention(nn.Module):
1402
- def __init__(self, dim: int, num_heads: int = 16) -> None:
1403
- super().__init__()
1404
- self.num_heads = num_heads
1405
- self.head_dim = dim // num_heads
1406
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
1407
- self.proj = nn.Linear(dim, dim)
1408
-
1409
- def forward(
1410
- self,
1411
- hidden_states: torch.Tensor,
1412
- cu_seqlens: torch.Tensor,
1413
- rotary_pos_emb: Optional[torch.Tensor] = None,
1414
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1415
- ) -> torch.Tensor:
1416
- seq_length = hidden_states.shape[0]
1417
- q, k, v = (
1418
- self.qkv(hidden_states)
1419
- .reshape(seq_length, self.num_heads, 3, -1)
1420
- .permute(2, 0, 1, 3)
1421
- .unbind(0)
1422
- )
1423
- if position_embeddings is None:
1424
- logger.warning_once(
1425
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
1426
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
1427
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
1428
- "removed and `position_embeddings` will be mandatory."
1429
- )
1430
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1431
- cos = emb.cos()
1432
- sin = emb.sin()
1433
- else:
1434
- cos, sin = position_embeddings
1435
- q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
1436
-
1437
- attention_mask = torch.full(
1438
- [1, seq_length, seq_length],
1439
- torch.finfo(q.dtype).min,
1440
- device=q.device,
1441
- dtype=q.dtype,
1442
- )
1443
- for i in range(1, len(cu_seqlens)):
1444
- attention_mask[
1445
- ...,
1446
- cu_seqlens[i - 1] : cu_seqlens[i],
1447
- cu_seqlens[i - 1] : cu_seqlens[i],
1448
- ] = 0
1449
-
1450
- q = q.transpose(0, 1)
1451
- k = k.transpose(0, 1)
1452
- v = v.transpose(0, 1)
1453
- attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1454
- attn_weights = attn_weights + attention_mask
1455
- attn_weights = nn.functional.softmax(
1456
- attn_weights, dim=-1, dtype=torch.float32
1457
- ).to(q.dtype)
1458
- attn_output = torch.matmul(attn_weights, v)
1459
- attn_output = attn_output.transpose(0, 1)
1460
- attn_output = attn_output.reshape(seq_length, -1)
1461
- attn_output = self.proj(attn_output)
1462
- return attn_output
1463
-
1464
-
1465
- class KeyeVisionSdpaAttention(nn.Module):
1466
- def __init__(self, dim: int, num_heads: int = 16) -> None:
1467
- super().__init__()
1468
- self.num_heads = num_heads
1469
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
1470
- self.proj = nn.Linear(dim, dim)
1471
-
1472
- def forward(
1473
- self,
1474
- hidden_states: torch.Tensor,
1475
- cu_seqlens: torch.Tensor,
1476
- rotary_pos_emb: Optional[torch.Tensor] = None,
1477
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1478
- ) -> torch.Tensor:
1479
- seq_length = hidden_states.shape[0]
1480
- # q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
1481
- q, k, v = (
1482
- self.qkv(hidden_states)
1483
- .reshape(seq_length, self.num_heads, 3, -1)
1484
- .permute(2, 0, 1, 3)
1485
- .unbind(0)
1486
- )
1487
- if position_embeddings is None:
1488
- logger.warning_once(
1489
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
1490
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
1491
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
1492
- "removed and `position_embeddings` will be mandatory."
1493
- )
1494
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1495
- cos = emb.cos()
1496
- sin = emb.sin()
1497
- else:
1498
- cos, sin = position_embeddings
1499
- q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
1500
-
1501
- attention_mask = torch.zeros(
1502
- [1, seq_length, seq_length], device=q.device, dtype=torch.bool
1503
- )
1504
- for i in range(1, len(cu_seqlens)):
1505
- attention_mask[
1506
- ...,
1507
- cu_seqlens[i - 1] : cu_seqlens[i],
1508
- cu_seqlens[i - 1] : cu_seqlens[i],
1509
- ] = True
1510
- q = q.transpose(0, 1)
1511
- k = k.transpose(0, 1)
1512
- v = v.transpose(0, 1)
1513
- attn_output = F.scaled_dot_product_attention(
1514
- q, k, v, attention_mask, dropout_p=0.0
1515
- )
1516
- attn_output = attn_output.transpose(0, 1)
1517
- attn_output = attn_output.reshape(seq_length, -1)
1518
- attn_output = self.proj(attn_output)
1519
- return attn_output
1520
-
1521
-
1522
- class KeyeVisionBlock(nn.Module):
1523
- def __init__(self, config, attn_implementation: str = "sdpa") -> None:
1524
- super().__init__()
1525
- self.norm1 = Qwen3RMSNorm(config.hidden_size, eps=1e-6)
1526
- self.norm2 = Qwen3RMSNorm(config.hidden_size, eps=1e-6)
1527
- assert attn_implementation == "flash_attention_2"
1528
- self.attn = QWEN3_ATTENTION_CLASSES[attn_implementation](
1529
- config.hidden_size, num_heads=config.num_heads
1530
- )
1531
- self.mlp = KeyeMLP(config, bias=True)
1532
-
1533
- def forward(
1534
- self,
1535
- hidden_states: torch.Tensor,
1536
- cu_seqlens: torch.Tensor,
1537
- rotary_pos_emb: Optional[torch.Tensor] = None,
1538
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1539
- ) -> torch.Tensor:
1540
- hidden_states = hidden_states + self.attn(
1541
- self.norm1(hidden_states),
1542
- cu_seqlens=cu_seqlens,
1543
- rotary_pos_emb=rotary_pos_emb,
1544
- position_embeddings=position_embeddings,
1545
- )
1546
- hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
1547
- return hidden_states
1548
-
1549
-
1550
  Keye_START_DOCSTRING = r"""
1551
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1552
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -1572,7 +1250,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
1572
  config_class = KeyeConfig
1573
  base_model_prefix = "model"
1574
  supports_gradient_checkpointing = True
1575
- _no_split_modules = ["KeyeDecoderLayer", "KeyeVisionBlock"]
1576
  _skip_keys_device_placement = "past_key_values"
1577
  _supports_flash_attn_2 = True
1578
  _supports_sdpa = True
@@ -1591,6 +1269,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
1591
  module.weight.data[module.padding_idx].zero_()
1592
 
1593
 
 
1594
  class SigLIPRotaryEmbedding(nn.Module):
1595
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
1596
  super().__init__()
@@ -1599,15 +1278,11 @@ class SigLIPRotaryEmbedding(nn.Module):
1599
  self.rope_init()
1600
 
1601
  def rope_init(self):
1602
- inv_freq = 1.0 / (
1603
- self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
1604
- )
1605
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1606
 
1607
  def forward(self, seqlen: int) -> torch.Tensor:
1608
- seq = torch.arange(
1609
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
1610
- )
1611
  freqs = torch.outer(seq, self.inv_freq)
1612
  return freqs
1613
 
@@ -1634,19 +1309,15 @@ class KeyeRotaryEmbedding(nn.Module):
1634
  else:
1635
  # BC: "rope_type" was originally "type"
1636
  if config.rope_scaling is not None:
1637
- self.rope_type = config.rope_scaling.get(
1638
- "rope_type", config.rope_scaling.get("type")
1639
- )
1640
  else:
1641
  self.rope_type = "default"
1642
  self.max_seq_len_cached = config.max_position_embeddings
1643
  self.original_max_seq_len = config.max_position_embeddings
1644
-
1645
  # BC: "rope_type" was originally "type"
1646
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1647
- self.rope_type = config.rope_scaling.get(
1648
- "rope_type", config.rope_scaling.get("type")
1649
- )
1650
  else:
1651
  self.rope_type = "default"
1652
  self.max_seq_len_cached = config.max_position_embeddings
@@ -1670,15 +1341,10 @@ class KeyeRotaryEmbedding(nn.Module):
1670
  inv_freq, self.attention_scaling = self.rope_init_fn(
1671
  self.config, device, seq_len=seq_len, **self.rope_kwargs
1672
  )
1673
- self.register_buffer(
1674
- "inv_freq", inv_freq, persistent=False
1675
- ) # TODO joao: may break with compilation
1676
  self.max_seq_len_cached = seq_len
1677
 
1678
- if (
1679
- seq_len < self.original_max_seq_len
1680
- and self.max_seq_len_cached > self.original_max_seq_len
1681
- ): # reset
1682
  self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1683
  self.max_seq_len_cached = self.original_max_seq_len
1684
 
@@ -1689,25 +1355,13 @@ class KeyeRotaryEmbedding(nn.Module):
1689
 
1690
  # Core RoPE block. In contrast to other models, Keye has different position ids for the grids
1691
  # So we expand the inv_freq to shape (3, ...)
1692
- inv_freq_expanded = (
1693
- self.inv_freq[None, None, :, None]
1694
- .float()
1695
- .expand(3, position_ids.shape[1], -1, 1)
1696
- )
1697
- position_ids_expanded = position_ids[
1698
- :, :, None, :
1699
- ].float() # shape (3, bs, 1, positions)
1700
  # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
1701
  device_type = x.device.type
1702
- device_type = (
1703
- device_type
1704
- if isinstance(device_type, str) and device_type != "mps"
1705
- else "cpu"
1706
- )
1707
  with torch.autocast(device_type=device_type, enabled=False):
1708
- freqs = (
1709
- inv_freq_expanded.float() @ position_ids_expanded.float()
1710
- ).transpose(2, 3)
1711
  emb = torch.cat((freqs, freqs), dim=-1)
1712
  cos = emb.cos()
1713
  sin = emb.sin()
@@ -1777,12 +1431,12 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
1777
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
1778
  """
1779
  mrope_section = mrope_section * 2
1780
- cos = torch.cat(
1781
- [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1
1782
- ).unsqueeze(unsqueeze_dim)
1783
- sin = torch.cat(
1784
- [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1
1785
- ).unsqueeze(unsqueeze_dim)
1786
 
1787
  q_embed = (q * cos) + (rotate_half(q) * sin)
1788
  k_embed = (k * cos) + (rotate_half(k) * sin)
@@ -1797,9 +1451,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1797
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1798
  if n_rep == 1:
1799
  return hidden_states
1800
- hidden_states = hidden_states[:, :, None, :, :].expand(
1801
- batch, num_key_value_heads, n_rep, slen, head_dim
1802
- )
1803
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1804
 
1805
 
@@ -1822,43 +1474,27 @@ class KeyeAttention(nn.Module):
1822
 
1823
  self.hidden_size = config.hidden_size
1824
  self.num_heads = config.num_attention_heads
1825
- self.head_dim = getattr(
1826
- config, "head_dim", config.hidden_size // config.num_attention_heads
1827
- )
1828
  self.num_key_value_heads = config.num_key_value_heads
1829
- self.num_key_value_groups = (
1830
- config.num_attention_heads // config.num_key_value_heads
1831
- )
1832
  self.is_causal = True
1833
  self.attention_dropout = config.attention_dropout
1834
  self.rope_scaling = config.rope_scaling
1835
 
1836
  self.q_proj = nn.Linear(
1837
- config.hidden_size,
1838
- config.num_attention_heads * self.head_dim,
1839
- bias=config.attention_bias,
1840
  )
1841
  self.k_proj = nn.Linear(
1842
- config.hidden_size,
1843
- config.num_key_value_heads * self.head_dim,
1844
- bias=config.attention_bias,
1845
  )
1846
  self.v_proj = nn.Linear(
1847
- config.hidden_size,
1848
- config.num_key_value_heads * self.head_dim,
1849
- bias=config.attention_bias,
1850
  )
1851
  self.o_proj = nn.Linear(
1852
- config.num_attention_heads * self.head_dim,
1853
- config.hidden_size,
1854
- bias=config.attention_bias,
1855
  )
1856
- self.q_norm = Qwen3RMSNorm(
1857
- self.head_dim, eps=config.rms_norm_eps
1858
- ) # unlike olmo, only on the head dim!
1859
- self.k_norm = Qwen3RMSNorm(
1860
- self.head_dim, eps=config.rms_norm_eps
1861
- ) # thus post q_norm does not need reshape
1862
 
1863
  self.rotary_emb = KeyeRotaryEmbedding(config=config)
1864
 
@@ -1871,18 +1507,12 @@ class KeyeAttention(nn.Module):
1871
  output_attentions: bool = False,
1872
  use_cache: bool = False,
1873
  cache_position: Optional[torch.LongTensor] = None,
1874
- position_embeddings: Optional[
1875
- Tuple[torch.Tensor, torch.Tensor]
1876
- ] = None, # necessary, but kept here for BC
1877
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1878
  bsz, q_len, _ = hidden_states.size()
1879
 
1880
- query_states = self.q_norm(
1881
- self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
1882
- )
1883
- key_states = self.k_norm(
1884
- self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
1885
- )
1886
  value_states = self.v_proj(hidden_states)
1887
 
1888
  query_states = query_states.transpose(1, 2)
@@ -1895,22 +1525,15 @@ class KeyeAttention(nn.Module):
1895
  )
1896
 
1897
  if past_key_value is not None:
1898
- cache_kwargs = {
1899
- "sin": sin,
1900
- "cos": cos,
1901
- "cache_position": cache_position,
1902
- } # Specific to RoPE models
1903
- key_states, value_states = past_key_value.update(
1904
- key_states, value_states, self.layer_idx, cache_kwargs
1905
- )
1906
 
1907
  # repeat k/v heads if n_kv_heads < n_heads
1908
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1909
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1910
 
1911
- attn_weights = torch.matmul(
1912
- query_states, key_states.transpose(2, 3)
1913
- ) / math.sqrt(self.head_dim)
1914
 
1915
  if attention_mask is not None: # no matter the length, we just slice it
1916
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
@@ -1919,17 +1542,11 @@ class KeyeAttention(nn.Module):
1919
  # Fix precision issues in float16 inference
1920
  # Replace inf values with zeros in attention weights to prevent NaN propagation
1921
  if query_states.dtype == torch.float16:
1922
- attn_weights = torch.where(
1923
- torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
1924
- )
1925
 
1926
  # upcast attention to fp32
1927
- attn_weights = nn.functional.softmax(
1928
- attn_weights, dim=-1, dtype=torch.float32
1929
- ).to(query_states.dtype)
1930
- attn_weights = nn.functional.dropout(
1931
- attn_weights, p=self.attention_dropout, training=self.training
1932
- )
1933
  attn_output = torch.matmul(attn_weights, value_states)
1934
 
1935
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -1975,19 +1592,15 @@ class KeyeFlashAttention2(KeyeAttention):
1975
  output_attentions: bool = False,
1976
  use_cache: bool = False,
1977
  cache_position: Optional[torch.LongTensor] = None,
1978
- position_embeddings: Optional[
1979
- Tuple[torch.Tensor, torch.Tensor]
1980
- ] = None, # necessary, but kept here for BC
1981
  cu_seqlens: Optional[torch.Tensor] = None,
1982
- sliding_window=-1,
1983
  **kwargs,
1984
  ):
1985
  bsz, q_len, _ = hidden_states.size()
1986
- q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
1987
  query_states = self.q_norm(q)
1988
- key_states = self.k_norm(
1989
- self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
1990
- )
1991
  value_states = self.v_proj(hidden_states)
1992
 
1993
  query_states = query_states.transpose(1, 2)
@@ -2001,20 +1614,14 @@ class KeyeFlashAttention2(KeyeAttention):
2001
  )
2002
 
2003
  if past_key_value is not None:
2004
- cache_kwargs = {
2005
- "sin": sin,
2006
- "cos": cos,
2007
- "cache_position": cache_position,
2008
- } # Specific to RoPE models
2009
- key_states, value_states = past_key_value.update(
2010
- key_states, value_states, self.layer_idx, cache_kwargs
2011
- )
2012
 
2013
  # repeat k/v heads if n_kv_heads < n_heads
2014
  key_states = repeat_kv(key_states, self.num_key_value_groups)
2015
  value_states = repeat_kv(value_states, self.num_key_value_groups)
2016
  dropout_rate = 0.0 if not self.training else self.attention_dropout
2017
-
2018
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
2019
  # therefore the input hidden states gets silently casted in float32. Hence, we need
2020
  # cast them back in float16 just to be sure everything works as expected.
@@ -2068,7 +1675,7 @@ class KeyeFlashAttention2(KeyeAttention):
2068
  max_seqlen,
2069
  dropout_p=dropout_rate,
2070
  window_size=(sliding_window, sliding_window),
2071
- causal=self.is_causal,
2072
  )
2073
  else:
2074
  attn_output = _flash_attention_forward(
@@ -2108,9 +1715,7 @@ class KeyeSdpaAttention(KeyeAttention):
2108
  output_attentions: bool = False,
2109
  use_cache: bool = False,
2110
  cache_position: Optional[torch.LongTensor] = None,
2111
- position_embeddings: Optional[
2112
- Tuple[torch.Tensor, torch.Tensor]
2113
- ] = None, # necessary, but kept here for BC
2114
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
2115
  if output_attentions:
2116
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -2131,12 +1736,8 @@ class KeyeSdpaAttention(KeyeAttention):
2131
 
2132
  bsz, q_len, _ = hidden_states.size()
2133
 
2134
- query_states = self.q_norm(
2135
- self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
2136
- )
2137
- key_states = self.k_norm(
2138
- self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
2139
- )
2140
  value_states = self.v_proj(hidden_states)
2141
 
2142
  query_states = query_states.transpose(1, 2)
@@ -2149,14 +1750,8 @@ class KeyeSdpaAttention(KeyeAttention):
2149
  )
2150
 
2151
  if past_key_value is not None:
2152
- cache_kwargs = {
2153
- "sin": sin,
2154
- "cos": cos,
2155
- "cache_position": cache_position,
2156
- } # Specific to RoPE models
2157
- key_states, value_states = past_key_value.update(
2158
- key_states, value_states, self.layer_idx, cache_kwargs
2159
- )
2160
 
2161
  key_states = repeat_kv(key_states, self.num_key_value_groups)
2162
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -2194,6 +1789,7 @@ class KeyeSdpaAttention(KeyeAttention):
2194
  return attn_output, None, past_key_value
2195
 
2196
 
 
2197
  QWEN3_ATTENTION_CLASSES = {
2198
  "eager": KeyeAttention,
2199
  "flash_attention_2": KeyeFlashAttention2,
@@ -2205,24 +1801,17 @@ class KeyeDecoderLayer(nn.Module):
2205
  def __init__(self, config: KeyeConfig, layer_idx: int):
2206
  super().__init__()
2207
  self.hidden_size = config.hidden_size
2208
-
2209
- if (
2210
- config.use_sliding_window
2211
- and config._attn_implementation != "flash_attention_2"
2212
- ):
2213
  logger.warning_once(
2214
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
2215
  "unexpected results may be encountered."
2216
  )
2217
 
2218
- self.self_attn = QWEN3_ATTENTION_CLASSES[config._attn_implementation](
2219
- config, layer_idx
2220
- )
2221
  self.mlp = Qwen3MLP(config)
2222
  self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2223
- self.post_attention_layernorm = Qwen3RMSNorm(
2224
- config.hidden_size, eps=config.rms_norm_eps
2225
- )
2226
 
2227
  def forward(
2228
  self,
@@ -2233,13 +1822,9 @@ class KeyeDecoderLayer(nn.Module):
2233
  output_attentions: Optional[bool] = False,
2234
  use_cache: Optional[bool] = False,
2235
  cache_position: Optional[torch.LongTensor] = None,
2236
- position_embeddings: Optional[
2237
- Tuple[torch.Tensor, torch.Tensor]
2238
- ] = None, # necessary, but kept here for BC
2239
  **kwargs,
2240
- ) -> Tuple[
2241
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
2242
- ]:
2243
  """
2244
  Args:
2245
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -2275,7 +1860,7 @@ class KeyeDecoderLayer(nn.Module):
2275
  use_cache=use_cache,
2276
  cache_position=cache_position,
2277
  position_embeddings=position_embeddings,
2278
- **kwargs,
2279
  )
2280
 
2281
  hidden_states = residual + hidden_states
@@ -2291,6 +1876,7 @@ class KeyeDecoderLayer(nn.Module):
2291
  if output_attentions:
2292
  outputs += (self_attn_weights,)
2293
 
 
2294
  if use_cache:
2295
  outputs += (present_key_value,)
2296
 
@@ -2307,14 +1893,9 @@ class Qwen3Model(Qwen3PreTrainedModel):
2307
  self.padding_idx = config.pad_token_id
2308
  self.vocab_size = config.vocab_size
2309
 
2310
- self.embed_tokens = nn.Embedding(
2311
- config.vocab_size, config.hidden_size, self.padding_idx
2312
- )
2313
  self.layers = nn.ModuleList(
2314
- [
2315
- KeyeDecoderLayer(config, layer_idx)
2316
- for layer_idx in range(config.num_hidden_layers)
2317
- ]
2318
  )
2319
  self._attn_implementation = config._attn_implementation
2320
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -2342,28 +1923,18 @@ class Qwen3Model(Qwen3PreTrainedModel):
2342
  output_hidden_states: Optional[bool] = None,
2343
  return_dict: Optional[bool] = None,
2344
  cache_position: Optional[torch.LongTensor] = None,
2345
- **kwargs,
2346
  ) -> Union[Tuple, BaseModelOutputWithPast]:
2347
- output_attentions = (
2348
- output_attentions
2349
- if output_attentions is not None
2350
- else self.config.output_attentions
2351
- )
2352
  output_hidden_states = (
2353
- output_hidden_states
2354
- if output_hidden_states is not None
2355
- else self.config.output_hidden_states
2356
  )
2357
  use_cache = use_cache if use_cache is not None else self.config.use_cache
2358
 
2359
- return_dict = (
2360
- return_dict if return_dict is not None else self.config.use_return_dict
2361
- )
2362
 
2363
  if (input_ids is None) ^ (inputs_embeds is not None):
2364
- raise ValueError(
2365
- "You must specify exactly one of input_ids or inputs_embeds"
2366
- )
2367
 
2368
  if self.gradient_checkpointing and self.training:
2369
  if use_cache:
@@ -2380,29 +1951,19 @@ class Qwen3Model(Qwen3PreTrainedModel):
2380
  inputs_embeds = self.embed_tokens(input_ids)
2381
 
2382
  if cache_position is None:
2383
- past_seen_tokens = (
2384
- past_key_values.get_seq_length() if past_key_values is not None else 0
2385
- )
2386
  cache_position = torch.arange(
2387
- past_seen_tokens,
2388
- past_seen_tokens + inputs_embeds.shape[1],
2389
- device=inputs_embeds.device,
2390
  )
2391
 
2392
  # the hard coded `3` is for temporal, height and width.
2393
  if position_ids is None:
2394
- position_ids = cache_position.view(1, 1, -1).expand(
2395
- 3, inputs_embeds.shape[0], -1
2396
- )
2397
  elif position_ids.dim() == 2:
2398
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
2399
 
2400
  causal_mask = self._update_causal_mask(
2401
- attention_mask,
2402
- inputs_embeds,
2403
- cache_position,
2404
- past_key_values,
2405
- output_attentions,
2406
  )
2407
  hidden_states = inputs_embeds
2408
 
@@ -2462,11 +2023,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
2462
  next_cache = next_decoder_cache if use_cache else None
2463
 
2464
  if not return_dict:
2465
- return tuple(
2466
- v
2467
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
2468
- if v is not None
2469
- )
2470
  return BaseModelOutputWithPast(
2471
  last_hidden_state=hidden_states,
2472
  past_key_values=next_cache,
@@ -2484,9 +2041,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
2484
  ):
2485
  if self.config._attn_implementation == "flash_attention_2":
2486
  if attention_mask is not None and past_key_values is not None:
2487
- is_padding_right = (
2488
- attention_mask[:, -1].sum().item() != input_tensor.size()[0]
2489
- )
2490
  if is_padding_right:
2491
  raise ValueError(
2492
  "You are attempting to perform batched generation with padding_side='right'"
@@ -2500,9 +2055,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
2500
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2501
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2502
  # to infer the attention mask.
2503
- past_seen_tokens = (
2504
- past_key_values.get_seq_length() if past_key_values is not None else 0
2505
- )
2506
  using_static_cache = isinstance(past_key_values, StaticCache)
2507
  using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2508
 
@@ -2557,9 +2110,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
2557
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2558
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2559
  # Details: https://github.com/pytorch/pytorch/issues/110213
2560
- causal_mask = AttentionMaskConverter._unmask_unattended(
2561
- causal_mask, min_dtype
2562
- )
2563
 
2564
  return causal_mask
2565
 
@@ -2605,41 +2156,31 @@ class Qwen3Model(Qwen3PreTrainedModel):
2605
  else:
2606
  min_dtype = torch.finfo(dtype).min
2607
  causal_mask = torch.full(
2608
- (sequence_length, target_length),
2609
- fill_value=min_dtype,
2610
- dtype=dtype,
2611
- device=device,
2612
  )
2613
- diagonal_attend_mask = torch.arange(
2614
- target_length, device=device
2615
- ) > cache_position.reshape(-1, 1)
2616
  if config.sliding_window is not None:
2617
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
2618
  # the check is needed to verify is current checkpoint was trained with sliding window or not
2619
- if (
2620
- not isinstance(past_key_values, SlidingWindowCache)
2621
- or sequence_length > target_length
2622
- ):
2623
- sliding_attend_mask = torch.arange(
2624
- target_length, device=device
2625
- ) <= (cache_position.reshape(-1, 1) - config.sliding_window)
2626
  diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
2627
  causal_mask *= diagonal_attend_mask
2628
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2629
  if attention_mask is not None:
2630
- causal_mask = (
2631
- causal_mask.clone()
2632
- ) # copy to contiguous memory for in-place edit
2633
  if attention_mask.shape[-1] > target_length:
2634
  attention_mask = attention_mask[:, :target_length]
2635
  mask_length = attention_mask.shape[-1]
2636
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
2637
- :, None, None, :
2638
- ].to(causal_mask.device)
2639
  padding_mask = padding_mask == 0
2640
- causal_mask[:, :, :, :mask_length] = causal_mask[
2641
- :, :, :, :mask_length
2642
- ].masked_fill(padding_mask, min_dtype)
2643
  return causal_mask
2644
 
2645
 
@@ -2699,6 +2240,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2699
  # Initialize weights and apply final processing
2700
  self.post_init()
2701
 
 
2702
  def get_input_embeddings(self):
2703
  return self.model.embed_tokens
2704
 
@@ -2783,9 +2325,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2783
  video_token_id = self.config.video_token_id
2784
  vision_start_token_id = self.config.vision_start_token_id
2785
  mrope_position_deltas = []
2786
- if input_ids is not None and (
2787
- image_grid_thw is not None or video_grid_thw is not None
2788
- ):
2789
  total_input_ids = input_ids
2790
  if attention_mask is None:
2791
  attention_mask = torch.ones_like(total_input_ids)
@@ -2801,9 +2341,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2801
  for i, input_ids in enumerate(total_input_ids):
2802
  input_ids = input_ids[attention_mask[i] == 1]
2803
  image_nums, video_nums = 0, 0
2804
- vision_start_indices = torch.argwhere(
2805
- input_ids == vision_start_token_id
2806
- ).squeeze(1)
2807
  vision_tokens = input_ids[vision_start_indices + 1]
2808
  image_nums = (vision_tokens == image_token_id).sum()
2809
  video_nums = (vision_tokens == video_token_id).sum()
@@ -2851,80 +2389,39 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2851
  )
2852
  text_len = ed - st
2853
 
2854
- st_idx = (
2855
- llm_pos_ids_list[-1].max() + 1
2856
- if len(llm_pos_ids_list) > 0
2857
- else 0
2858
- )
2859
- llm_pos_ids_list.append(
2860
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
2861
- )
2862
 
2863
- if torch.is_tensor(second_per_grid_t):
2864
- second_per_grid_t = second_per_grid_t.detach().item()
2865
  range_tensor = torch.arange(llm_grid_t).view(-1, 1)
2866
  expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
2867
 
2868
- time_tensor = (
2869
- expanded_range
2870
- * second_per_grid_t
2871
- * self.config.vision_config.tokens_per_second
2872
- )
2873
 
2874
  time_tensor_long = time_tensor.long()
2875
  t_index = time_tensor_long.flatten()
2876
 
2877
- h_index = (
2878
- torch.arange(llm_grid_h)
2879
- .view(1, -1, 1)
2880
- .expand(llm_grid_t, -1, llm_grid_w)
2881
- .flatten()
2882
- )
2883
- w_index = (
2884
- torch.arange(llm_grid_w)
2885
- .view(1, 1, -1)
2886
- .expand(llm_grid_t, llm_grid_h, -1)
2887
- .flatten()
2888
- )
2889
- llm_pos_ids_list.append(
2890
- torch.stack([t_index, h_index, w_index]) + text_len + st_idx
2891
- )
2892
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
2893
 
2894
  if st < len(input_tokens):
2895
- st_idx = (
2896
- llm_pos_ids_list[-1].max() + 1
2897
- if len(llm_pos_ids_list) > 0
2898
- else 0
2899
- )
2900
  text_len = len(input_tokens) - st
2901
- llm_pos_ids_list.append(
2902
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
2903
- )
2904
 
2905
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
2906
- position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
2907
- position_ids.device
2908
- )
2909
- mrope_position_deltas.append(
2910
- llm_positions.max() + 1 - len(total_input_ids[i])
2911
- )
2912
- mrope_position_deltas = torch.tensor(
2913
- mrope_position_deltas, device=input_ids.device
2914
- ).unsqueeze(1)
2915
  return position_ids, mrope_position_deltas
2916
  else:
2917
  if attention_mask is not None:
2918
  position_ids = attention_mask.long().cumsum(-1) - 1
2919
  position_ids.masked_fill_(attention_mask == 0, 1)
2920
- position_ids = (
2921
- position_ids.unsqueeze(0)
2922
- .expand(3, -1, -1)
2923
- .to(attention_mask.device)
2924
- )
2925
- max_position_ids = position_ids.max(0, keepdim=False)[0].max(
2926
- -1, keepdim=True
2927
- )[0]
2928
  mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
2929
  else:
2930
  position_ids = (
@@ -2940,9 +2437,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2940
 
2941
  return position_ids, mrope_position_deltas
2942
 
2943
- @replace_return_docstrings(
2944
- output_type=KeyeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
2945
- )
2946
  def forward(
2947
  self,
2948
  input_ids: torch.LongTensor = None,
@@ -2962,7 +2457,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
2962
  rope_deltas: Optional[torch.LongTensor] = None,
2963
  cache_position: Optional[torch.LongTensor] = None,
2964
  second_per_grid_ts: Optional[torch.Tensor] = None,
2965
- **kwargs,
2966
  ) -> Union[Tuple, KeyeCausalLMOutputWithPast]:
2967
  r"""
2968
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -3003,19 +2498,11 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3003
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
3004
  ```"""
3005
 
3006
- output_attentions = (
3007
- output_attentions
3008
- if output_attentions is not None
3009
- else self.config.output_attentions
3010
- )
3011
  output_hidden_states = (
3012
- output_hidden_states
3013
- if output_hidden_states is not None
3014
- else self.config.output_hidden_states
3015
- )
3016
- return_dict = (
3017
- return_dict if return_dict is not None else self.config.use_return_dict
3018
  )
 
3019
 
3020
  if inputs_embeds is None:
3021
  inputs_embeds = self.model.embed_tokens(input_ids)
@@ -3034,21 +2521,15 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3034
  image_grid_hws.append(thw_tuple)
3035
  image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
3036
  siglip_position_ids.append(image_position_ids)
3037
- sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
3038
  cu_seqlens.append(cu_seqlens[-1] + numel)
3039
-
3040
- siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
3041
- pixel_values.device
3042
- )
3043
- cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
3044
- pixel_values.device
3045
- )
3046
- sample_indices = torch.concat(sample_indices, dim=0).to(
3047
- pixel_values.device
3048
- )
3049
 
3050
  vision_outputs = self.visual(
3051
- pixel_values=pixel_values,
3052
  image_grid_thw=image_grid_hws,
3053
  position_ids=siglip_position_ids,
3054
  vision_return_embed_list=True,
@@ -3057,29 +2538,27 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3057
  cu_seqlens=cu_seqlens,
3058
  return_pooler_output=False,
3059
  use_rope=True,
3060
- window_size=-1,
3061
  )
3062
  image_embeds = vision_outputs.last_hidden_state
3063
 
3064
  image_embeds = self.mlp_AR(image_embeds, image_grid_thw)
3065
-
3066
  n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
3067
- # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor
3068
- image_embeds = torch.cat(image_embeds, dim=0)
3069
  n_image_features = image_embeds.shape[0]
3070
  if n_image_tokens != n_image_features:
3071
  raise ValueError(
3072
  f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
3073
  )
3074
 
3075
- mask = input_ids == self.config.image_token_id
3076
  mask_unsqueezed = mask.unsqueeze(-1)
3077
  mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
3078
  image_mask = mask_expanded.to(inputs_embeds.device)
3079
 
3080
- image_embeds = image_embeds.to(
3081
- inputs_embeds.device, inputs_embeds.dtype
3082
- )
3083
 
3084
  inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
3085
 
@@ -3098,20 +2577,14 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3098
  video_grid_hws.append(thw_tuple)
3099
  video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
3100
  siglip_position_ids.append(video_position_ids)
3101
- sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
3102
  cu_seqlens.append(cu_seqlens[-1] + numel)
3103
- siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
3104
- pixel_values_videos.device
3105
- )
3106
- cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
3107
- pixel_values_videos.device
3108
- )
3109
- sample_indices = torch.concat(sample_indices, dim=0).to(
3110
- pixel_values_videos.device
3111
- )
3112
 
3113
  vision_outputs = self.visual(
3114
- pixel_values=pixel_values_videos,
3115
  image_grid_thw=video_grid_hws,
3116
  position_ids=siglip_position_ids,
3117
  vision_return_embed_list=True,
@@ -3120,12 +2593,12 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3120
  cu_seqlens=cu_seqlens,
3121
  return_pooler_output=False,
3122
  use_rope=True,
3123
- window_size=-1,
3124
  )
3125
  video_embeds = vision_outputs.last_hidden_state
3126
  video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
3127
  n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
3128
- video_embeds = torch.cat(video_embeds, dim=0)
3129
  n_video_features = video_embeds.shape[0]
3130
  if n_video_tokens != n_video_features:
3131
  raise ValueError(
@@ -3137,18 +2610,14 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3137
  mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
3138
  video_mask = mask_expanded.to(inputs_embeds.device)
3139
 
3140
- video_embeds = video_embeds.to(
3141
- inputs_embeds.device, inputs_embeds.dtype
3142
- )
3143
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
3144
 
3145
  if attention_mask is not None:
3146
  attention_mask = attention_mask.to(inputs_embeds.device)
3147
 
3148
  # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
3149
- if position_ids is None and (
3150
- attention_mask is None or attention_mask.ndim == 2
3151
- ):
3152
  # calculate RoPE index once per generation in the pre-fill stage only
3153
  if (
3154
  (cache_position is not None and cache_position[0] == 0)
@@ -3189,7 +2658,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3189
  output_hidden_states=output_hidden_states,
3190
  return_dict=return_dict,
3191
  cache_position=cache_position,
3192
- **kwargs,
3193
  )
3194
 
3195
  hidden_states = outputs[0]
@@ -3309,13 +2778,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3309
  if expand_size == 1:
3310
  return input_ids, model_kwargs
3311
 
3312
- visual_keys = [
3313
- "pixel_values",
3314
- "image_grid_thw",
3315
- "pixel_values_videos",
3316
- "video_grid_thw",
3317
- "second_per_grid_ts",
3318
- ]
3319
 
3320
  def _expand_dict_for_generation_visual(dict_to_expand):
3321
  image_grid_thw = model_kwargs.get("image_grid_thw", None)
@@ -3325,9 +2788,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3325
  def _repeat_interleave_samples(x, lengths, repeat_times):
3326
  samples = torch.split(x, lengths)
3327
  repeat_args = [repeat_times] + [1] * (x.dim() - 1)
3328
- result = torch.cat(
3329
- [sample.repeat(*repeat_args) for sample in samples], dim=0
3330
- )
3331
  return result
3332
 
3333
  for key in dict_to_expand:
@@ -3363,9 +2824,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3363
  )
3364
  tensor = torch.tensor(dict_to_expand[key])
3365
  lengths = list(video_nums)
3366
- tensor = _repeat_interleave_samples(
3367
- tensor, lengths=lengths, repeat_times=expand_size
3368
- )
3369
  dict_to_expand[key] = tensor.tolist()
3370
  return dict_to_expand
3371
 
@@ -3377,9 +2836,7 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3377
  and isinstance(dict_to_expand[key], torch.Tensor)
3378
  and key not in visual_keys
3379
  ):
3380
- dict_to_expand[key] = dict_to_expand[key].repeat_interleave(
3381
- expand_size, dim=0
3382
- )
3383
  return dict_to_expand
3384
 
3385
  # input_ids is required for expanding visual inputs
@@ -3394,11 +2851,15 @@ class KeyeForConditionalGeneration(Qwen3PreTrainedModel, GenerationMixin):
3394
 
3395
  if is_encoder_decoder:
3396
  if model_kwargs.get("encoder_outputs") is None:
3397
- raise ValueError(
3398
- "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
3399
- )
3400
- model_kwargs["encoder_outputs"] = _expand_dict_for_generation(
3401
- model_kwargs["encoder_outputs"]
3402
- )
3403
 
3404
  return input_ids, model_kwargs
 
 
 
 
 
 
 
 
 
31
  from torch.nn import CrossEntropyLoss
32
 
33
  from transformers.activations import ACT2FN
34
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
 
 
 
 
35
  from transformers.generation import GenerationMixin
36
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
37
+ from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutput, BaseModelOutputWithPooling
 
 
 
 
38
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
39
  from transformers.modeling_utils import PreTrainedModel, sdpa_attention_forward
40
  from transformers.activations import GELUActivation, ACT2FN, PytorchGELUTanh
 
46
  logging,
47
  replace_return_docstrings,
48
  torch_int,
49
+ is_flash_attn_greater_or_equal_2_10
50
  )
51
  from .configuration_keye import KeyeConfig, KeyeVisionConfig
52
 
 
55
  from typing import Any, Callable, Optional, Tuple, Union, List
56
  from torch import nn
57
  from torch.nn.init import _calculate_fan_in_and_fan_out
58
+ from einops import repeat
59
 
60
 
 
61
  if is_flash_attn_2_available():
62
  from flash_attn import flash_attn_varlen_func
63
  from flash_attn.layers.rotary import apply_rotary_emb
 
71
 
72
  _CONFIG_FOR_DOC = "KeyeConfig"
73
 
 
74
  class KeyeMLP(nn.Module):
75
  def __init__(self, config, bias: bool = False):
76
  super().__init__()
 
82
  self.act_fn = ACT2FN[config.hidden_act]
83
 
84
  def forward(self, hidden_state):
85
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
86
 
87
 
88
  def _trunc_normal_(tensor, mean, std, a, b):
 
122
 
123
 
124
  def trunc_normal_tf_(
125
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
 
 
 
 
126
  ) -> torch.Tensor:
127
  """Fills the input Tensor with values drawn from a truncated
128
  normal distribution. The values are effectively drawn from the
 
180
  variance_scaling_(tensor, mode="fan_in", distribution="normal")
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  class Projector(nn.Module):
184
 
185
+ def __init__(self, text_config: KeyeConfig,vision_config: KeyeVisionConfig):
186
  super().__init__()
187
  self.text_config = text_config
188
  self.vision_config = vision_config
 
201
  self.hidden_size, self.text_config.hidden_size, bias=True
202
  )
203
 
204
+ def forward(self, image_features: torch.Tensor, image_grid_thw: List[Tuple[int, int, int]]) -> torch.Tensor:
 
 
205
  m1, m2 = self.merge_kernel_size
206
  if isinstance(image_features, (list, tuple)):
207
  processed_features = list()
 
210
  t, h, w = image_grid
211
  from einops import rearrange
212
 
213
+ image_feature = rearrange(image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2)
 
 
 
 
 
 
 
 
214
  hidden_states = self.linear_1(image_feature)
215
  hidden_states = self.act(hidden_states)
216
  hidden_states = self.linear_2(hidden_states)
 
228
 
229
  return hidden_states.view(*dims, -1)
230
 
 
231
  class SiglipVisionEmbeddings(nn.Module):
232
  def __init__(self, config: KeyeVisionConfig):
233
  super().__init__()
 
251
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
252
  self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
253
 
254
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
 
 
 
 
255
 
256
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False) -> torch.Tensor:
 
 
 
 
 
 
257
  """
258
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
259
  images. This method is also adapted to support torch.jit tracing and no class embeddings.
 
276
  new_width = width // self.patch_size
277
 
278
  sqrt_num_positions = torch_int(num_positions**0.5)
279
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
 
 
280
  patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
281
 
282
  patch_pos_embed = nn.functional.interpolate(
 
304
  if grid in self.cache_position_embedding:
305
  self.cache_position_count[grid] += 1
306
  return self.cache_position_embedding[grid]
307
+
308
  if len(self.cache_position_embedding) >= max_cache:
309
+ min_hit_grid = min(self.cache_position_count, key=self.cache_position_count.get)
 
 
310
  self.cache_position_count.pop(min_hit_grid)
311
  self.cache_position_embedding.pop(min_hit_grid)
312
+
313
  position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
314
  self.cache_position_count[grid] = 1
315
  self.cache_position_embedding[grid] = position_embedding
316
  return position_embedding
317
 
318
  def forward(
319
+ self,
320
+ pixel_values: torch.FloatTensor,
321
  position_ids: Optional[torch.Tensor] = None,
322
+ image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
323
+ interpolate_pos_encoding=False
 
 
324
  ) -> torch.Tensor:
325
  if pixel_values.dim() == 5:
326
  assert position_ids is not None
327
  from einops import rearrange
 
328
  batch_size, squence_len, channel, height, width = pixel_values.shape
329
  target_dtype = self.patch_embedding.weight.dtype
330
  pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
331
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
 
 
332
  embeddings = patch_embeds.flatten(-2).squeeze(-1)
333
+ embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len)
 
 
334
 
335
  # todo: not dubug
336
  if interpolate_pos_encoding and image_grid_thw is not None:
 
338
  assert batch_size == 1
339
  start = 0
340
  image_embedding_list = list()
341
+ assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], (flatten_image_grid_thw, embeddings.shape)
 
 
 
342
  embeddings = embeddings.squeeze(0)
343
  tmp_embeddings = list()
344
  for image_grid in image_grid_thw:
345
  t, h, w = image_grid
346
  end = start + t * h * w
347
+ image_embeddings = embeddings[start: end, :]
348
+ position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).repeat(
349
+ t, 1)
 
 
 
350
  image_embeddings = image_embeddings + position_embedding
351
  tmp_embeddings.append(image_embeddings)
352
  start = end
 
372
  if attention_mask is not None:
373
  attn_weights = attn_weights + attention_mask
374
 
375
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
376
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
 
 
 
377
 
378
  attn_output = torch.matmul(attn_weights, value)
379
  attn_output = attn_output.transpose(1, 2).contiguous()
 
414
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
415
  """Input shape: Batch x Time x Channel"""
416
 
417
+ use_flash_attn = (cu_seqlens is not None) and self.config._attn_implementation == "flash_attention_2"
 
 
418
 
419
  batch_size, seq_length, embed_dim = hidden_states.shape
420
 
 
423
  values = self.v_proj(hidden_states)
424
 
425
  if rope_emb is None:
426
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
427
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
428
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
429
  else:
430
  assert cu_seqlens is not None, "Rope support flash attn only."
431
  cos, sin = rope_emb
432
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim)
 
 
433
  keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim)
434
+ if use_flash_attn:
435
+ queries, keys = apply_rotary_pos_emb_flashatt(queries, keys, cos, sin)
436
+ else:
437
+ queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin)
438
  queries = queries.transpose(1, 2)
439
  keys = keys.transpose(1, 2)
440
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
 
 
441
 
442
  if not use_flash_attn:
443
  attention_interface: Callable = eager_attention_forward
 
460
  scaling=self.scale,
461
  dropout=0.0 if not self.training else self.dropout,
462
  )
463
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
 
 
464
  else:
465
  assert batch_size == 1, hidden_states.shape
466
  queries = queries.transpose(1, 2).squeeze(0)
467
  keys = keys.transpose(1, 2).squeeze(0)
468
  values = values.transpose(1, 2).squeeze(0)
469
 
 
 
470
  max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
471
  max_seqlen_k = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
472
+ assert cu_seqlens[-1].item() == queries.shape[0] == keys.shape[0] == values.shape[0], (cu_seqlens, queries.shape, keys.shape, values.shape)
 
 
 
 
 
473
 
474
  attn_output = flash_attn_varlen_func(
475
  queries,
 
735
  embed_dim = config.hidden_size
736
  num_heads = config.num_attention_heads
737
  head_dim = embed_dim // num_heads
738
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
739
  self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
740
  self.gradient_checkpointing = False
741
 
 
751
 
752
  def build_window_index(self, image_grid, window_size, device):
753
  from einops import rearrange
 
754
  window_indices = list()
755
  pad_values = -100
756
  start_window_index = 0
 
762
  pad_w = (-w) % window_size
763
  assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
764
  window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
765
+ window_index = rearrange(window_index, "t (h p1) (w p2) -> t (h w) (p1 p2)", p1=window_size, p2=window_size)
 
 
 
 
 
766
  window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
767
  window_index = window_index.reshape(-1)
768
  window_index = window_index[window_index != pad_values]
769
  window_indices.append(window_index + start_window_index)
770
+ cu_seqlens_within_windows.append(window_seqlens.cumsum(0) + start_window_index)
 
 
771
  start_window_index += t * h * w
772
  window_indices = torch.concat(window_indices, dim=0)
773
  cu_seqlens_within_windows = torch.concat(cu_seqlens_within_windows, dim=0)
774
+ cu_seqlens_within_windows = F.pad(cu_seqlens_within_windows, (1, 0), value=0).to(torch.int32)
 
 
775
  return window_indices, cu_seqlens_within_windows
776
 
777
  # Ignore copy
 
783
  output_attentions: Optional[bool] = None,
784
  output_hidden_states: Optional[bool] = None,
785
  cu_seqlens: Optional[List[torch.Tensor]] = None,
786
+ image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
 
 
787
  height_position_ids: Optional[torch.Tensor] = None,
788
  width_position_ids: Optional[torch.Tensor] = None,
789
  use_rope: Optional[bool] = False,
 
816
 
817
  vision_or_text = "vision"
818
  assert vision_or_text in ["vision", "text"]
819
+ use_window_attn = (window_size > 0 and vision_or_text == "vision")
820
  use_rope = (use_rope is True) and (vision_or_text == "vision")
821
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
822
  output_hidden_states = (
823
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
824
  )
825
 
826
  encoder_states = () if output_hidden_states else None
 
828
 
829
  device = inputs_embeds.device
830
  hidden_states = inputs_embeds
831
+ attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None
 
 
 
 
832
  if use_rope is True:
833
  flatten_image_grid_thw = self.flatten_list(image_grid_thw)
834
+ assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], (flatten_image_grid_thw, hidden_states.shape)
 
 
 
835
 
836
  if width_position_ids is None or height_position_ids is None:
837
  split_hids = list()
 
844
  split_wids.append(sample_wids)
845
  width_position_ids = torch.concat(split_wids, dim=0)
846
  height_position_ids = torch.concat(split_hids, dim=0)
847
+
848
  window_indices, cu_seqlens_within_windows = None, None
849
 
850
  if use_window_attn:
851
+ window_indices, cu_seqlens_within_windows = self.build_window_index(flatten_image_grid_thw, window_size, device)
 
 
852
  reversed_window_indices = window_indices.argsort()
853
  height_position_ids = height_position_ids[window_indices]
854
  width_position_ids = width_position_ids[window_indices]
 
863
 
864
  rope_emb = None
865
  window_indices, cu_seqlens_within_windows = None, None
866
+
867
  if use_window_attn:
868
  flatten_image_grid_thw = self.flatten_list(image_grid_thw)
869
+ assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], (flatten_image_grid_thw, hidden_states.shape)
870
+
871
+ window_indices, cu_seqlens_within_windows = self.build_window_index(flatten_image_grid_thw, window_size, device)
 
 
 
 
 
872
  reversed_window_indices = window_indices.argsort()
873
 
874
  if use_window_attn:
 
880
 
881
  for encoder_layer in self.layers:
882
  if output_hidden_states:
883
+ encoder_states = encoder_states + ((hidden_states[:, reversed_window_indices, :],) if use_window_attn else (hidden_states, ))
 
 
 
 
884
  if self.gradient_checkpointing and self.training:
885
  layer_outputs = self._gradient_checkpointing_func(
886
  encoder_layer.__call__,
 
926
  self.embeddings = SiglipVisionEmbeddings(config)
927
  self.encoder = SiglipEncoder(config)
928
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
929
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
 
 
930
  if self.use_head:
931
  self.head = SiglipMultiheadAttentionPoolingHead(config)
932
 
933
  # @can_return_tuple
934
  @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
935
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=KeyeVisionConfig)
 
 
936
  def forward(
937
  self,
938
  pixel_values,
 
948
  cu_seqlens: Optional[List[torch.Tensor]] = None,
949
  padding_mask: Optional[torch.Tensor] = None,
950
  vision_return_embed_list: Optional[bool] = False,
951
+ image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
 
 
952
  return_pooler_output: Optional[bool] = True,
953
  use_rope: Optional[bool] = False,
954
  window_size: Optional[bool] = -1,
 
957
  Returns:
958
 
959
  """
960
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
961
  output_hidden_states = (
962
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
963
  )
964
  hidden_states = self.embeddings(
965
+ pixel_values,
966
+ interpolate_pos_encoding=interpolate_pos_encoding,
967
  position_ids=position_ids,
968
+ image_grid_thw=image_grid_thw
969
  )
970
 
971
  encoder_outputs: BaseModelOutput = self.encoder(
 
1001
  token_indices = (sample_index == sample_idx).nonzero().flatten()
1002
  sample_hidden_state = hidden_state[token_indices]
1003
  sample_hidden_state_list.append(sample_hidden_state)
1004
+
1005
  if not vision_return_embed_list:
1006
+ max_length = max([_state.shape[0] for _state in sample_hidden_state_list])
 
 
1007
  tmp_sample_hidden_state_list = list()
1008
  padding_mask = list()
1009
  for idx, _state in enumerate(sample_hidden_state_list):
1010
  padding_length = max_length - _state.shape[0]
1011
+ mask = _state.new_zeros(size=(max_length, ), dtype=torch.int64)
1012
+ mask[-padding_length: ] = 1
1013
  padding_mask.append(mask)
1014
  padding = _state.new_zeros(size=(padding_length, dim))
1015
  new_state = torch.concat([_state, padding], dim=0)
1016
  tmp_sample_hidden_state_list.append(new_state)
1017
+ sample_hidden_state = torch.stack(tmp_sample_hidden_state_list, dim=0)
1018
+ padding_mask = torch.stack(padding_mask, dim=0).float().to(last_hidden_state.dtype)
1019
+ pooler_output = self.head(sample_hidden_state, key_padding_mask=padding_mask)
 
 
 
 
 
 
 
 
1020
  else:
1021
  pooler_output = list()
1022
  for state in sample_hidden_state_list:
 
1040
  hidden_states=encoder_outputs.hidden_states,
1041
  attentions=encoder_outputs.attentions,
1042
  )
1043
+
1044
  sample_hidden_state = list()
1045
  assert cu_seqlens is not None
1046
  for i in range(cu_seqlens.shape[0] - 1):
1047
  start = cu_seqlens[i]
1048
  end = cu_seqlens[i + 1]
1049
+ tensor = last_hidden_state[:, start: end, :].squeeze(0)
1050
  sample_hidden_state.append(tensor)
1051
+
1052
  return BaseModelOutputWithPooling(
1053
  last_hidden_state=sample_hidden_state,
1054
  pooler_output=None,
 
1064
  super().__init__()
1065
 
1066
  self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1067
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
 
 
1068
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1069
  self.mlp = SiglipMLP(config)
1070
 
 
1072
  batch_size = hidden_state.shape[0]
1073
  probe = self.probe.repeat(batch_size, 1, 1)
1074
 
1075
+ hidden_state = self.attention(probe, hidden_state, hidden_state, key_padding_mask=key_padding_mask)[0]
 
 
1076
 
1077
  residual = hidden_state
1078
  hidden_state = self.layernorm(hidden_state)
 
1102
 
1103
  # @can_return_tuple
1104
  @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1105
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=KeyeVisionConfig)
 
 
1106
  def forward(
1107
  self,
1108
  pixel_values,
 
1112
  interpolate_pos_encoding: bool = False,
1113
  position_ids: Optional[torch.Tensor] = None,
1114
  vision_return_embed_list: Optional[bool] = False,
1115
+ image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
 
 
1116
  cu_seqlens: Optional[List[torch.Tensor]] = None,
1117
  return_pooler_output: Optional[bool] = True,
1118
  use_rope: Optional[bool] = False,
 
1157
  )
1158
 
1159
 
1160
+
1161
  class Qwen3RMSNorm(nn.Module):
1162
  def __init__(self, hidden_size, eps=1e-6):
1163
  """
 
1204
  return q_embed, k_embed
1205
 
1206
 
1207
+
1208
  def rotate_half(x):
1209
  """Rotates half the hidden dims of the input."""
1210
  x1 = x[..., : x.shape[-1] // 2]
 
1225
  k_embed = k_embed.to(orig_k_dtype)
1226
  return q_embed, k_embed
1227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1228
  Keye_START_DOCSTRING = r"""
1229
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1230
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
1250
  config_class = KeyeConfig
1251
  base_model_prefix = "model"
1252
  supports_gradient_checkpointing = True
1253
+ _no_split_modules = ["KeyeDecoderLayer"]
1254
  _skip_keys_device_placement = "past_key_values"
1255
  _supports_flash_attn_2 = True
1256
  _supports_sdpa = True
 
1269
  module.weight.data[module.padding_idx].zero_()
1270
 
1271
 
1272
+
1273
  class SigLIPRotaryEmbedding(nn.Module):
1274
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
1275
  super().__init__()
 
1278
  self.rope_init()
1279
 
1280
  def rope_init(self):
1281
+ inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
 
 
1282
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1283
 
1284
  def forward(self, seqlen: int) -> torch.Tensor:
1285
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
 
 
1286
  freqs = torch.outer(seq, self.inv_freq)
1287
  return freqs
1288
 
 
1309
  else:
1310
  # BC: "rope_type" was originally "type"
1311
  if config.rope_scaling is not None:
1312
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
1313
  else:
1314
  self.rope_type = "default"
1315
  self.max_seq_len_cached = config.max_position_embeddings
1316
  self.original_max_seq_len = config.max_position_embeddings
1317
+
1318
  # BC: "rope_type" was originally "type"
1319
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1320
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
1321
  else:
1322
  self.rope_type = "default"
1323
  self.max_seq_len_cached = config.max_position_embeddings
 
1341
  inv_freq, self.attention_scaling = self.rope_init_fn(
1342
  self.config, device, seq_len=seq_len, **self.rope_kwargs
1343
  )
1344
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
 
 
1345
  self.max_seq_len_cached = seq_len
1346
 
1347
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
 
 
 
1348
  self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1349
  self.max_seq_len_cached = self.original_max_seq_len
1350
 
 
1355
 
1356
  # Core RoPE block. In contrast to other models, Keye has different position ids for the grids
1357
  # So we expand the inv_freq to shape (3, ...)
1358
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
1359
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
 
 
 
 
 
 
1360
  # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
1361
  device_type = x.device.type
1362
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
 
 
 
 
1363
  with torch.autocast(device_type=device_type, enabled=False):
1364
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
 
 
1365
  emb = torch.cat((freqs, freqs), dim=-1)
1366
  cos = emb.cos()
1367
  sin = emb.sin()
 
1431
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
1432
  """
1433
  mrope_section = mrope_section * 2
1434
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
1435
+ unsqueeze_dim
1436
+ )
1437
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
1438
+ unsqueeze_dim
1439
+ )
1440
 
1441
  q_embed = (q * cos) + (rotate_half(q) * sin)
1442
  k_embed = (k * cos) + (rotate_half(k) * sin)
 
1451
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1452
  if n_rep == 1:
1453
  return hidden_states
1454
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
1455
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1456
 
1457
 
 
1474
 
1475
  self.hidden_size = config.hidden_size
1476
  self.num_heads = config.num_attention_heads
1477
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
 
 
1478
  self.num_key_value_heads = config.num_key_value_heads
1479
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
 
 
1480
  self.is_causal = True
1481
  self.attention_dropout = config.attention_dropout
1482
  self.rope_scaling = config.rope_scaling
1483
 
1484
  self.q_proj = nn.Linear(
1485
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
 
 
1486
  )
1487
  self.k_proj = nn.Linear(
1488
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
 
 
1489
  )
1490
  self.v_proj = nn.Linear(
1491
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
 
 
1492
  )
1493
  self.o_proj = nn.Linear(
1494
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
 
 
1495
  )
1496
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
1497
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
 
 
 
 
1498
 
1499
  self.rotary_emb = KeyeRotaryEmbedding(config=config)
1500
 
 
1507
  output_attentions: bool = False,
1508
  use_cache: bool = False,
1509
  cache_position: Optional[torch.LongTensor] = None,
1510
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
1511
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1512
  bsz, q_len, _ = hidden_states.size()
1513
 
1514
+ query_states = self.q_norm(self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim))
1515
+ key_states = self.k_norm(self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim))
 
 
 
 
1516
  value_states = self.v_proj(hidden_states)
1517
 
1518
  query_states = query_states.transpose(1, 2)
 
1525
  )
1526
 
1527
  if past_key_value is not None:
1528
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1529
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
1530
 
1531
  # repeat k/v heads if n_kv_heads < n_heads
1532
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1533
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1534
 
1535
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
1536
+
 
1537
 
1538
  if attention_mask is not None: # no matter the length, we just slice it
1539
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 
1542
  # Fix precision issues in float16 inference
1543
  # Replace inf values with zeros in attention weights to prevent NaN propagation
1544
  if query_states.dtype == torch.float16:
1545
+ attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
 
 
1546
 
1547
  # upcast attention to fp32
1548
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
1549
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
1550
  attn_output = torch.matmul(attn_weights, value_states)
1551
 
1552
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
1592
  output_attentions: bool = False,
1593
  use_cache: bool = False,
1594
  cache_position: Optional[torch.LongTensor] = None,
1595
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
1596
  cu_seqlens: Optional[torch.Tensor] = None,
1597
+ sliding_window = -1,
1598
  **kwargs,
1599
  ):
1600
  bsz, q_len, _ = hidden_states.size()
1601
+ q= self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
1602
  query_states = self.q_norm(q)
1603
+ key_states = self.k_norm(self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim))
 
 
1604
  value_states = self.v_proj(hidden_states)
1605
 
1606
  query_states = query_states.transpose(1, 2)
 
1614
  )
1615
 
1616
  if past_key_value is not None:
1617
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1618
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
1619
 
1620
  # repeat k/v heads if n_kv_heads < n_heads
1621
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1622
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1623
  dropout_rate = 0.0 if not self.training else self.attention_dropout
1624
+
1625
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1626
  # therefore the input hidden states gets silently casted in float32. Hence, we need
1627
  # cast them back in float16 just to be sure everything works as expected.
 
1675
  max_seqlen,
1676
  dropout_p=dropout_rate,
1677
  window_size=(sliding_window, sliding_window),
1678
+ causal=self.is_causal
1679
  )
1680
  else:
1681
  attn_output = _flash_attention_forward(
 
1715
  output_attentions: bool = False,
1716
  use_cache: bool = False,
1717
  cache_position: Optional[torch.LongTensor] = None,
1718
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
1719
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1720
  if output_attentions:
1721
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
1736
 
1737
  bsz, q_len, _ = hidden_states.size()
1738
 
1739
+ query_states = self.q_norm(self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim))
1740
+ key_states = self.k_norm(self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim))
 
 
 
 
1741
  value_states = self.v_proj(hidden_states)
1742
 
1743
  query_states = query_states.transpose(1, 2)
 
1750
  )
1751
 
1752
  if past_key_value is not None:
1753
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1754
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
1755
 
1756
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1757
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
1789
  return attn_output, None, past_key_value
1790
 
1791
 
1792
+
1793
  QWEN3_ATTENTION_CLASSES = {
1794
  "eager": KeyeAttention,
1795
  "flash_attention_2": KeyeFlashAttention2,
 
1801
  def __init__(self, config: KeyeConfig, layer_idx: int):
1802
  super().__init__()
1803
  self.hidden_size = config.hidden_size
1804
+
1805
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
 
 
 
1806
  logger.warning_once(
1807
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
1808
  "unexpected results may be encountered."
1809
  )
1810
 
1811
+ self.self_attn = QWEN3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
1812
  self.mlp = Qwen3MLP(config)
1813
  self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1814
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
1815
 
1816
  def forward(
1817
  self,
 
1822
  output_attentions: Optional[bool] = False,
1823
  use_cache: Optional[bool] = False,
1824
  cache_position: Optional[torch.LongTensor] = None,
1825
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
1826
  **kwargs,
1827
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
1828
  """
1829
  Args:
1830
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
1860
  use_cache=use_cache,
1861
  cache_position=cache_position,
1862
  position_embeddings=position_embeddings,
1863
+ **kwargs
1864
  )
1865
 
1866
  hidden_states = residual + hidden_states
 
1876
  if output_attentions:
1877
  outputs += (self_attn_weights,)
1878
 
1879
+
1880
  if use_cache:
1881
  outputs += (present_key_value,)
1882
 
 
1893
  self.padding_idx = config.pad_token_id
1894
  self.vocab_size = config.vocab_size
1895
 
1896
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
1897
  self.layers = nn.ModuleList(
1898
+ [KeyeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
1899
  )
1900
  self._attn_implementation = config._attn_implementation
1901
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1923
  output_hidden_states: Optional[bool] = None,
1924
  return_dict: Optional[bool] = None,
1925
  cache_position: Optional[torch.LongTensor] = None,
1926
+ **kwargs
1927
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1928
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1929
  output_hidden_states = (
1930
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
1931
  )
1932
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1933
 
1934
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1935
 
1936
  if (input_ids is None) ^ (inputs_embeds is not None):
1937
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
1938
 
1939
  if self.gradient_checkpointing and self.training:
1940
  if use_cache:
 
1951
  inputs_embeds = self.embed_tokens(input_ids)
1952
 
1953
  if cache_position is None:
1954
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
1955
  cache_position = torch.arange(
1956
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
1957
  )
1958
 
1959
  # the hard coded `3` is for temporal, height and width.
1960
  if position_ids is None:
1961
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
 
 
1962
  elif position_ids.dim() == 2:
1963
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1964
 
1965
  causal_mask = self._update_causal_mask(
1966
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
1967
  )
1968
  hidden_states = inputs_embeds
1969
 
 
2023
  next_cache = next_decoder_cache if use_cache else None
2024
 
2025
  if not return_dict:
2026
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
2027
  return BaseModelOutputWithPast(
2028
  last_hidden_state=hidden_states,
2029
  past_key_values=next_cache,
 
2041
  ):
2042
  if self.config._attn_implementation == "flash_attention_2":
2043
  if attention_mask is not None and past_key_values is not None:
2044
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
 
 
2045
  if is_padding_right:
2046
  raise ValueError(
2047
  "You are attempting to perform batched generation with padding_side='right'"
 
2055
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2056
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2057
  # to infer the attention mask.
2058
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
2059
  using_static_cache = isinstance(past_key_values, StaticCache)
2060
  using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2061
 
 
2110
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2111
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2112
  # Details: https://github.com/pytorch/pytorch/issues/110213
2113
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
2114
 
2115
  return causal_mask
2116
 
 
2156
  else:
2157
  min_dtype = torch.finfo(dtype).min
2158
  causal_mask = torch.full(
2159
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
 
 
 
2160
  )
2161
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
 
 
2162
  if config.sliding_window is not None:
2163
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
2164
  # the check is needed to verify is current checkpoint was trained with sliding window or not
2165
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
2166
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
2167
+ cache_position.reshape(-1, 1) - config.sliding_window
2168
+ )
 
 
 
2169
  diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
2170
  causal_mask *= diagonal_attend_mask
2171
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2172
  if attention_mask is not None:
2173
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
 
 
2174
  if attention_mask.shape[-1] > target_length:
2175
  attention_mask = attention_mask[:, :target_length]
2176
  mask_length = attention_mask.shape[-1]
2177
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
2178
+ causal_mask.device
2179
+ )
2180
  padding_mask = padding_mask == 0
2181
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2182
+ padding_mask, min_dtype
2183
+ )
2184
  return causal_mask
2185
 
2186
 
 
2240
  # Initialize weights and apply final processing
2241
  self.post_init()
2242
 
2243
+
2244
  def get_input_embeddings(self):
2245
  return self.model.embed_tokens
2246
 
 
2325
  video_token_id = self.config.video_token_id
2326
  vision_start_token_id = self.config.vision_start_token_id
2327
  mrope_position_deltas = []
2328
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
 
 
2329
  total_input_ids = input_ids
2330
  if attention_mask is None:
2331
  attention_mask = torch.ones_like(total_input_ids)
 
2341
  for i, input_ids in enumerate(total_input_ids):
2342
  input_ids = input_ids[attention_mask[i] == 1]
2343
  image_nums, video_nums = 0, 0
2344
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
 
 
2345
  vision_tokens = input_ids[vision_start_indices + 1]
2346
  image_nums = (vision_tokens == image_token_id).sum()
2347
  video_nums = (vision_tokens == video_token_id).sum()
 
2389
  )
2390
  text_len = ed - st
2391
 
2392
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
2393
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
 
 
 
 
 
 
2394
 
2395
+ if torch.is_tensor(second_per_grid_t): second_per_grid_t = second_per_grid_t.detach().item()
 
2396
  range_tensor = torch.arange(llm_grid_t).view(-1, 1)
2397
  expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
2398
 
2399
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
 
 
 
 
2400
 
2401
  time_tensor_long = time_tensor.long()
2402
  t_index = time_tensor_long.flatten()
2403
 
2404
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
2405
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
2406
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
 
 
 
 
 
 
 
 
 
 
 
 
2407
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
2408
 
2409
  if st < len(input_tokens):
2410
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
 
 
 
 
2411
  text_len = len(input_tokens) - st
2412
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
 
 
2413
 
2414
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
2415
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
2416
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
2417
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
 
 
 
 
 
 
2418
  return position_ids, mrope_position_deltas
2419
  else:
2420
  if attention_mask is not None:
2421
  position_ids = attention_mask.long().cumsum(-1) - 1
2422
  position_ids.masked_fill_(attention_mask == 0, 1)
2423
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
2424
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
 
 
 
 
 
 
2425
  mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
2426
  else:
2427
  position_ids = (
 
2437
 
2438
  return position_ids, mrope_position_deltas
2439
 
2440
+ @replace_return_docstrings(output_type=KeyeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
2441
  def forward(
2442
  self,
2443
  input_ids: torch.LongTensor = None,
 
2457
  rope_deltas: Optional[torch.LongTensor] = None,
2458
  cache_position: Optional[torch.LongTensor] = None,
2459
  second_per_grid_ts: Optional[torch.Tensor] = None,
2460
+ **kwargs
2461
  ) -> Union[Tuple, KeyeCausalLMOutputWithPast]:
2462
  r"""
2463
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
2498
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
2499
  ```"""
2500
 
2501
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
2502
  output_hidden_states = (
2503
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
2504
  )
2505
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2506
 
2507
  if inputs_embeds is None:
2508
  inputs_embeds = self.model.embed_tokens(input_ids)
 
2521
  image_grid_hws.append(thw_tuple)
2522
  image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
2523
  siglip_position_ids.append(image_position_ids)
2524
+ sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64))
2525
  cu_seqlens.append(cu_seqlens[-1] + numel)
2526
+
2527
+ siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values.device)
2528
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device)
2529
+ sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device)
 
 
 
 
 
 
2530
 
2531
  vision_outputs = self.visual(
2532
+ pixel_values=pixel_values,
2533
  image_grid_thw=image_grid_hws,
2534
  position_ids=siglip_position_ids,
2535
  vision_return_embed_list=True,
 
2538
  cu_seqlens=cu_seqlens,
2539
  return_pooler_output=False,
2540
  use_rope=True,
2541
+ window_size =-1,
2542
  )
2543
  image_embeds = vision_outputs.last_hidden_state
2544
 
2545
  image_embeds = self.mlp_AR(image_embeds, image_grid_thw)
2546
+
2547
  n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
2548
+ #image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor
2549
+ image_embeds = torch.cat(image_embeds,dim=0)
2550
  n_image_features = image_embeds.shape[0]
2551
  if n_image_tokens != n_image_features:
2552
  raise ValueError(
2553
  f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
2554
  )
2555
 
2556
+ mask = (input_ids == self.config.image_token_id)
2557
  mask_unsqueezed = mask.unsqueeze(-1)
2558
  mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
2559
  image_mask = mask_expanded.to(inputs_embeds.device)
2560
 
2561
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
 
 
2562
 
2563
  inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
2564
 
 
2577
  video_grid_hws.append(thw_tuple)
2578
  video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
2579
  siglip_position_ids.append(video_position_ids)
2580
+ sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64))
2581
  cu_seqlens.append(cu_seqlens[-1] + numel)
2582
+ siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values_videos.device)
2583
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values_videos.device)
2584
+ sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values_videos.device)
 
 
 
 
 
 
2585
 
2586
  vision_outputs = self.visual(
2587
+ pixel_values=pixel_values_videos,
2588
  image_grid_thw=video_grid_hws,
2589
  position_ids=siglip_position_ids,
2590
  vision_return_embed_list=True,
 
2593
  cu_seqlens=cu_seqlens,
2594
  return_pooler_output=False,
2595
  use_rope=True,
2596
+ window_size = -1,
2597
  )
2598
  video_embeds = vision_outputs.last_hidden_state
2599
  video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
2600
  n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
2601
+ video_embeds = torch.cat(video_embeds,dim=0)
2602
  n_video_features = video_embeds.shape[0]
2603
  if n_video_tokens != n_video_features:
2604
  raise ValueError(
 
2610
  mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
2611
  video_mask = mask_expanded.to(inputs_embeds.device)
2612
 
2613
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
 
 
2614
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
2615
 
2616
  if attention_mask is not None:
2617
  attention_mask = attention_mask.to(inputs_embeds.device)
2618
 
2619
  # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
2620
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
 
 
2621
  # calculate RoPE index once per generation in the pre-fill stage only
2622
  if (
2623
  (cache_position is not None and cache_position[0] == 0)
 
2658
  output_hidden_states=output_hidden_states,
2659
  return_dict=return_dict,
2660
  cache_position=cache_position,
2661
+ **kwargs
2662
  )
2663
 
2664
  hidden_states = outputs[0]
 
2778
  if expand_size == 1:
2779
  return input_ids, model_kwargs
2780
 
2781
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
 
 
 
 
 
 
2782
 
2783
  def _expand_dict_for_generation_visual(dict_to_expand):
2784
  image_grid_thw = model_kwargs.get("image_grid_thw", None)
 
2788
  def _repeat_interleave_samples(x, lengths, repeat_times):
2789
  samples = torch.split(x, lengths)
2790
  repeat_args = [repeat_times] + [1] * (x.dim() - 1)
2791
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
 
 
2792
  return result
2793
 
2794
  for key in dict_to_expand:
 
2824
  )
2825
  tensor = torch.tensor(dict_to_expand[key])
2826
  lengths = list(video_nums)
2827
+ tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
 
 
2828
  dict_to_expand[key] = tensor.tolist()
2829
  return dict_to_expand
2830
 
 
2836
  and isinstance(dict_to_expand[key], torch.Tensor)
2837
  and key not in visual_keys
2838
  ):
2839
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
 
 
2840
  return dict_to_expand
2841
 
2842
  # input_ids is required for expanding visual inputs
 
2851
 
2852
  if is_encoder_decoder:
2853
  if model_kwargs.get("encoder_outputs") is None:
2854
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
2855
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
 
 
 
 
2856
 
2857
  return input_ids, model_kwargs
2858
+
2859
+
2860
+
2861
+
2862
+
2863
+
2864
+
2865
+