soldni commited on
Commit
e5a7913
1 Parent(s): 3cf55ab

Update modeling_molmo.py

Browse files
Files changed (1) hide show
  1. modeling_molmo.py +184 -205
modeling_molmo.py CHANGED
@@ -32,13 +32,13 @@ import einops
32
  from transformers import PreTrainedModel
33
  from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
34
 
35
- from olmo.util import resource_path
36
  from .configuration_molmo import (
37
  MolmoConfig,
38
  VisionBackboneConfig,
39
  VisionBackboneType,
40
  ImagePooling2DType,
41
- ImageProjectType,
42
  AttentionType,
43
  MolmoConfigurationError,
44
  )
@@ -54,6 +54,20 @@ else:
54
  log = logging.getLogger(__name__)
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
58
  """
59
  Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
@@ -106,7 +120,7 @@ class Embedding(nn.Module):
106
  def reset_parameters(self):
107
  nn.init.normal_(self.embedding, std=self.initializer_range)
108
  nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
109
-
110
  def forward(self, x: torch.Tensor) -> torch.Tensor:
111
  return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
112
 
@@ -131,7 +145,7 @@ class Dropout(nn.Dropout):
131
  if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
132
  return input
133
  else:
134
- if self.mask_p > 0. and self.training:
135
  assert drop_mask is not None
136
  drop_mask = drop_mask.to(input.dtype)
137
  keep_prob = 1.0 - self.p
@@ -143,7 +157,7 @@ class Dropout(nn.Dropout):
143
  multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
144
  multiplier.div_(keep_prob)
145
  return input * multiplier
146
- elif self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
147
  keep_prob = 1.0 - self.p
148
  dropout_shape = list(input.shape)
149
  for dim in self.broadcast_dims:
@@ -212,7 +226,6 @@ class LayerNorm(LayerNormBase):
212
  else:
213
  return tensor
214
 
215
-
216
  def forward(self, x: torch.Tensor) -> torch.Tensor:
217
  if self.low_precision:
218
  module_device = x.device
@@ -227,7 +240,7 @@ class LayerNorm(LayerNormBase):
227
  )
228
  else:
229
  return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
230
-
231
  def reset_parameters(self):
232
  if self.weight is not None:
233
  torch.nn.init.ones_(self.weight) # type: ignore
@@ -239,6 +252,7 @@ class RMSLayerNorm(LayerNormBase):
239
  """
240
  RMS layer norm, a simplified :class:`LayerNorm` implementation
241
  """
 
242
  def __init__(
243
  self,
244
  config: MolmoConfig,
@@ -263,7 +277,7 @@ class RMSLayerNorm(LayerNormBase):
263
  return self.weight * x
264
  else:
265
  return x
266
-
267
  def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
268
  # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
269
  # `is_autocast_cpu_enabled()` for CPU autocast.
@@ -274,7 +288,7 @@ class RMSLayerNorm(LayerNormBase):
274
  return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
275
  else:
276
  return tensor
277
-
278
  def reset_parameters(self):
279
  if self.weight is not None:
280
  torch.nn.init.ones_(self.weight) # type: ignore
@@ -293,8 +307,7 @@ class RotaryEmbedding(nn.Module):
293
  self.__cache = cache
294
  # Warm up cache.
295
  self.get_rotary_embedding(
296
- config.max_position_embeddings or config.max_sequence_length,
297
- _non_meta_init_device(config)
298
  )
299
 
300
  def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -313,8 +326,14 @@ class RotaryEmbedding(nn.Module):
313
  return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
314
 
315
  with torch.autocast(device.type, enabled=False):
316
- dim = self.config.head_dim if self.config.head_dim is not None else self.config.d_model // self.config.n_heads
317
- inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
 
 
 
 
 
 
318
  seq = torch.arange(seq_len, device=device, dtype=torch.float)
319
  freqs = einsum("i , j -> i j", seq, inv_freq)
320
  if self.config.rope_impl == "cockatoo":
@@ -346,10 +365,7 @@ class RotaryEmbedding(nn.Module):
346
  return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
347
 
348
  def forward(
349
- self,
350
- q: torch.Tensor,
351
- k: torch.Tensor,
352
- position_ids: Optional[torch.Tensor] = None
353
  ) -> Tuple[torch.Tensor, torch.Tensor]:
354
  if self.config.rope_full_precision:
355
  q_, k_ = q.float(), k.float()
@@ -360,7 +376,7 @@ class RotaryEmbedding(nn.Module):
360
  batch_size = q_.shape[0]
361
  query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
362
  if position_ids is not None:
363
- freqs_cis_len = (self.config.max_position_embeddings or self.config.max_sequence_length)
364
  else:
365
  freqs_cis_len = key_len
366
  pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
@@ -368,12 +384,8 @@ class RotaryEmbedding(nn.Module):
368
  pos_cos = pos_cos.type_as(q_)
369
  if position_ids is not None:
370
  assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
371
- pos_sin = pos_sin[0, 0][position_ids].view(
372
- (batch_size, 1, key_len, pos_sin.shape[-1])
373
- )
374
- pos_cos = pos_cos[0, 0][position_ids].view(
375
- (batch_size, 1, key_len, pos_cos.shape[-1])
376
- )
377
  q_ = self.apply_rotary_pos_emb(
378
  pos_sin[:, :, key_len - query_len : key_len, :],
379
  pos_cos[:, :, key_len - query_len : key_len, :],
@@ -466,11 +478,7 @@ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.de
466
 
467
 
468
  class MolmoAttention(nn.Module):
469
- def __init__(
470
- self,
471
- config: MolmoConfig,
472
- cache: BufferCache
473
- ):
474
  super().__init__()
475
  self.config = config
476
  self.__cache = cache
@@ -478,8 +486,7 @@ class MolmoAttention(nn.Module):
478
  self.k_norm: Optional[LayerNormBase] = None
479
  self.q_norm: Optional[LayerNormBase] = None
480
  self.hidden_size = (
481
- config.mlp_hidden_size if config.mlp_hidden_size is not None \
482
- else config.mlp_ratio * config.d_model
483
  )
484
 
485
  if config.attention_layer_norm:
@@ -508,29 +515,25 @@ class MolmoAttention(nn.Module):
508
  config.n_kv_heads * head_dim,
509
  )
510
  self.att_proj = nn.Linear(
511
- config.d_model, sum(self.fused_dims),
 
512
  bias=config.include_bias or config.qkv_bias,
513
- device=config.init_device
514
- )
515
- self.attn_out = nn.Linear(
516
- input_dim, config.d_model,
517
- bias=config.include_bias,
518
- device=config.init_device
519
  )
520
- self.attn_norm = RMSLayerNorm(
521
- config,
522
- size=config.d_model,
523
- eps=config.layer_norm_eps)
524
-
525
- self.flash_attn_func = None
526
  if self.config.attention_type == AttentionType.flash:
527
  try:
528
  from flash_attn import flash_attn_func
 
529
  self.flash_attn_func = flash_attn_func
530
  except ModuleNotFoundError:
531
  pass
532
 
533
- def attention(self,
 
534
  q: torch.Tensor,
535
  k: torch.Tensor,
536
  v: torch.Tensor,
@@ -541,7 +544,7 @@ class MolmoAttention(nn.Module):
541
  use_cache: bool = False,
542
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
543
  B, T, C = q.size() # batch size, sequence length, d_model
544
- dtype = k.dtype
545
 
546
  # Optionally apply layer norm to keys and queries.
547
  if self.q_norm is not None and self.k_norm is not None:
@@ -658,15 +661,7 @@ class MolmoAttention(nn.Module):
658
  is_causal=is_causal,
659
  )
660
 
661
- def forward(
662
- self,
663
- x,
664
- attention_bias,
665
- position_ids,
666
- drop_mask,
667
- layer_past,
668
- use_cache
669
- ):
670
  if not self.config.norm_after:
671
  atten_in = self.attn_norm(x)
672
  else:
@@ -678,54 +673,45 @@ class MolmoAttention(nn.Module):
678
  qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
679
 
680
  q, k, v = qkv.split(self.fused_dims, dim=-1)
681
-
682
  # Get attention scores.
683
  att, cache = self.attention(
684
- q, k, v,
 
 
685
  attention_bias,
686
  position_ids=position_ids,
687
  drop_mask=drop_mask,
688
  layer_past=layer_past,
689
- use_cache=use_cache
690
  )
691
-
692
  if self.config.norm_after:
693
  att = self.attn_norm(att)
694
-
695
  return att, cache
696
 
697
 
698
  class MolmoMLP(nn.Module):
699
- def __init__(
700
- self,
701
- config: MolmoConfig
702
- ):
703
  # Feed-forward input projection.
704
  super().__init__()
705
  self.config = config
706
  self.hidden_size = (
707
- config.mlp_hidden_size if config.mlp_hidden_size is not None \
708
- else config.mlp_ratio * config.d_model
709
  )
710
  self.act = SwiGLU(config)
711
  self.ff_proj = nn.Linear(
712
- config.d_model,
713
- self.hidden_size,
714
- bias=config.include_bias,
715
- device=config.init_device
716
- )
717
  self.ff_out = nn.Linear(
718
  int(self.act.output_multiplier * self.hidden_size),
719
  config.d_model,
720
  bias=config.include_bias,
721
  device=config.init_device,
722
  )
723
- self.ff_norm = RMSLayerNorm(
724
- config,
725
- size=config.d_model,
726
- eps=config.layer_norm_eps
727
- )
728
-
729
  def forward(self, x):
730
  if not self.config.norm_after:
731
  x = self.ff_norm(x)
@@ -744,12 +730,8 @@ class MolmoDecoderLayer(nn.Module):
744
  """
745
  A base class for transformer block implementations.
746
  """
747
- def __init__(
748
- self,
749
- layer_id: int,
750
- config: MolmoConfig,
751
- cache: BufferCache
752
- ):
753
  super().__init__()
754
  self.self_attn = MolmoAttention(config, cache)
755
  self.mlp = MolmoMLP(config)
@@ -763,10 +745,7 @@ class MolmoDecoderLayer(nn.Module):
763
  assert config.d_model % config.n_heads == 0
764
 
765
  # Dropout.
766
- self.dropout = Dropout(
767
- config.residual_dropout,
768
- mask_p=config.response_residual_dropout
769
- )
770
 
771
  def forward(
772
  self,
@@ -787,12 +766,12 @@ class MolmoDecoderLayer(nn.Module):
787
  """
788
 
789
  att, cache = self.self_attn(
790
- x,
791
  attention_bias=attention_bias,
792
  position_ids=position_ids,
793
  drop_mask=drop_mask,
794
  layer_past=layer_past,
795
- use_cache=use_cache
796
  )
797
  x = x + self.dropout(att, drop_mask=drop_mask)
798
  og_x = x
@@ -822,7 +801,7 @@ class MultiHeadDotProductAttention(nn.Module):
822
  super().__init__()
823
  self.config = config
824
  self.use_bias = use_bias
825
-
826
  v_cfg = config.vision_backbone
827
  self.embed_dim = v_cfg.image_emb_dim
828
  self.num_heads = v_cfg.image_num_heads
@@ -862,7 +841,7 @@ class MultiHeadDotProductAttention(nn.Module):
862
  if v_cfg.attention_dropout > 0:
863
  self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
864
  self.residual_dropout = Dropout(v_cfg.residual_dropout)
865
-
866
  def reset_parameters(self):
867
  nn.init.normal_(self.wq.weight, std=self.initializer_range)
868
  nn.init.normal_(self.wk.weight, std=self.initializer_range)
@@ -879,15 +858,15 @@ class MultiHeadDotProductAttention(nn.Module):
879
 
880
  def _merge_heads(self, hidden_states) -> torch.Tensor:
881
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
882
-
883
- def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
884
  if inputs_kv is not None:
885
  inputs_k = inputs_kv
886
  inputs_v = inputs_kv
887
  else:
888
  inputs_k = inputs_q
889
  inputs_v = inputs_q
890
-
891
  xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
892
 
893
  xq = self._split_heads(xq, self.num_heads)
@@ -918,7 +897,7 @@ class MultiHeadDotProductAttention(nn.Module):
918
  xk.transpose(1, 2).contiguous(),
919
  xv.transpose(1, 2).contiguous(),
920
  is_causal=False,
921
- dropout_p=self.config.vision_backbone.attention_dropout
922
  ).transpose(1, 2)
923
  else:
924
  raise NotImplementedError(self.config.attention_type)
@@ -940,7 +919,7 @@ class MultiHeadAttentionPool(nn.Module):
940
  output_layer: bool = True,
941
  mean_residual: bool = False,
942
  query: str = "mean",
943
- is_vit_layer: Optional[bool] = True
944
  ):
945
  super().__init__()
946
  self.config = config
@@ -950,7 +929,7 @@ class MultiHeadAttentionPool(nn.Module):
950
  self.output_layer = output_layer
951
  self.mean_residual = mean_residual
952
  self.query = query
953
-
954
  v_cfg = config.vision_backbone
955
  input_dim = v_cfg.image_emb_dim
956
  self.embed_dim = v_cfg.image_emb_dim * factor
@@ -985,7 +964,9 @@ class MultiHeadAttentionPool(nn.Module):
985
  if query == "vector":
986
  self.attention_query = nn.Parameter(
987
  torch.zeros(
988
- 1, self.num_key_value_heads * self.head_dim, device=config.init_device,
 
 
989
  ),
990
  )
991
 
@@ -1024,7 +1005,6 @@ class MultiHeadAttentionPool(nn.Module):
1024
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
1025
 
1026
  def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
1027
-
1028
  xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
1029
 
1030
  if self.query == "mean":
@@ -1093,14 +1073,14 @@ class ViTMLP(nn.Module):
1093
  bias=True,
1094
  device=config.init_device,
1095
  )
1096
-
1097
  def reset_parameters(self):
1098
  v_cfg = self.config.vision_backbone
1099
  nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
1100
  nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
1101
  nn.init.zeros_(self.w1.bias)
1102
  nn.init.zeros_(self.w2.bias)
1103
-
1104
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1105
  x = self.w1(x)
1106
  x = self.act(x)
@@ -1111,7 +1091,7 @@ class ViTMLP(nn.Module):
1111
  class MLP(nn.Module):
1112
  def __init__(self, config: MolmoConfig, input_dim: int, dropout: float = 0.0):
1113
  super().__init__()
1114
- self.config = config
1115
  self.hidden_size = (
1116
  config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
1117
  )
@@ -1135,15 +1115,15 @@ class MLP(nn.Module):
1135
  bias=False,
1136
  device=config.init_device,
1137
  )
1138
- #`MLP` assume the activation takes two inputs, so it must be a 'llama' version.
1139
  self.act = LlamaSwiGLU(config)
1140
  self.dropout = Dropout(dropout)
1141
-
1142
  def reset_parameters(self):
1143
  nn.init.normal_(self.w1.weight, std=self.initializer_range)
1144
  nn.init.normal_(self.w2.weight, std=self.initializer_range)
1145
  nn.init.normal_(self.w3.weight, std=self.initializer_range)
1146
-
1147
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1148
  x = self.w2(self.act(self.w1(x), self.w3(x)))
1149
  x = self.dropout(x)
@@ -1154,26 +1134,26 @@ class Residual(nn.Module):
1154
  def __init__(self, submodule: nn.Module):
1155
  super().__init__()
1156
  self.submodule = submodule
1157
-
1158
  def reset_parameters(self):
1159
  self.submodule.reset_parameters()
1160
-
1161
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1162
  return x + self.submodule(x)
1163
 
1164
 
1165
  class LayerNormFp32(nn.LayerNorm):
1166
- """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).
1167
- Derived from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py.
1168
- """
1169
-
1170
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1171
- orig_type = x.dtype
1172
- if self.training:
1173
- x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
1174
- else:
1175
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
1176
- return x.to(orig_type)
1177
 
1178
 
1179
  class ResidualAttentionBlock(nn.Module):
@@ -1200,7 +1180,7 @@ class ResidualAttentionBlock(nn.Module):
1200
  self.feed_forward.reset_parameters()
1201
  self.attention_norm.reset_parameters()
1202
  self.ffn_norm.reset_parameters()
1203
-
1204
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1205
  x = x + self.attention(self.attention_norm(x))
1206
  x = x + self.feed_forward(self.ffn_norm(x))
@@ -1213,10 +1193,8 @@ class BlockCollection(nn.Module):
1213
  self.config = config
1214
 
1215
  v_cfg = config.vision_backbone
1216
- self.resblocks = nn.ModuleList([
1217
- ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)
1218
- ])
1219
-
1220
  def reset_parameters(self):
1221
  for r in self.resblocks:
1222
  r.reset_parameters()
@@ -1240,7 +1218,7 @@ class VisionTransformer(nn.Module):
1240
 
1241
  v_cfg = config.vision_backbone
1242
  # class embeddings and positional embeddings
1243
- self.scale = v_cfg.image_emb_dim ** -0.5
1244
  self.class_embedding = nn.Parameter(
1245
  torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
1246
  )
@@ -1264,14 +1242,14 @@ class VisionTransformer(nn.Module):
1264
  )
1265
 
1266
  self.transformer = BlockCollection(config)
1267
-
1268
  def reset_parameters(self):
1269
  nn.init.normal_(self.class_embedding, std=self.scale)
1270
  nn.init.normal_(self.positional_embedding, std=self.scale)
1271
  nn.init.normal_(self.patch_embedding.weight, std=0.02)
1272
  self.pre_ln.reset_parameters()
1273
  self.transformer.reset_parameters()
1274
-
1275
  def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
1276
  cls_emb = self.positional_embedding[0:1]
1277
  pos_emb = self.positional_embedding[1:]
@@ -1279,7 +1257,7 @@ class VisionTransformer(nn.Module):
1279
  pos_emb = pos_emb.reshape(
1280
  (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
1281
  )
1282
-
1283
  (patch_num_0, patch_num_1) = patch_num
1284
 
1285
  if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
@@ -1287,7 +1265,11 @@ class VisionTransformer(nn.Module):
1287
  # antialias: default True in jax.image.resize
1288
  pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
1289
  pos_emb = F.interpolate(
1290
- pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
 
 
 
 
1291
  )
1292
  pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
1293
 
@@ -1355,7 +1337,7 @@ class MolmoVisionBackbone(nn.Module):
1355
  input_dim = nlayers * config.vision_backbone.image_emb_dim
1356
  else:
1357
  raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
1358
-
1359
  self.input_dim = input_dim
1360
 
1361
  self.image_projector = MLP(config, input_dim)
@@ -1380,9 +1362,11 @@ class MolmoVisionBackbone(nn.Module):
1380
  self.image_projector.reset_parameters()
1381
 
1382
  @abstractmethod
1383
- def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
 
1384
  raise NotImplementedError
1385
-
1386
 
1387
  class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1388
  def __init__(self, config: MolmoConfig):
@@ -1408,13 +1392,11 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1408
 
1409
  self.pad_embed = None
1410
  if config.image_padding_embed:
1411
- image_dim = v_cfg.image_emb_dim*len(self.config.vit_layers)
1412
  if config.image_padding_embed in ["pad_embed", "regress"]:
1413
- self.pad_embed = nn.Parameter(
1414
- torch.zeros((image_dim,), device=config.init_device))
1415
  elif config.image_padding_embed == "pad_and_partial_pad":
1416
- self.pad_embed = nn.Parameter(
1417
- torch.zeros((2, image_dim), device=config.init_device))
1418
  else:
1419
  raise ValueError(config.image_padding_embed)
1420
 
@@ -1423,7 +1405,8 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1423
  if self.config.vit_load_path:
1424
  vit_load_path = Path(self.config.vit_load_path)
1425
  state_dict_path = resource_path(
1426
- vit_load_path.parent, vit_load_path.name,
 
1427
  local_cache=vit_load_path.parent,
1428
  )
1429
  assert state_dict_path.is_file(), f"Model file {str(state_dict_path)} not found"
@@ -1441,7 +1424,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1441
  self.image_vit.reset_parameters()
1442
  if self.config.use_cls_feature:
1443
  nn.init.xavier_uniform_(self.cls_projector.weight)
1444
-
1445
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1446
  """
1447
  : param images: (batch_size, num_crops, num_patch, n_pixels)
@@ -1469,15 +1452,17 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1469
  if self.num_prefix_tokens > 0:
1470
  cls_embed = image_features[:, 0]
1471
  image_features = image_features[:, 1:]
1472
-
1473
  image_features = image_features * mask
1474
  image_features = image_features.view(B, T, N, -1)
1475
 
1476
  cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
1477
 
1478
  return image_features, cls_embed
1479
-
1480
- def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
 
1481
  cfg = self.config
1482
 
1483
  # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
@@ -1493,12 +1478,16 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1493
  image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
1494
  elif cfg.image_padding_embed == "regress":
1495
  pad_embed = self.pad_embed[None, None, None, :]
1496
- image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
 
 
1497
  elif cfg.image_padding_embed == "pad_and_partial_pad":
1498
  og_dtype = image_features.dtype
1499
  pad_embed = self.pad_embed[:, None, None, None, :]
1500
  all_pad = image_masks == 0
1501
- partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=torch.float32)
 
 
1502
  all_pad = all_pad.to(dtype=torch.float32)
1503
  image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
1504
  image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
@@ -1509,7 +1498,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1509
  image_features = self.image_feature_dropout(image_features)
1510
  if cls_embed is not None:
1511
  cls_embed = self.image_feature_dropout(cls_embed)
1512
-
1513
  image_features = image_features.reshape(
1514
  (batch_size, num_image) + cfg.vision_backbone.image_num_patch + (-1,),
1515
  )
@@ -1520,11 +1509,11 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1520
  image_features,
1521
  (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
1522
  )
1523
-
1524
  # image pooling
1525
  image_features = einops.rearrange(
1526
  image_features,
1527
- 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
1528
  dh=cfg.image_pooling_h,
1529
  dw=cfg.image_pooling_w,
1530
  )
@@ -1546,7 +1535,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1546
  image_features = module(image_features)
1547
  else:
1548
  image_features = self.image_projector(image_features)
1549
-
1550
  if self.config.use_cls_feature:
1551
  cls_embed = self.cls_projector(cls_embed)
1552
  if cfg.image_projector == ImageProjectType.mlpx2:
@@ -1554,7 +1543,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1554
  cls_embed = module(cls_embed)
1555
  else:
1556
  cls_embed = self.image_projector(cls_embed)
1557
-
1558
  # image_features: (batch_size, num_image, num_patch, d_model)
1559
  # cls_embed: (batch_size, num_image, d_model)
1560
  return image_features, cls_embed
@@ -1579,11 +1568,7 @@ class MolmoPretrainedModel(PreTrainedModel):
1579
 
1580
 
1581
  class MolmoModel(MolmoPretrainedModel):
1582
- def __init__(
1583
- self,
1584
- config: MolmoConfig,
1585
- init_params: bool = True
1586
- ):
1587
  super().__init__(config)
1588
  self.config = config
1589
  self.__cache = BufferCache()
@@ -1616,10 +1601,10 @@ class MolmoModel(MolmoPretrainedModel):
1616
  config.d_model,
1617
  device=config.init_device,
1618
  initializer_range=config.initializer_range,
1619
- new_embed_initializer_range=config.new_embedding_init_range
1620
  )
1621
  else:
1622
- wte=nn.Embedding(
1623
  config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1624
  )
1625
 
@@ -1627,26 +1612,20 @@ class MolmoModel(MolmoPretrainedModel):
1627
  dict(
1628
  wte=wte,
1629
  emb_drop=Dropout(config.embedding_dropout),
1630
- ln_f=RMSLayerNorm(
1631
- config,
1632
- size=config.d_model,
1633
- eps=config.layer_norm_eps),
1634
  )
1635
  )
1636
 
1637
- layers = [
1638
- MolmoDecoderLayer(i, config, self.__cache) \
1639
- for i in range(config.n_layers)
1640
- ]
1641
  self.transformer.update({"layers": nn.ModuleList(layers)})
1642
-
1643
  self.vision_backbone: Optional[MolmoVisionBackbone] = None
1644
  if config.vision_backbone is not None:
1645
  self.vision_backbone = MolmoVisionBackbone.build(config)
1646
 
1647
  if self.vision_backbone is not None:
1648
  self.vision_backbone.reset_with_pretrained_weights()
1649
-
1650
  @property
1651
  def device(self) -> torch.device:
1652
  device: torch.device = self.transformer.wte.weight.device # type: ignore
@@ -1655,7 +1634,6 @@ class MolmoModel(MolmoPretrainedModel):
1655
  else:
1656
  return device
1657
 
1658
-
1659
  def forward(
1660
  self,
1661
  input_ids: torch.LongTensor,
@@ -1716,7 +1694,9 @@ class MolmoModel(MolmoPretrainedModel):
1716
  has_image = images is not None
1717
 
1718
  assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
1719
- assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images."
 
 
1720
 
1721
  batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1722
  if past_key_values is None:
@@ -1730,16 +1710,17 @@ class MolmoModel(MolmoPretrainedModel):
1730
 
1731
  if self.config.use_position_ids and attention_mask is None:
1732
  attention_mask = input_ids != -1
1733
-
1734
  if subsegment_ids is not None:
1735
  assert not use_cache, "Subsegment_ids cannot be used with cache."
1736
  subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
1737
  attention_mask = (
1738
- subsegment_mask.to(attention_mask.dtype) *
1739
- attention_mask.unsqueeze(2) *
1740
- attention_mask.unsqueeze(1))
 
1741
  if position_ids is None:
1742
- raise ValueError(f"Positioned ids must be given if using subsegment_ids")
1743
  else:
1744
  if self.config.use_position_ids and position_ids is None:
1745
  position_ids = torch.clamp(
@@ -1776,10 +1757,8 @@ class MolmoModel(MolmoPretrainedModel):
1776
 
1777
  if self.config.use_cls_feature:
1778
  x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
1779
-
1780
- valid_images = torch.any(
1781
- (image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1
1782
- )
1783
  valid_images = valid_images.to(attention_mask.dtype)
1784
  attention_mask = torch.cat(
1785
  [attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
@@ -1796,13 +1775,13 @@ class MolmoModel(MolmoPretrainedModel):
1796
 
1797
  # normalized
1798
  if self.config.normalize_input_embeds:
1799
- x = x * (self.config.d_model ** 0.5)
1800
 
1801
  # Transform the attention mask into what the blocks expect.
1802
  if attention_mask is not None:
1803
  # shape: (batch_size, 1, 1, seq_len)
1804
  if len(attention_mask.shape) == 2:
1805
- attention_mask = attention_mask[:, :past_length + seq_len]
1806
  attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1807
  else:
1808
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
@@ -1852,16 +1831,23 @@ class MolmoModel(MolmoPretrainedModel):
1852
 
1853
  layer_past = None if past_key_values is None else past_key_values[block_idx]
1854
  # shape: (batch_size, seq_len, d_model)
1855
- x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
 
 
 
 
 
 
 
1856
 
1857
  if attn_key_values is not None:
1858
  assert cache is not None
1859
  attn_key_values.append(cache)
1860
-
1861
  if images is not None and self.config.use_cls_feature:
1862
  assert num_image is not None
1863
  x = torch.cat(
1864
- [x[:, :1], x[:, num_image+1:], torch.zeros_like(x[:, :num_image])],
1865
  dim=1,
1866
  )
1867
 
@@ -1869,7 +1855,8 @@ class MolmoModel(MolmoPretrainedModel):
1869
  # shape: (batch_size, 1, d_model)
1870
  if append_last_valid_logits is not None:
1871
  last_valid_output = x[
1872
- torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)]
 
1873
  x = last_valid_output.unsqueeze(1)
1874
  else:
1875
  x = x[:, -1, :].unsqueeze(1)
@@ -1886,23 +1873,20 @@ class MolmoModel(MolmoPretrainedModel):
1886
  return MolmoOutput(
1887
  last_hidden_states=x,
1888
  attn_key_values=attn_key_values,
1889
- hidden_states=tuple(all_hidden_states) \
1890
- if output_hidden_states else None
1891
- )
1892
 
1893
 
1894
  class MolmoForCausalLM(PreTrainedModel):
1895
  """
1896
  Extremely barebones HF model wrapper.
1897
  """
 
1898
  config_class = MolmoConfig
1899
  base_model_prefix = "model"
1900
  _no_split_modules = ["MolmoDecoderLayer"]
1901
 
1902
- def __init__(
1903
- self,
1904
- config: MolmoConfig
1905
- ):
1906
  super().__init__(config)
1907
  # model_config = create_model_config_from_pretrained_config(config)
1908
  # Initialize model (always on CPU to start with so we don't run out of GPU memory).
@@ -1972,7 +1956,7 @@ class MolmoForCausalLM(PreTrainedModel):
1972
  output_hidden_states=output_hidden_states,
1973
  append_last_valid_logits=append_last_valid_logits,
1974
  )
1975
-
1976
  x = outputs.last_hidden_states
1977
  if self.config.weight_tying:
1978
  logits = F.linear(x, self.model.transformer.wte.weight, None) # type: ignore
@@ -1981,15 +1965,16 @@ class MolmoForCausalLM(PreTrainedModel):
1981
 
1982
  if self.config.scale_logits:
1983
  logits.mul_(1 / math.sqrt(self.config.d_model))
1984
-
1985
  if self.config.final_logit_softcapping is not None:
1986
  logits = logits / self.config.final_logit_softcapping
1987
  logits = torch.tanh(logits)
1988
  logits = logits * self.config.final_logit_softcapping
1989
-
1990
  if not last_logits_only and append_last_valid_logits is not None:
1991
  last_valid_logit = logits[
1992
- torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits]
 
1993
  logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
1994
 
1995
  loss = None
@@ -2001,7 +1986,7 @@ class MolmoForCausalLM(PreTrainedModel):
2001
  labels.masked_fill_(~(loss_masks > 0), -100)
2002
  labels = labels.view(-1)
2003
  logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
2004
- loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
2005
  loss = loss_fct(logits_for_loss, labels)
2006
  loss = loss.view(input_ids.shape[0], -1)
2007
  loss = loss * loss_masks
@@ -2063,10 +2048,7 @@ class MolmoForCausalLM(PreTrainedModel):
2063
  append_last_valid_logits: Optional[torch.Tensor] = None
2064
  if self.config.use_position_ids and attention_mask is None:
2065
  attention_mask = input_ids != -1
2066
- position_ids = torch.clamp(
2067
- torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
2068
- min=0
2069
- )
2070
  append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
2071
  attention_mask = torch.cat(
2072
  [attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
@@ -2074,7 +2056,7 @@ class MolmoForCausalLM(PreTrainedModel):
2074
  )
2075
  if attention_mask is not None:
2076
  assert attention_mask.shape == (batch_size, mask_len)
2077
-
2078
  out = super().generate(
2079
  input_ids,
2080
  generation_config,
@@ -2088,7 +2070,7 @@ class MolmoForCausalLM(PreTrainedModel):
2088
  )
2089
 
2090
  return out
2091
-
2092
  def prepare_inputs_for_generation(
2093
  self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
2094
  ):
@@ -2116,7 +2098,7 @@ class MolmoForCausalLM(PreTrainedModel):
2116
  model_inputs["image_masks"] = image_masks
2117
  model_inputs["image_input_idx"] = image_input_idx
2118
  model_inputs["append_last_valid_logits"] = append_last_valid_logits
2119
- else:
2120
  model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
2121
 
2122
  model_inputs.update(kwargs)
@@ -2236,7 +2218,4 @@ class MolmoForCausalLM(PreTrainedModel):
2236
  # Tie weights again if needed
2237
  self.tie_weights()
2238
 
2239
- return model_embeds
2240
-
2241
- # Always register for multi-modal features
2242
- AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)
 
32
  from transformers import PreTrainedModel
33
  from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
34
 
35
+ # from olmo.util import resource_path
36
  from .configuration_molmo import (
37
  MolmoConfig,
38
  VisionBackboneConfig,
39
  VisionBackboneType,
40
  ImagePooling2DType,
41
+ ImageProjectType,
42
  AttentionType,
43
  MolmoConfigurationError,
44
  )
 
54
  log = logging.getLogger(__name__)
55
 
56
 
57
+ def resource_path(
58
+ folder: Union[str, Path],
59
+ fname: str,
60
+ local_cache: Optional[Union[str, Path]] = None,
61
+ ) -> Path:
62
+ if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
63
+ log.info(f"Found local cache of {fname} at {local_path}")
64
+ return local_path
65
+ else:
66
+ from cached_path import cached_path
67
+
68
+ return cached_path(f"{str(folder).rstrip('/')}/{fname}")
69
+
70
+
71
  def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
72
  """
73
  Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
 
120
  def reset_parameters(self):
121
  nn.init.normal_(self.embedding, std=self.initializer_range)
122
  nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
123
+
124
  def forward(self, x: torch.Tensor) -> torch.Tensor:
125
  return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
126
 
 
145
  if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
146
  return input
147
  else:
148
+ if self.mask_p > 0.0 and self.training:
149
  assert drop_mask is not None
150
  drop_mask = drop_mask.to(input.dtype)
151
  keep_prob = 1.0 - self.p
 
157
  multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
158
  multiplier.div_(keep_prob)
159
  return input * multiplier
160
+ elif self.p > 0.0 and len(self.broadcast_dims) > 0 and self.training:
161
  keep_prob = 1.0 - self.p
162
  dropout_shape = list(input.shape)
163
  for dim in self.broadcast_dims:
 
226
  else:
227
  return tensor
228
 
 
229
  def forward(self, x: torch.Tensor) -> torch.Tensor:
230
  if self.low_precision:
231
  module_device = x.device
 
240
  )
241
  else:
242
  return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
243
+
244
  def reset_parameters(self):
245
  if self.weight is not None:
246
  torch.nn.init.ones_(self.weight) # type: ignore
 
252
  """
253
  RMS layer norm, a simplified :class:`LayerNorm` implementation
254
  """
255
+
256
  def __init__(
257
  self,
258
  config: MolmoConfig,
 
277
  return self.weight * x
278
  else:
279
  return x
280
+
281
  def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
282
  # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
283
  # `is_autocast_cpu_enabled()` for CPU autocast.
 
288
  return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
289
  else:
290
  return tensor
291
+
292
  def reset_parameters(self):
293
  if self.weight is not None:
294
  torch.nn.init.ones_(self.weight) # type: ignore
 
307
  self.__cache = cache
308
  # Warm up cache.
309
  self.get_rotary_embedding(
310
+ config.max_position_embeddings or config.max_sequence_length, _non_meta_init_device(config)
 
311
  )
312
 
313
  def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
 
326
  return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
327
 
328
  with torch.autocast(device.type, enabled=False):
329
+ dim = (
330
+ self.config.head_dim
331
+ if self.config.head_dim is not None
332
+ else self.config.d_model // self.config.n_heads
333
+ )
334
+ inv_freq = 1.0 / (
335
+ self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
336
+ )
337
  seq = torch.arange(seq_len, device=device, dtype=torch.float)
338
  freqs = einsum("i , j -> i j", seq, inv_freq)
339
  if self.config.rope_impl == "cockatoo":
 
365
  return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
366
 
367
  def forward(
368
+ self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None
 
 
 
369
  ) -> Tuple[torch.Tensor, torch.Tensor]:
370
  if self.config.rope_full_precision:
371
  q_, k_ = q.float(), k.float()
 
376
  batch_size = q_.shape[0]
377
  query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
378
  if position_ids is not None:
379
+ freqs_cis_len = self.config.max_position_embeddings or self.config.max_sequence_length
380
  else:
381
  freqs_cis_len = key_len
382
  pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
 
384
  pos_cos = pos_cos.type_as(q_)
385
  if position_ids is not None:
386
  assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
387
+ pos_sin = pos_sin[0, 0][position_ids].view((batch_size, 1, key_len, pos_sin.shape[-1]))
388
+ pos_cos = pos_cos[0, 0][position_ids].view((batch_size, 1, key_len, pos_cos.shape[-1]))
 
 
 
 
389
  q_ = self.apply_rotary_pos_emb(
390
  pos_sin[:, :, key_len - query_len : key_len, :],
391
  pos_cos[:, :, key_len - query_len : key_len, :],
 
478
 
479
 
480
  class MolmoAttention(nn.Module):
481
+ def __init__(self, config: MolmoConfig, cache: BufferCache):
 
 
 
 
482
  super().__init__()
483
  self.config = config
484
  self.__cache = cache
 
486
  self.k_norm: Optional[LayerNormBase] = None
487
  self.q_norm: Optional[LayerNormBase] = None
488
  self.hidden_size = (
489
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
 
490
  )
491
 
492
  if config.attention_layer_norm:
 
515
  config.n_kv_heads * head_dim,
516
  )
517
  self.att_proj = nn.Linear(
518
+ config.d_model,
519
+ sum(self.fused_dims),
520
  bias=config.include_bias or config.qkv_bias,
521
+ device=config.init_device,
 
 
 
 
 
522
  )
523
+ self.attn_out = nn.Linear(input_dim, config.d_model, bias=config.include_bias, device=config.init_device)
524
+ self.attn_norm = RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps)
525
+
526
+ self.flash_attn_func = None
 
 
527
  if self.config.attention_type == AttentionType.flash:
528
  try:
529
  from flash_attn import flash_attn_func
530
+
531
  self.flash_attn_func = flash_attn_func
532
  except ModuleNotFoundError:
533
  pass
534
 
535
+ def attention(
536
+ self,
537
  q: torch.Tensor,
538
  k: torch.Tensor,
539
  v: torch.Tensor,
 
544
  use_cache: bool = False,
545
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
546
  B, T, C = q.size() # batch size, sequence length, d_model
547
+ dtype = k.dtype
548
 
549
  # Optionally apply layer norm to keys and queries.
550
  if self.q_norm is not None and self.k_norm is not None:
 
661
  is_causal=is_causal,
662
  )
663
 
664
+ def forward(self, x, attention_bias, position_ids, drop_mask, layer_past, use_cache):
 
 
 
 
 
 
 
 
665
  if not self.config.norm_after:
666
  atten_in = self.attn_norm(x)
667
  else:
 
673
  qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
674
 
675
  q, k, v = qkv.split(self.fused_dims, dim=-1)
676
+
677
  # Get attention scores.
678
  att, cache = self.attention(
679
+ q,
680
+ k,
681
+ v,
682
  attention_bias,
683
  position_ids=position_ids,
684
  drop_mask=drop_mask,
685
  layer_past=layer_past,
686
+ use_cache=use_cache,
687
  )
688
+
689
  if self.config.norm_after:
690
  att = self.attn_norm(att)
691
+
692
  return att, cache
693
 
694
 
695
  class MolmoMLP(nn.Module):
696
+ def __init__(self, config: MolmoConfig):
 
 
 
697
  # Feed-forward input projection.
698
  super().__init__()
699
  self.config = config
700
  self.hidden_size = (
701
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
 
702
  )
703
  self.act = SwiGLU(config)
704
  self.ff_proj = nn.Linear(
705
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
706
+ )
 
 
 
707
  self.ff_out = nn.Linear(
708
  int(self.act.output_multiplier * self.hidden_size),
709
  config.d_model,
710
  bias=config.include_bias,
711
  device=config.init_device,
712
  )
713
+ self.ff_norm = RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps)
714
+
 
 
 
 
715
  def forward(self, x):
716
  if not self.config.norm_after:
717
  x = self.ff_norm(x)
 
730
  """
731
  A base class for transformer block implementations.
732
  """
733
+
734
+ def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
 
 
 
 
735
  super().__init__()
736
  self.self_attn = MolmoAttention(config, cache)
737
  self.mlp = MolmoMLP(config)
 
745
  assert config.d_model % config.n_heads == 0
746
 
747
  # Dropout.
748
+ self.dropout = Dropout(config.residual_dropout, mask_p=config.response_residual_dropout)
 
 
 
749
 
750
  def forward(
751
  self,
 
766
  """
767
 
768
  att, cache = self.self_attn(
769
+ x,
770
  attention_bias=attention_bias,
771
  position_ids=position_ids,
772
  drop_mask=drop_mask,
773
  layer_past=layer_past,
774
+ use_cache=use_cache,
775
  )
776
  x = x + self.dropout(att, drop_mask=drop_mask)
777
  og_x = x
 
801
  super().__init__()
802
  self.config = config
803
  self.use_bias = use_bias
804
+
805
  v_cfg = config.vision_backbone
806
  self.embed_dim = v_cfg.image_emb_dim
807
  self.num_heads = v_cfg.image_num_heads
 
841
  if v_cfg.attention_dropout > 0:
842
  self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
843
  self.residual_dropout = Dropout(v_cfg.residual_dropout)
844
+
845
  def reset_parameters(self):
846
  nn.init.normal_(self.wq.weight, std=self.initializer_range)
847
  nn.init.normal_(self.wk.weight, std=self.initializer_range)
 
858
 
859
  def _merge_heads(self, hidden_states) -> torch.Tensor:
860
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
861
+
862
+ def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
863
  if inputs_kv is not None:
864
  inputs_k = inputs_kv
865
  inputs_v = inputs_kv
866
  else:
867
  inputs_k = inputs_q
868
  inputs_v = inputs_q
869
+
870
  xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
871
 
872
  xq = self._split_heads(xq, self.num_heads)
 
897
  xk.transpose(1, 2).contiguous(),
898
  xv.transpose(1, 2).contiguous(),
899
  is_causal=False,
900
+ dropout_p=self.config.vision_backbone.attention_dropout,
901
  ).transpose(1, 2)
902
  else:
903
  raise NotImplementedError(self.config.attention_type)
 
919
  output_layer: bool = True,
920
  mean_residual: bool = False,
921
  query: str = "mean",
922
+ is_vit_layer: Optional[bool] = True,
923
  ):
924
  super().__init__()
925
  self.config = config
 
929
  self.output_layer = output_layer
930
  self.mean_residual = mean_residual
931
  self.query = query
932
+
933
  v_cfg = config.vision_backbone
934
  input_dim = v_cfg.image_emb_dim
935
  self.embed_dim = v_cfg.image_emb_dim * factor
 
964
  if query == "vector":
965
  self.attention_query = nn.Parameter(
966
  torch.zeros(
967
+ 1,
968
+ self.num_key_value_heads * self.head_dim,
969
+ device=config.init_device,
970
  ),
971
  )
972
 
 
1005
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
1006
 
1007
  def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
 
1008
  xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
1009
 
1010
  if self.query == "mean":
 
1073
  bias=True,
1074
  device=config.init_device,
1075
  )
1076
+
1077
  def reset_parameters(self):
1078
  v_cfg = self.config.vision_backbone
1079
  nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
1080
  nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
1081
  nn.init.zeros_(self.w1.bias)
1082
  nn.init.zeros_(self.w2.bias)
1083
+
1084
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1085
  x = self.w1(x)
1086
  x = self.act(x)
 
1091
  class MLP(nn.Module):
1092
  def __init__(self, config: MolmoConfig, input_dim: int, dropout: float = 0.0):
1093
  super().__init__()
1094
+ self.config = config
1095
  self.hidden_size = (
1096
  config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
1097
  )
 
1115
  bias=False,
1116
  device=config.init_device,
1117
  )
1118
+ # `MLP` assume the activation takes two inputs, so it must be a 'llama' version.
1119
  self.act = LlamaSwiGLU(config)
1120
  self.dropout = Dropout(dropout)
1121
+
1122
  def reset_parameters(self):
1123
  nn.init.normal_(self.w1.weight, std=self.initializer_range)
1124
  nn.init.normal_(self.w2.weight, std=self.initializer_range)
1125
  nn.init.normal_(self.w3.weight, std=self.initializer_range)
1126
+
1127
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1128
  x = self.w2(self.act(self.w1(x), self.w3(x)))
1129
  x = self.dropout(x)
 
1134
  def __init__(self, submodule: nn.Module):
1135
  super().__init__()
1136
  self.submodule = submodule
1137
+
1138
  def reset_parameters(self):
1139
  self.submodule.reset_parameters()
1140
+
1141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1142
  return x + self.submodule(x)
1143
 
1144
 
1145
  class LayerNormFp32(nn.LayerNorm):
1146
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).
1147
+ Derived from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py.
1148
+ """
1149
+
1150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1151
+ orig_type = x.dtype
1152
+ if self.training:
1153
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
1154
+ else:
1155
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
1156
+ return x.to(orig_type)
1157
 
1158
 
1159
  class ResidualAttentionBlock(nn.Module):
 
1180
  self.feed_forward.reset_parameters()
1181
  self.attention_norm.reset_parameters()
1182
  self.ffn_norm.reset_parameters()
1183
+
1184
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1185
  x = x + self.attention(self.attention_norm(x))
1186
  x = x + self.feed_forward(self.ffn_norm(x))
 
1193
  self.config = config
1194
 
1195
  v_cfg = config.vision_backbone
1196
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
1197
+
 
 
1198
  def reset_parameters(self):
1199
  for r in self.resblocks:
1200
  r.reset_parameters()
 
1218
 
1219
  v_cfg = config.vision_backbone
1220
  # class embeddings and positional embeddings
1221
+ self.scale = v_cfg.image_emb_dim**-0.5
1222
  self.class_embedding = nn.Parameter(
1223
  torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
1224
  )
 
1242
  )
1243
 
1244
  self.transformer = BlockCollection(config)
1245
+
1246
  def reset_parameters(self):
1247
  nn.init.normal_(self.class_embedding, std=self.scale)
1248
  nn.init.normal_(self.positional_embedding, std=self.scale)
1249
  nn.init.normal_(self.patch_embedding.weight, std=0.02)
1250
  self.pre_ln.reset_parameters()
1251
  self.transformer.reset_parameters()
1252
+
1253
  def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
1254
  cls_emb = self.positional_embedding[0:1]
1255
  pos_emb = self.positional_embedding[1:]
 
1257
  pos_emb = pos_emb.reshape(
1258
  (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
1259
  )
1260
+
1261
  (patch_num_0, patch_num_1) = patch_num
1262
 
1263
  if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
 
1265
  # antialias: default True in jax.image.resize
1266
  pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
1267
  pos_emb = F.interpolate(
1268
+ pos_emb,
1269
+ size=(patch_num_0, patch_num_1),
1270
+ mode="bicubic",
1271
+ align_corners=False,
1272
+ antialias=True,
1273
  )
1274
  pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
1275
 
 
1337
  input_dim = nlayers * config.vision_backbone.image_emb_dim
1338
  else:
1339
  raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
1340
+
1341
  self.input_dim = input_dim
1342
 
1343
  self.image_projector = MLP(config, input_dim)
 
1362
  self.image_projector.reset_parameters()
1363
 
1364
  @abstractmethod
1365
+ def forward(
1366
+ self, images: torch.Tensor, image_masks: torch.Tensor
1367
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1368
  raise NotImplementedError
1369
+
1370
 
1371
  class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1372
  def __init__(self, config: MolmoConfig):
 
1392
 
1393
  self.pad_embed = None
1394
  if config.image_padding_embed:
1395
+ image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers)
1396
  if config.image_padding_embed in ["pad_embed", "regress"]:
1397
+ self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
 
1398
  elif config.image_padding_embed == "pad_and_partial_pad":
1399
+ self.pad_embed = nn.Parameter(torch.zeros((2, image_dim), device=config.init_device))
 
1400
  else:
1401
  raise ValueError(config.image_padding_embed)
1402
 
 
1405
  if self.config.vit_load_path:
1406
  vit_load_path = Path(self.config.vit_load_path)
1407
  state_dict_path = resource_path(
1408
+ vit_load_path.parent,
1409
+ vit_load_path.name,
1410
  local_cache=vit_load_path.parent,
1411
  )
1412
  assert state_dict_path.is_file(), f"Model file {str(state_dict_path)} not found"
 
1424
  self.image_vit.reset_parameters()
1425
  if self.config.use_cls_feature:
1426
  nn.init.xavier_uniform_(self.cls_projector.weight)
1427
+
1428
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1429
  """
1430
  : param images: (batch_size, num_crops, num_patch, n_pixels)
 
1452
  if self.num_prefix_tokens > 0:
1453
  cls_embed = image_features[:, 0]
1454
  image_features = image_features[:, 1:]
1455
+
1456
  image_features = image_features * mask
1457
  image_features = image_features.view(B, T, N, -1)
1458
 
1459
  cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
1460
 
1461
  return image_features, cls_embed
1462
+
1463
+ def forward(
1464
+ self, images: torch.Tensor, image_masks: torch.Tensor
1465
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1466
  cfg = self.config
1467
 
1468
  # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
 
1478
  image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
1479
  elif cfg.image_padding_embed == "regress":
1480
  pad_embed = self.pad_embed[None, None, None, :]
1481
+ image_features = image_features + pad_embed * torch.unsqueeze(
1482
+ torch.maximum(image_masks, torch.zeros_like(image_masks)), -1
1483
+ )
1484
  elif cfg.image_padding_embed == "pad_and_partial_pad":
1485
  og_dtype = image_features.dtype
1486
  pad_embed = self.pad_embed[:, None, None, None, :]
1487
  all_pad = image_masks == 0
1488
+ partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
1489
+ dtype=torch.float32
1490
+ )
1491
  all_pad = all_pad.to(dtype=torch.float32)
1492
  image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
1493
  image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
 
1498
  image_features = self.image_feature_dropout(image_features)
1499
  if cls_embed is not None:
1500
  cls_embed = self.image_feature_dropout(cls_embed)
1501
+
1502
  image_features = image_features.reshape(
1503
  (batch_size, num_image) + cfg.vision_backbone.image_num_patch + (-1,),
1504
  )
 
1509
  image_features,
1510
  (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
1511
  )
1512
+
1513
  # image pooling
1514
  image_features = einops.rearrange(
1515
  image_features,
1516
+ "b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
1517
  dh=cfg.image_pooling_h,
1518
  dw=cfg.image_pooling_w,
1519
  )
 
1535
  image_features = module(image_features)
1536
  else:
1537
  image_features = self.image_projector(image_features)
1538
+
1539
  if self.config.use_cls_feature:
1540
  cls_embed = self.cls_projector(cls_embed)
1541
  if cfg.image_projector == ImageProjectType.mlpx2:
 
1543
  cls_embed = module(cls_embed)
1544
  else:
1545
  cls_embed = self.image_projector(cls_embed)
1546
+
1547
  # image_features: (batch_size, num_image, num_patch, d_model)
1548
  # cls_embed: (batch_size, num_image, d_model)
1549
  return image_features, cls_embed
 
1568
 
1569
 
1570
  class MolmoModel(MolmoPretrainedModel):
1571
+ def __init__(self, config: MolmoConfig, init_params: bool = True):
 
 
 
 
1572
  super().__init__(config)
1573
  self.config = config
1574
  self.__cache = BufferCache()
 
1601
  config.d_model,
1602
  device=config.init_device,
1603
  initializer_range=config.initializer_range,
1604
+ new_embed_initializer_range=config.new_embedding_init_range,
1605
  )
1606
  else:
1607
+ wte = nn.Embedding(
1608
  config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1609
  )
1610
 
 
1612
  dict(
1613
  wte=wte,
1614
  emb_drop=Dropout(config.embedding_dropout),
1615
+ ln_f=RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps),
 
 
 
1616
  )
1617
  )
1618
 
1619
+ layers = [MolmoDecoderLayer(i, config, self.__cache) for i in range(config.n_layers)]
 
 
 
1620
  self.transformer.update({"layers": nn.ModuleList(layers)})
1621
+
1622
  self.vision_backbone: Optional[MolmoVisionBackbone] = None
1623
  if config.vision_backbone is not None:
1624
  self.vision_backbone = MolmoVisionBackbone.build(config)
1625
 
1626
  if self.vision_backbone is not None:
1627
  self.vision_backbone.reset_with_pretrained_weights()
1628
+
1629
  @property
1630
  def device(self) -> torch.device:
1631
  device: torch.device = self.transformer.wte.weight.device # type: ignore
 
1634
  else:
1635
  return device
1636
 
 
1637
  def forward(
1638
  self,
1639
  input_ids: torch.LongTensor,
 
1694
  has_image = images is not None
1695
 
1696
  assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
1697
+ assert not (
1698
+ has_image and past_key_values is not None
1699
+ ), "Cached key and values should not be used with images."
1700
 
1701
  batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1702
  if past_key_values is None:
 
1710
 
1711
  if self.config.use_position_ids and attention_mask is None:
1712
  attention_mask = input_ids != -1
1713
+
1714
  if subsegment_ids is not None:
1715
  assert not use_cache, "Subsegment_ids cannot be used with cache."
1716
  subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
1717
  attention_mask = (
1718
+ subsegment_mask.to(attention_mask.dtype)
1719
+ * attention_mask.unsqueeze(2)
1720
+ * attention_mask.unsqueeze(1)
1721
+ )
1722
  if position_ids is None:
1723
+ raise ValueError("Positioned ids must be given if using subsegment_ids")
1724
  else:
1725
  if self.config.use_position_ids and position_ids is None:
1726
  position_ids = torch.clamp(
 
1757
 
1758
  if self.config.use_cls_feature:
1759
  x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
1760
+
1761
+ valid_images = torch.any((image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1)
 
 
1762
  valid_images = valid_images.to(attention_mask.dtype)
1763
  attention_mask = torch.cat(
1764
  [attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
 
1775
 
1776
  # normalized
1777
  if self.config.normalize_input_embeds:
1778
+ x = x * (self.config.d_model**0.5)
1779
 
1780
  # Transform the attention mask into what the blocks expect.
1781
  if attention_mask is not None:
1782
  # shape: (batch_size, 1, 1, seq_len)
1783
  if len(attention_mask.shape) == 2:
1784
+ attention_mask = attention_mask[:, : past_length + seq_len]
1785
  attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1786
  else:
1787
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
 
1831
 
1832
  layer_past = None if past_key_values is None else past_key_values[block_idx]
1833
  # shape: (batch_size, seq_len, d_model)
1834
+ x, cache = layer(
1835
+ x,
1836
+ attention_bias=attention_bias,
1837
+ position_ids=position_ids,
1838
+ drop_mask=response_mask,
1839
+ layer_past=layer_past,
1840
+ use_cache=use_cache,
1841
+ )
1842
 
1843
  if attn_key_values is not None:
1844
  assert cache is not None
1845
  attn_key_values.append(cache)
1846
+
1847
  if images is not None and self.config.use_cls_feature:
1848
  assert num_image is not None
1849
  x = torch.cat(
1850
+ [x[:, :1], x[:, num_image + 1 :], torch.zeros_like(x[:, :num_image])],
1851
  dim=1,
1852
  )
1853
 
 
1855
  # shape: (batch_size, 1, d_model)
1856
  if append_last_valid_logits is not None:
1857
  last_valid_output = x[
1858
+ torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)
1859
+ ]
1860
  x = last_valid_output.unsqueeze(1)
1861
  else:
1862
  x = x[:, -1, :].unsqueeze(1)
 
1873
  return MolmoOutput(
1874
  last_hidden_states=x,
1875
  attn_key_values=attn_key_values,
1876
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1877
+ )
 
1878
 
1879
 
1880
  class MolmoForCausalLM(PreTrainedModel):
1881
  """
1882
  Extremely barebones HF model wrapper.
1883
  """
1884
+
1885
  config_class = MolmoConfig
1886
  base_model_prefix = "model"
1887
  _no_split_modules = ["MolmoDecoderLayer"]
1888
 
1889
+ def __init__(self, config: MolmoConfig):
 
 
 
1890
  super().__init__(config)
1891
  # model_config = create_model_config_from_pretrained_config(config)
1892
  # Initialize model (always on CPU to start with so we don't run out of GPU memory).
 
1956
  output_hidden_states=output_hidden_states,
1957
  append_last_valid_logits=append_last_valid_logits,
1958
  )
1959
+
1960
  x = outputs.last_hidden_states
1961
  if self.config.weight_tying:
1962
  logits = F.linear(x, self.model.transformer.wte.weight, None) # type: ignore
 
1965
 
1966
  if self.config.scale_logits:
1967
  logits.mul_(1 / math.sqrt(self.config.d_model))
1968
+
1969
  if self.config.final_logit_softcapping is not None:
1970
  logits = logits / self.config.final_logit_softcapping
1971
  logits = torch.tanh(logits)
1972
  logits = logits * self.config.final_logit_softcapping
1973
+
1974
  if not last_logits_only and append_last_valid_logits is not None:
1975
  last_valid_logit = logits[
1976
+ torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits
1977
+ ]
1978
  logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
1979
 
1980
  loss = None
 
1986
  labels.masked_fill_(~(loss_masks > 0), -100)
1987
  labels = labels.view(-1)
1988
  logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
1989
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
1990
  loss = loss_fct(logits_for_loss, labels)
1991
  loss = loss.view(input_ids.shape[0], -1)
1992
  loss = loss * loss_masks
 
2048
  append_last_valid_logits: Optional[torch.Tensor] = None
2049
  if self.config.use_position_ids and attention_mask is None:
2050
  attention_mask = input_ids != -1
2051
+ position_ids = torch.clamp(torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0)
 
 
 
2052
  append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
2053
  attention_mask = torch.cat(
2054
  [attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
 
2056
  )
2057
  if attention_mask is not None:
2058
  assert attention_mask.shape == (batch_size, mask_len)
2059
+
2060
  out = super().generate(
2061
  input_ids,
2062
  generation_config,
 
2070
  )
2071
 
2072
  return out
2073
+
2074
  def prepare_inputs_for_generation(
2075
  self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
2076
  ):
 
2098
  model_inputs["image_masks"] = image_masks
2099
  model_inputs["image_input_idx"] = image_input_idx
2100
  model_inputs["append_last_valid_logits"] = append_last_valid_logits
2101
+ else:
2102
  model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
2103
 
2104
  model_inputs.update(kwargs)
 
2218
  # Tie weights again if needed
2219
  self.tie_weights()
2220
 
2221
+ return model_embeds