Upload modeling_molmo.py with huggingface_hub
Browse files- modeling_molmo.py +7 -15
modeling_molmo.py
CHANGED
@@ -725,17 +725,6 @@ def _expand_token(token, batch_size: int):
|
|
725 |
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
726 |
|
727 |
|
728 |
-
class LayerNormFp32(nn.LayerNorm):
|
729 |
-
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).
|
730 |
-
Derived from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py.
|
731 |
-
"""
|
732 |
-
|
733 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
734 |
-
orig_type = x.dtype
|
735 |
-
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
736 |
-
return x.to(orig_type)
|
737 |
-
|
738 |
-
|
739 |
class ViTMLP(nn.Module):
|
740 |
def __init__(self, config: FullMolmoConfig):
|
741 |
super().__init__()
|
@@ -855,10 +844,9 @@ class VisionTransformer(nn.Module):
|
|
855 |
device=config.init_device,
|
856 |
)
|
857 |
|
858 |
-
self.pre_ln =
|
859 |
v_cfg.image_emb_dim,
|
860 |
eps=v_cfg.image_norm_eps,
|
861 |
-
device=config.init_device,
|
862 |
)
|
863 |
|
864 |
self.transformer = BlockCollection(config)
|
@@ -1013,6 +1001,8 @@ class MultiHeadDotProductAttention(nn.Module):
|
|
1013 |
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
|
1014 |
|
1015 |
elif self.config.attention_type == "sdpa":
|
|
|
|
|
1016 |
attn_output = F.scaled_dot_product_attention(
|
1017 |
xq.transpose(1, 2).contiguous(),
|
1018 |
xk.transpose(1, 2).contiguous(),
|
@@ -1389,8 +1379,8 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|
1389 |
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
1390 |
pad_embed = self.pad_embed[:, None, None, None, :]
|
1391 |
all_pad = image_masks == 0
|
1392 |
-
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=
|
1393 |
-
all_pad = all_pad.to(dtype=
|
1394 |
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
1395 |
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
|
1396 |
else:
|
@@ -1769,6 +1759,7 @@ class Molmo(nn.Module):
|
|
1769 |
for block_group in self.transformer.block_groups:
|
1770 |
block_group.reset_parameters()
|
1771 |
|
|
|
1772 |
def forward(
|
1773 |
self,
|
1774 |
input_ids: torch.LongTensor,
|
@@ -2070,6 +2061,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2070 |
else:
|
2071 |
self.model = model
|
2072 |
|
|
|
2073 |
def forward(
|
2074 |
self,
|
2075 |
input_ids: torch.LongTensor = None,
|
|
|
725 |
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
726 |
|
727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
class ViTMLP(nn.Module):
|
729 |
def __init__(self, config: FullMolmoConfig):
|
730 |
super().__init__()
|
|
|
844 |
device=config.init_device,
|
845 |
)
|
846 |
|
847 |
+
self.pre_ln = nn.LayerNorm(
|
848 |
v_cfg.image_emb_dim,
|
849 |
eps=v_cfg.image_norm_eps,
|
|
|
850 |
)
|
851 |
|
852 |
self.transformer = BlockCollection(config)
|
|
|
1001 |
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
|
1002 |
|
1003 |
elif self.config.attention_type == "sdpa":
|
1004 |
+
if self.config.float32_attention and not torch.is_autocast_enabled():
|
1005 |
+
xv = xv.to(torch.float32)
|
1006 |
attn_output = F.scaled_dot_product_attention(
|
1007 |
xq.transpose(1, 2).contiguous(),
|
1008 |
xk.transpose(1, 2).contiguous(),
|
|
|
1379 |
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
1380 |
pad_embed = self.pad_embed[:, None, None, None, :]
|
1381 |
all_pad = image_masks == 0
|
1382 |
+
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
1383 |
+
all_pad = all_pad.to(dtype=image_features.dtype)
|
1384 |
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
1385 |
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
|
1386 |
else:
|
|
|
1759 |
for block_group in self.transformer.block_groups:
|
1760 |
block_group.reset_parameters()
|
1761 |
|
1762 |
+
|
1763 |
def forward(
|
1764 |
self,
|
1765 |
input_ids: torch.LongTensor,
|
|
|
2061 |
else:
|
2062 |
self.model = model
|
2063 |
|
2064 |
+
|
2065 |
def forward(
|
2066 |
self,
|
2067 |
input_ids: torch.LongTensor = None,
|