chrisc36 commited on
Commit
9eb32aa
1 Parent(s): 6b320f7

Upload modeling_molmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_molmo.py +7 -15
modeling_molmo.py CHANGED
@@ -725,17 +725,6 @@ def _expand_token(token, batch_size: int):
725
  return token.view(1, 1, -1).expand(batch_size, -1, -1)
726
 
727
 
728
- class LayerNormFp32(nn.LayerNorm):
729
- """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).
730
- Derived from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py.
731
- """
732
-
733
- def forward(self, x: torch.Tensor) -> torch.Tensor:
734
- orig_type = x.dtype
735
- x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
736
- return x.to(orig_type)
737
-
738
-
739
  class ViTMLP(nn.Module):
740
  def __init__(self, config: FullMolmoConfig):
741
  super().__init__()
@@ -855,10 +844,9 @@ class VisionTransformer(nn.Module):
855
  device=config.init_device,
856
  )
857
 
858
- self.pre_ln = LayerNormFp32(
859
  v_cfg.image_emb_dim,
860
  eps=v_cfg.image_norm_eps,
861
- device=config.init_device,
862
  )
863
 
864
  self.transformer = BlockCollection(config)
@@ -1013,6 +1001,8 @@ class MultiHeadDotProductAttention(nn.Module):
1013
  attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
1014
 
1015
  elif self.config.attention_type == "sdpa":
 
 
1016
  attn_output = F.scaled_dot_product_attention(
1017
  xq.transpose(1, 2).contiguous(),
1018
  xk.transpose(1, 2).contiguous(),
@@ -1389,8 +1379,8 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
1389
  elif cfg.image_padding_embed == "pad_and_partial_pad":
1390
  pad_embed = self.pad_embed[:, None, None, None, :]
1391
  all_pad = image_masks == 0
1392
- partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=torch.float32)
1393
- all_pad = all_pad.to(dtype=torch.float32)
1394
  image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
1395
  image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
1396
  else:
@@ -1769,6 +1759,7 @@ class Molmo(nn.Module):
1769
  for block_group in self.transformer.block_groups:
1770
  block_group.reset_parameters()
1771
 
 
1772
  def forward(
1773
  self,
1774
  input_ids: torch.LongTensor,
@@ -2070,6 +2061,7 @@ class MolmoForCausalLM(PreTrainedModel):
2070
  else:
2071
  self.model = model
2072
 
 
2073
  def forward(
2074
  self,
2075
  input_ids: torch.LongTensor = None,
 
725
  return token.view(1, 1, -1).expand(batch_size, -1, -1)
726
 
727
 
 
 
 
 
 
 
 
 
 
 
 
728
  class ViTMLP(nn.Module):
729
  def __init__(self, config: FullMolmoConfig):
730
  super().__init__()
 
844
  device=config.init_device,
845
  )
846
 
847
+ self.pre_ln = nn.LayerNorm(
848
  v_cfg.image_emb_dim,
849
  eps=v_cfg.image_norm_eps,
 
850
  )
851
 
852
  self.transformer = BlockCollection(config)
 
1001
  attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
1002
 
1003
  elif self.config.attention_type == "sdpa":
1004
+ if self.config.float32_attention and not torch.is_autocast_enabled():
1005
+ xv = xv.to(torch.float32)
1006
  attn_output = F.scaled_dot_product_attention(
1007
  xq.transpose(1, 2).contiguous(),
1008
  xk.transpose(1, 2).contiguous(),
 
1379
  elif cfg.image_padding_embed == "pad_and_partial_pad":
1380
  pad_embed = self.pad_embed[:, None, None, None, :]
1381
  all_pad = image_masks == 0
1382
+ partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
1383
+ all_pad = all_pad.to(dtype=image_features.dtype)
1384
  image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
1385
  image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
1386
  else:
 
1759
  for block_group in self.transformer.block_groups:
1760
  block_group.reset_parameters()
1761
 
1762
+
1763
  def forward(
1764
  self,
1765
  input_ids: torch.LongTensor,
 
2061
  else:
2062
  self.model = model
2063
 
2064
+
2065
  def forward(
2066
  self,
2067
  input_ids: torch.LongTensor = None,