Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code

Update modeling_hf_nomic_bert.py

#6
by zpn - opened
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +127 -90
modeling_hf_nomic_bert.py CHANGED
@@ -3,24 +3,26 @@
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
 
6
  import logging
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
  import math
10
- import numpy as np
11
- import collections
12
  import os
13
  import re
14
  from collections import OrderedDict
15
  from functools import partial
16
  from typing import List, Optional, Tuple, Union
17
 
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from einops import rearrange, repeat
22
  from safetensors.torch import load_file as safe_load_file
23
- from transformers import GPT2Config, PreTrainedModel, ViTModel, ViTConfig
 
 
24
  from transformers.models.bert.modeling_bert import (
25
  BaseModelOutputWithPoolingAndCrossAttentions,
26
  MaskedLMOutput,
@@ -28,11 +30,14 @@ from transformers.models.bert.modeling_bert import (
28
  )
29
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
30
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
31
- from transformers.modeling_outputs import BaseModelOutputWithPast
32
- from torch.nn.modules.utils import _pair
33
 
34
  from .configuration_hf_nomic_bert import NomicBertConfig
35
 
 
 
 
 
 
36
  logger = logging.getLogger(__name__)
37
 
38
 
@@ -66,9 +71,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
66
  else: # Try loading from HF hub instead of from local files
67
  resolved_archive_file = None
68
  for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
69
- resolved_archive_file = cached_file(
70
- model_name, weight_name, _raise_exceptions_for_missing_entries=False
71
- )
72
  if resolved_archive_file is not None:
73
  if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
74
  load_safe = True
@@ -273,18 +276,20 @@ def remap_bert_state_dict(
273
 
274
  return state_dict
275
 
276
-
277
  def _trunc_normal_(tensor, mean, std, a, b):
278
  # Cut & paste from PyTorch official master until it's in a few official releases - RW
279
  # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
280
  def norm_cdf(x):
281
  # Computes standard normal cumulative distribution function
282
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
283
 
284
  if (mean < a - 2 * std) or (mean > b + 2 * std):
285
- print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
286
- "The distribution of values may be incorrect.",
287
- stacklevel=2)
 
 
288
 
289
  # Values are generated by using a truncated uniform distribution and
290
  # then using the inverse CDF for the normal distribution.
@@ -301,14 +306,15 @@ def _trunc_normal_(tensor, mean, std, a, b):
301
  tensor.erfinv_()
302
 
303
  # Transform to proper mean, std
304
- tensor.mul_(std * math.sqrt(2.))
305
  tensor.add_(mean)
306
 
307
  # Clamp to ensure it's in the proper range
308
  tensor.clamp_(min=a, max=b)
309
  return tensor
310
 
311
- def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
 
312
  r"""Fills the input Tensor with values drawn from a truncated
313
  normal distribution. The values are effectively drawn from the
314
  normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
@@ -449,11 +455,13 @@ def _init_weights(module, initializer_range=0.02):
449
  if module.padding_idx is not None:
450
  nn.init.zeros_(module.weight[module.padding_idx])
451
 
 
452
  def _ntuple(n):
453
  def parse(x):
454
  if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
455
  return tuple(x)
456
  return tuple(repeat(x, n))
 
457
  return parse
458
 
459
 
@@ -481,7 +489,7 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
481
  position embeddings (with or without classification token)
482
  """
483
  grid_h = np.arange(grid_size, dtype=np.float32)
484
-
485
  grid_w = np.arange(grid_size, dtype=np.float32)
486
  grid = np.meshgrid(grid_w, grid_h) # here w goes first
487
  grid = np.stack(grid, axis=0)
@@ -525,6 +533,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
525
  emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
526
  return emb
527
 
 
528
  def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
529
  """generate N-D grid in dimension order.
530
 
@@ -548,18 +557,19 @@ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
548
  # the old behaviour of meshgrid was 'ij'
549
  return torch.meshgrid(*tensors)
550
 
 
551
  def build_fourier_pos_embed(
552
- feat_shape: List[int],
553
- bands: Optional[torch.Tensor] = None,
554
- num_bands: int = 64,
555
- max_res: int = 224,
556
- temperature: float = 10000.,
557
- linear_bands: bool = False,
558
- include_grid: bool = False,
559
- in_pixels: bool = True,
560
- ref_feat_shape: Optional[List[int]] = None,
561
- dtype: torch.dtype = torch.float32,
562
- device: Optional[torch.device] = None,
563
  ) -> List[torch.Tensor]:
564
  """
565
 
@@ -601,7 +611,7 @@ def build_fourier_pos_embed(
601
  dtype = bands.dtype
602
 
603
  if in_pixels:
604
- t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
605
  else:
606
  t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
607
 
@@ -619,16 +629,16 @@ def build_fourier_pos_embed(
619
 
620
 
621
  def build_rotary_pos_embed(
622
- feat_shape: List[int],
623
- bands: Optional[torch.Tensor] = None,
624
- dim: int = 64,
625
- max_res: int = 224,
626
- temperature: float = 10000.,
627
- linear_bands: bool = False,
628
- in_pixels: bool = True,
629
- ref_feat_shape: Optional[List[int]] = None,
630
- dtype: torch.dtype = torch.float32,
631
- device: Optional[torch.device] = None,
632
  ):
633
  """
634
 
@@ -666,22 +676,23 @@ def build_rotary_pos_embed(
666
  cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
667
  return sin_emb, cos_emb
668
 
 
669
  def freq_bands(
670
- num_bands: int,
671
- temperature: float = 10000.,
672
- step: int = 2,
673
- device: Optional[torch.device] = None,
674
  ) -> torch.Tensor:
675
  exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
676
- bands = 1. / (temperature ** exp)
677
  return bands
678
 
679
-
680
  def pixel_freq_bands(
681
- num_bands: int,
682
- max_freq: float = 224.,
683
- linear_bands: bool = True,
684
- device: Optional[torch.device] = None,
685
  ):
686
  if linear_bands:
687
  bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
@@ -689,18 +700,21 @@ def pixel_freq_bands(
689
  bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
690
  return bands * torch.pi
691
 
 
692
  def rot(x):
693
  return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
694
 
 
695
  def apply_rot_embed_cat(x: torch.Tensor, emb):
696
  sin_emb, cos_emb = emb.tensor_split(2, -1)
697
  if sin_emb.ndim == 3:
698
  return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
699
  return x * cos_emb + rot(x) * sin_emb
700
 
 
701
  # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
702
  class NomicVisionRotaryEmbeddingCat(nn.Module):
703
- """ Rotary position embedding w/ concatenatd sin & cos
704
 
705
  The following impl/resources were referenced for this impl:
706
  * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
@@ -708,14 +722,14 @@ class NomicVisionRotaryEmbeddingCat(nn.Module):
708
  """
709
 
710
  def __init__(
711
- self,
712
- dim,
713
- max_res=224,
714
- temperature=10000,
715
- in_pixels=True,
716
- linear_bands: bool = False,
717
- feat_shape: Optional[List[int]] = None,
718
- ref_feat_shape: Optional[List[int]] = None,
719
  ):
720
  super().__init__()
721
  self.dim = dim
@@ -782,6 +796,7 @@ class NomicVisionRotaryEmbeddingCat(nn.Module):
782
  pos_embed = self.get_embed(x.shape[2:])
783
  return apply_rot_embed_cat(x, pos_embed)
784
 
 
785
  class NomicVisionPatchEmbeddings(nn.Module):
786
  def __init__(
787
  self,
@@ -803,13 +818,19 @@ class NomicVisionPatchEmbeddings(nn.Module):
803
  self.sinusoidal_pos_embedding = False
804
  self.no_embed_class = getattr(config, "no_embed_class", False)
805
 
806
- self.cls_token = nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
 
 
807
  if config.learned_pos_embedding:
808
  # this is the default in DINO
809
  self.learned_pos_embedding = True
810
  # hack for timm dinov2 with registers
811
  num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
812
- self.pos_embed = nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
 
 
 
 
813
  elif getattr(config, "sinusoidal_pos_embedding", False):
814
  self.sinusoidal_pos_embedding = True
815
  if getattr(config, "use_pos_embed", True):
@@ -819,12 +840,16 @@ class NomicVisionPatchEmbeddings(nn.Module):
819
  else:
820
  self.pos_embed = None
821
  else:
822
- self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
 
 
 
 
823
 
824
  if getattr(config, "register_tokens", 0) > 0:
825
  self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
826
  else:
827
- self.reg_token = None
828
 
829
  if config.mask_token:
830
  self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
@@ -843,7 +868,6 @@ class NomicVisionPatchEmbeddings(nn.Module):
843
  else:
844
  self.rope = None
845
 
846
-
847
  def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
848
  """
849
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
@@ -913,7 +937,7 @@ class NomicVisionPatchEmbeddings(nn.Module):
913
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
914
  else:
915
  if self.pos_embed is not None:
916
- embeddings = embeddings + self.pos_embed
917
  if to_cat:
918
  embeddings = torch.cat(to_cat + [embeddings], dim=1)
919
  else:
@@ -924,7 +948,7 @@ class NomicVisionPatchEmbeddings(nn.Module):
924
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
925
  else:
926
  if self.pos_embed is not None:
927
- embeddings = embeddings + self.pos_embed
928
 
929
  embeddings = self.patch_dropout(embeddings)
930
 
@@ -1350,8 +1374,12 @@ class NomicBertAttention(nn.Module):
1350
  qkv = rearrange(qkv, "b h three s d -> b s three h d")
1351
  elif rope is not None:
1352
  q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1353
- q = torch.cat([q[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1354
- k = torch.cat([k[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
 
 
 
 
1355
 
1356
  qkv = torch.stack([q, k, v], dim=-2)
1357
  qkv = rearrange(qkv, "b h s three d -> b s three h d")
@@ -1361,15 +1389,20 @@ class NomicBertAttention(nn.Module):
1361
  query = query.permute(0, 2, 1, 3)
1362
  key = key.permute(0, 2, 1, 3)
1363
  value = value.permute(0, 2, 1, 3)
 
 
 
 
 
 
 
 
1364
 
1365
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1366
- if attention_mask is not None:
1367
- attention_scores = attention_scores + attention_mask
1368
 
1369
- attentions_probs = F.softmax(attention_scores, dim=-1)
1370
- attentions_probs = self.drop(attentions_probs)
1371
 
1372
- attn_output = torch.matmul(attentions_probs, value)
1373
  attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1374
 
1375
  attn_output = self.out_proj(attn_output)
@@ -1807,6 +1840,7 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1807
  attentions=outputs.attentions,
1808
  )
1809
 
 
1810
  def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1811
  return GPT2Config(
1812
  n_embd=vit_config.hidden_size,
@@ -1814,7 +1848,7 @@ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1814
  n_head=vit_config.num_attention_heads,
1815
  n_inner=vit_config.intermediate_size,
1816
  activation_function=vit_config.hidden_act,
1817
- vocab_size=0, # no vocab since using patches
1818
  n_positions=0, # No absolute position embedding
1819
  resid_pdrop=0.0, # No dropout
1820
  embd_pdrop=getattr(vit_config, "dropout", 0.0),
@@ -1850,15 +1884,12 @@ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1850
  mask_token=False,
1851
  learned_pos_embedding=False,
1852
  patch_dropout=0,
1853
- sinusoidal_pos_embedding=vit_config.model_type == "vit_mae"
1854
  )
1855
 
1856
-
1857
  class NomicAttentionPooling(nn.Module):
1858
- def __init__(
1859
- self,
1860
- config
1861
- ):
1862
  super().__init__()
1863
  self.embed_dim = config.n_embd
1864
  self.use_flash_attn = config.use_flash_attn
@@ -1879,7 +1910,7 @@ class NomicAttentionPooling(nn.Module):
1879
 
1880
  self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1881
  self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1882
-
1883
  self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1884
 
1885
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
@@ -1887,7 +1918,7 @@ class NomicAttentionPooling(nn.Module):
1887
  self.drop = nn.Dropout(config.attn_pdrop)
1888
 
1889
  def init_weights(self):
1890
- trunc_normal_tf_(self.latent, std=self.embed_dim ** -0.5)
1891
 
1892
  def forward(
1893
  self,
@@ -1938,7 +1969,7 @@ class NomicAttentionPooling(nn.Module):
1938
 
1939
  return attn_output
1940
 
1941
-
1942
  class NomicMultiHeadAttentionPooling(nn.Module):
1943
  def __init__(
1944
  self,
@@ -1993,15 +2024,16 @@ class NomicMultiHeadAttentionPooling(nn.Module):
1993
  """
1994
 
1995
  attn_outputs = self.attn(
1996
- hidden_states,
1997
- attention_mask=attention_mask,
1998
- )
1999
 
2000
  normed = self.norm1(attn_outputs)
2001
  hidden_states = hidden_states + self.mlp(normed)
2002
 
2003
  return hidden_states
2004
 
 
2005
  class NomicVisionPreTrainedModel(PreTrainedModel):
2006
  """An abstract class to handle weights initialization and
2007
  a simple interface for dowloading and loading pretrained models.
@@ -2025,6 +2057,7 @@ class NomicVisionPreTrainedModel(PreTrainedModel):
2025
  )
2026
  self.config = config
2027
 
 
2028
  class NomicVisionModel(NomicVisionPreTrainedModel):
2029
  def __init__(self, config):
2030
  super().__init__(config)
@@ -2035,7 +2068,9 @@ class NomicVisionModel(NomicVisionPreTrainedModel):
2035
  self.selector = NomicMultiHeadAttentionPooling(config)
2036
 
2037
  self.global_pool = getattr(config, "global_pool", None)
2038
- self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(config, "register_tokens", 0)
 
 
2039
 
2040
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2041
 
@@ -2052,20 +2087,22 @@ class NomicVisionModel(NomicVisionPreTrainedModel):
2052
 
2053
  original_dtype = embeddings.dtype
2054
 
2055
- hidden_states = embeddings
2056
  # unused but easier to pass to gradient checkpointing as words
2057
  residual = None
2058
  for layer in self.layers:
2059
  # need to pass none for backwards compatability
2060
- hidden_states, _, residual = layer(hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope)
 
 
2061
 
2062
  hidden_states = hidden_states + residual
2063
  if self.global_pool == "avg":
2064
- hidden_states = hidden_states[:, self.num_prefix_tokens:].mean(dim=1)
2065
 
2066
  pooled_output = self.selector(hidden_states)
2067
 
2068
  return BaseModelOutputWithPast(
2069
  last_hidden_state=pooled_output,
2070
  hidden_states=hidden_states,
2071
- )
 
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
6
+ import collections
7
  import logging
8
 
9
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
  import math
 
 
11
  import os
12
  import re
13
  from collections import OrderedDict
14
  from functools import partial
15
  from typing import List, Optional, Tuple, Union
16
 
17
+ import numpy as np
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from einops import rearrange, repeat
22
  from safetensors.torch import load_file as safe_load_file
23
+ from torch.nn.modules.utils import _pair
24
+ from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
  from transformers.models.bert.modeling_bert import (
27
  BaseModelOutputWithPoolingAndCrossAttentions,
28
  MaskedLMOutput,
 
30
  )
31
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
32
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
 
 
33
 
34
  from .configuration_hf_nomic_bert import NomicBertConfig
35
 
36
+ try:
37
+ from torch.nn.functional import scaled_dot_product_attention
38
+ except ImportError:
39
+ scaled_dot_product_attention = None
40
+
41
  logger = logging.getLogger(__name__)
42
 
43
 
 
71
  else: # Try loading from HF hub instead of from local files
72
  resolved_archive_file = None
73
  for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
74
+ resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
 
 
75
  if resolved_archive_file is not None:
76
  if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
77
  load_safe = True
 
276
 
277
  return state_dict
278
 
279
+
280
  def _trunc_normal_(tensor, mean, std, a, b):
281
  # Cut & paste from PyTorch official master until it's in a few official releases - RW
282
  # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
283
  def norm_cdf(x):
284
  # Computes standard normal cumulative distribution function
285
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
286
 
287
  if (mean < a - 2 * std) or (mean > b + 2 * std):
288
+ print(
289
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
290
+ "The distribution of values may be incorrect.",
291
+ stacklevel=2,
292
+ )
293
 
294
  # Values are generated by using a truncated uniform distribution and
295
  # then using the inverse CDF for the normal distribution.
 
306
  tensor.erfinv_()
307
 
308
  # Transform to proper mean, std
309
+ tensor.mul_(std * math.sqrt(2.0))
310
  tensor.add_(mean)
311
 
312
  # Clamp to ensure it's in the proper range
313
  tensor.clamp_(min=a, max=b)
314
  return tensor
315
 
316
+
317
+ def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
318
  r"""Fills the input Tensor with values drawn from a truncated
319
  normal distribution. The values are effectively drawn from the
320
  normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
 
455
  if module.padding_idx is not None:
456
  nn.init.zeros_(module.weight[module.padding_idx])
457
 
458
+
459
  def _ntuple(n):
460
  def parse(x):
461
  if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
462
  return tuple(x)
463
  return tuple(repeat(x, n))
464
+
465
  return parse
466
 
467
 
 
489
  position embeddings (with or without classification token)
490
  """
491
  grid_h = np.arange(grid_size, dtype=np.float32)
492
+
493
  grid_w = np.arange(grid_size, dtype=np.float32)
494
  grid = np.meshgrid(grid_w, grid_h) # here w goes first
495
  grid = np.stack(grid, axis=0)
 
533
  emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
534
  return emb
535
 
536
+
537
  def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
538
  """generate N-D grid in dimension order.
539
 
 
557
  # the old behaviour of meshgrid was 'ij'
558
  return torch.meshgrid(*tensors)
559
 
560
+
561
  def build_fourier_pos_embed(
562
+ feat_shape: List[int],
563
+ bands: Optional[torch.Tensor] = None,
564
+ num_bands: int = 64,
565
+ max_res: int = 224,
566
+ temperature: float = 10000.0,
567
+ linear_bands: bool = False,
568
+ include_grid: bool = False,
569
+ in_pixels: bool = True,
570
+ ref_feat_shape: Optional[List[int]] = None,
571
+ dtype: torch.dtype = torch.float32,
572
+ device: Optional[torch.device] = None,
573
  ) -> List[torch.Tensor]:
574
  """
575
 
 
611
  dtype = bands.dtype
612
 
613
  if in_pixels:
614
+ t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape]
615
  else:
616
  t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
617
 
 
629
 
630
 
631
  def build_rotary_pos_embed(
632
+ feat_shape: List[int],
633
+ bands: Optional[torch.Tensor] = None,
634
+ dim: int = 64,
635
+ max_res: int = 224,
636
+ temperature: float = 10000.0,
637
+ linear_bands: bool = False,
638
+ in_pixels: bool = True,
639
+ ref_feat_shape: Optional[List[int]] = None,
640
+ dtype: torch.dtype = torch.float32,
641
+ device: Optional[torch.device] = None,
642
  ):
643
  """
644
 
 
676
  cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
677
  return sin_emb, cos_emb
678
 
679
+
680
  def freq_bands(
681
+ num_bands: int,
682
+ temperature: float = 10000.0,
683
+ step: int = 2,
684
+ device: Optional[torch.device] = None,
685
  ) -> torch.Tensor:
686
  exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
687
+ bands = 1.0 / (temperature**exp)
688
  return bands
689
 
690
+
691
  def pixel_freq_bands(
692
+ num_bands: int,
693
+ max_freq: float = 224.0,
694
+ linear_bands: bool = True,
695
+ device: Optional[torch.device] = None,
696
  ):
697
  if linear_bands:
698
  bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
 
700
  bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
701
  return bands * torch.pi
702
 
703
+
704
  def rot(x):
705
  return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
706
 
707
+
708
  def apply_rot_embed_cat(x: torch.Tensor, emb):
709
  sin_emb, cos_emb = emb.tensor_split(2, -1)
710
  if sin_emb.ndim == 3:
711
  return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
712
  return x * cos_emb + rot(x) * sin_emb
713
 
714
+
715
  # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
716
  class NomicVisionRotaryEmbeddingCat(nn.Module):
717
+ """Rotary position embedding w/ concatenatd sin & cos
718
 
719
  The following impl/resources were referenced for this impl:
720
  * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
 
722
  """
723
 
724
  def __init__(
725
+ self,
726
+ dim,
727
+ max_res=224,
728
+ temperature=10000,
729
+ in_pixels=True,
730
+ linear_bands: bool = False,
731
+ feat_shape: Optional[List[int]] = None,
732
+ ref_feat_shape: Optional[List[int]] = None,
733
  ):
734
  super().__init__()
735
  self.dim = dim
 
796
  pos_embed = self.get_embed(x.shape[2:])
797
  return apply_rot_embed_cat(x, pos_embed)
798
 
799
+
800
  class NomicVisionPatchEmbeddings(nn.Module):
801
  def __init__(
802
  self,
 
818
  self.sinusoidal_pos_embedding = False
819
  self.no_embed_class = getattr(config, "no_embed_class", False)
820
 
821
+ self.cls_token = (
822
+ nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
823
+ )
824
  if config.learned_pos_embedding:
825
  # this is the default in DINO
826
  self.learned_pos_embedding = True
827
  # hack for timm dinov2 with registers
828
  num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
829
+ self.pos_embed = (
830
+ nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02)
831
+ if getattr(config, "use_pos_embed", True)
832
+ else None
833
+ )
834
  elif getattr(config, "sinusoidal_pos_embedding", False):
835
  self.sinusoidal_pos_embedding = True
836
  if getattr(config, "use_pos_embed", True):
 
840
  else:
841
  self.pos_embed = None
842
  else:
843
+ self.pos_embed = (
844
+ nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02)
845
+ if getattr(config, "use_pos_embed", True)
846
+ else None
847
+ )
848
 
849
  if getattr(config, "register_tokens", 0) > 0:
850
  self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
851
  else:
852
+ self.reg_token = None
853
 
854
  if config.mask_token:
855
  self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
 
868
  else:
869
  self.rope = None
870
 
 
871
  def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
872
  """
873
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
 
937
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
938
  else:
939
  if self.pos_embed is not None:
940
+ embeddings = embeddings + self.pos_embed
941
  if to_cat:
942
  embeddings = torch.cat(to_cat + [embeddings], dim=1)
943
  else:
 
948
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
949
  else:
950
  if self.pos_embed is not None:
951
+ embeddings = embeddings + self.pos_embed
952
 
953
  embeddings = self.patch_dropout(embeddings)
954
 
 
1374
  qkv = rearrange(qkv, "b h three s d -> b s three h d")
1375
  elif rope is not None:
1376
  q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1377
+ q = torch.cat(
1378
+ [q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2
1379
+ ).type_as(q)
1380
+ k = torch.cat(
1381
+ [k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2
1382
+ ).type_as(q)
1383
 
1384
  qkv = torch.stack([q, k, v], dim=-2)
1385
  qkv = rearrange(qkv, "b h s three d -> b s three h d")
 
1389
  query = query.permute(0, 2, 1, 3)
1390
  key = key.permute(0, 2, 1, 3)
1391
  value = value.permute(0, 2, 1, 3)
1392
+ if scaled_dot_product_attention is not None:
1393
+ attn_output = F.scaled_dot_product_attention(
1394
+ query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False
1395
+ )
1396
+ else:
1397
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1398
+ if attention_mask is not None:
1399
+ attention_scores = attention_scores + attention_mask
1400
 
1401
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1402
+ attentions_probs = self.drop(attentions_probs)
 
1403
 
1404
+ attn_output = torch.matmul(attentions_probs, value)
 
1405
 
 
1406
  attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1407
 
1408
  attn_output = self.out_proj(attn_output)
 
1840
  attentions=outputs.attentions,
1841
  )
1842
 
1843
+
1844
  def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1845
  return GPT2Config(
1846
  n_embd=vit_config.hidden_size,
 
1848
  n_head=vit_config.num_attention_heads,
1849
  n_inner=vit_config.intermediate_size,
1850
  activation_function=vit_config.hidden_act,
1851
+ vocab_size=0, # no vocab since using patches
1852
  n_positions=0, # No absolute position embedding
1853
  resid_pdrop=0.0, # No dropout
1854
  embd_pdrop=getattr(vit_config, "dropout", 0.0),
 
1884
  mask_token=False,
1885
  learned_pos_embedding=False,
1886
  patch_dropout=0,
1887
+ sinusoidal_pos_embedding=vit_config.model_type == "vit_mae",
1888
  )
1889
 
1890
+
1891
  class NomicAttentionPooling(nn.Module):
1892
+ def __init__(self, config):
 
 
 
1893
  super().__init__()
1894
  self.embed_dim = config.n_embd
1895
  self.use_flash_attn = config.use_flash_attn
 
1910
 
1911
  self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1912
  self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1913
+
1914
  self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1915
 
1916
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
 
1918
  self.drop = nn.Dropout(config.attn_pdrop)
1919
 
1920
  def init_weights(self):
1921
+ trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5)
1922
 
1923
  def forward(
1924
  self,
 
1969
 
1970
  return attn_output
1971
 
1972
+
1973
  class NomicMultiHeadAttentionPooling(nn.Module):
1974
  def __init__(
1975
  self,
 
2024
  """
2025
 
2026
  attn_outputs = self.attn(
2027
+ hidden_states,
2028
+ attention_mask=attention_mask,
2029
+ )
2030
 
2031
  normed = self.norm1(attn_outputs)
2032
  hidden_states = hidden_states + self.mlp(normed)
2033
 
2034
  return hidden_states
2035
 
2036
+
2037
  class NomicVisionPreTrainedModel(PreTrainedModel):
2038
  """An abstract class to handle weights initialization and
2039
  a simple interface for dowloading and loading pretrained models.
 
2057
  )
2058
  self.config = config
2059
 
2060
+
2061
  class NomicVisionModel(NomicVisionPreTrainedModel):
2062
  def __init__(self, config):
2063
  super().__init__(config)
 
2068
  self.selector = NomicMultiHeadAttentionPooling(config)
2069
 
2070
  self.global_pool = getattr(config, "global_pool", None)
2071
+ self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(
2072
+ config, "register_tokens", 0
2073
+ )
2074
 
2075
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2076
 
 
2087
 
2088
  original_dtype = embeddings.dtype
2089
 
2090
+ hidden_states = embeddings
2091
  # unused but easier to pass to gradient checkpointing as words
2092
  residual = None
2093
  for layer in self.layers:
2094
  # need to pass none for backwards compatability
2095
+ hidden_states, _, residual = layer(
2096
+ hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope
2097
+ )
2098
 
2099
  hidden_states = hidden_states + residual
2100
  if self.global_pool == "avg":
2101
+ hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1)
2102
 
2103
  pooled_output = self.selector(hidden_states)
2104
 
2105
  return BaseModelOutputWithPast(
2106
  last_hidden_state=pooled_output,
2107
  hidden_states=hidden_states,
2108
+ )