Update modeling_hf_nomic_bert.py
#6
by
zpn
- opened
- modeling_hf_nomic_bert.py +127 -90
modeling_hf_nomic_bert.py
CHANGED
@@ -3,24 +3,26 @@
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
|
|
6 |
import logging
|
7 |
|
8 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
9 |
import math
|
10 |
-
import numpy as np
|
11 |
-
import collections
|
12 |
import os
|
13 |
import re
|
14 |
from collections import OrderedDict
|
15 |
from functools import partial
|
16 |
from typing import List, Optional, Tuple, Union
|
17 |
|
|
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from einops import rearrange, repeat
|
22 |
from safetensors.torch import load_file as safe_load_file
|
23 |
-
from
|
|
|
|
|
24 |
from transformers.models.bert.modeling_bert import (
|
25 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
26 |
MaskedLMOutput,
|
@@ -28,11 +30,14 @@ from transformers.models.bert.modeling_bert import (
|
|
28 |
)
|
29 |
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
30 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
31 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
32 |
-
from torch.nn.modules.utils import _pair
|
33 |
|
34 |
from .configuration_hf_nomic_bert import NomicBertConfig
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
logger = logging.getLogger(__name__)
|
37 |
|
38 |
|
@@ -66,9 +71,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
|
|
66 |
else: # Try loading from HF hub instead of from local files
|
67 |
resolved_archive_file = None
|
68 |
for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
69 |
-
resolved_archive_file = cached_file(
|
70 |
-
model_name, weight_name, _raise_exceptions_for_missing_entries=False
|
71 |
-
)
|
72 |
if resolved_archive_file is not None:
|
73 |
if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
74 |
load_safe = True
|
@@ -273,18 +276,20 @@ def remap_bert_state_dict(
|
|
273 |
|
274 |
return state_dict
|
275 |
|
276 |
-
|
277 |
def _trunc_normal_(tensor, mean, std, a, b):
|
278 |
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
279 |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
280 |
def norm_cdf(x):
|
281 |
# Computes standard normal cumulative distribution function
|
282 |
-
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
283 |
|
284 |
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
285 |
-
print(
|
286 |
-
|
287 |
-
|
|
|
|
|
288 |
|
289 |
# Values are generated by using a truncated uniform distribution and
|
290 |
# then using the inverse CDF for the normal distribution.
|
@@ -301,14 +306,15 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
301 |
tensor.erfinv_()
|
302 |
|
303 |
# Transform to proper mean, std
|
304 |
-
tensor.mul_(std * math.sqrt(2.))
|
305 |
tensor.add_(mean)
|
306 |
|
307 |
# Clamp to ensure it's in the proper range
|
308 |
tensor.clamp_(min=a, max=b)
|
309 |
return tensor
|
310 |
|
311 |
-
|
|
|
312 |
r"""Fills the input Tensor with values drawn from a truncated
|
313 |
normal distribution. The values are effectively drawn from the
|
314 |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
@@ -449,11 +455,13 @@ def _init_weights(module, initializer_range=0.02):
|
|
449 |
if module.padding_idx is not None:
|
450 |
nn.init.zeros_(module.weight[module.padding_idx])
|
451 |
|
|
|
452 |
def _ntuple(n):
|
453 |
def parse(x):
|
454 |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
455 |
return tuple(x)
|
456 |
return tuple(repeat(x, n))
|
|
|
457 |
return parse
|
458 |
|
459 |
|
@@ -481,7 +489,7 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
|
481 |
position embeddings (with or without classification token)
|
482 |
"""
|
483 |
grid_h = np.arange(grid_size, dtype=np.float32)
|
484 |
-
|
485 |
grid_w = np.arange(grid_size, dtype=np.float32)
|
486 |
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
487 |
grid = np.stack(grid, axis=0)
|
@@ -525,6 +533,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
525 |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
526 |
return emb
|
527 |
|
|
|
528 |
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
529 |
"""generate N-D grid in dimension order.
|
530 |
|
@@ -548,18 +557,19 @@ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
|
548 |
# the old behaviour of meshgrid was 'ij'
|
549 |
return torch.meshgrid(*tensors)
|
550 |
|
|
|
551 |
def build_fourier_pos_embed(
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
) -> List[torch.Tensor]:
|
564 |
"""
|
565 |
|
@@ -601,7 +611,7 @@ def build_fourier_pos_embed(
|
|
601 |
dtype = bands.dtype
|
602 |
|
603 |
if in_pixels:
|
604 |
-
t = [torch.linspace(-1
|
605 |
else:
|
606 |
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
607 |
|
@@ -619,16 +629,16 @@ def build_fourier_pos_embed(
|
|
619 |
|
620 |
|
621 |
def build_rotary_pos_embed(
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
):
|
633 |
"""
|
634 |
|
@@ -666,22 +676,23 @@ def build_rotary_pos_embed(
|
|
666 |
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
667 |
return sin_emb, cos_emb
|
668 |
|
|
|
669 |
def freq_bands(
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
) -> torch.Tensor:
|
675 |
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
676 |
-
bands = 1. / (temperature
|
677 |
return bands
|
678 |
|
679 |
-
|
680 |
def pixel_freq_bands(
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
):
|
686 |
if linear_bands:
|
687 |
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
|
@@ -689,18 +700,21 @@ def pixel_freq_bands(
|
|
689 |
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
|
690 |
return bands * torch.pi
|
691 |
|
|
|
692 |
def rot(x):
|
693 |
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
694 |
|
|
|
695 |
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
696 |
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
697 |
if sin_emb.ndim == 3:
|
698 |
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
699 |
return x * cos_emb + rot(x) * sin_emb
|
700 |
|
|
|
701 |
# taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
|
702 |
class NomicVisionRotaryEmbeddingCat(nn.Module):
|
703 |
-
"""
|
704 |
|
705 |
The following impl/resources were referenced for this impl:
|
706 |
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
@@ -708,14 +722,14 @@ class NomicVisionRotaryEmbeddingCat(nn.Module):
|
|
708 |
"""
|
709 |
|
710 |
def __init__(
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
):
|
720 |
super().__init__()
|
721 |
self.dim = dim
|
@@ -782,6 +796,7 @@ class NomicVisionRotaryEmbeddingCat(nn.Module):
|
|
782 |
pos_embed = self.get_embed(x.shape[2:])
|
783 |
return apply_rot_embed_cat(x, pos_embed)
|
784 |
|
|
|
785 |
class NomicVisionPatchEmbeddings(nn.Module):
|
786 |
def __init__(
|
787 |
self,
|
@@ -803,13 +818,19 @@ class NomicVisionPatchEmbeddings(nn.Module):
|
|
803 |
self.sinusoidal_pos_embedding = False
|
804 |
self.no_embed_class = getattr(config, "no_embed_class", False)
|
805 |
|
806 |
-
self.cls_token =
|
|
|
|
|
807 |
if config.learned_pos_embedding:
|
808 |
# this is the default in DINO
|
809 |
self.learned_pos_embedding = True
|
810 |
# hack for timm dinov2 with registers
|
811 |
num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
|
812 |
-
self.pos_embed =
|
|
|
|
|
|
|
|
|
813 |
elif getattr(config, "sinusoidal_pos_embedding", False):
|
814 |
self.sinusoidal_pos_embedding = True
|
815 |
if getattr(config, "use_pos_embed", True):
|
@@ -819,12 +840,16 @@ class NomicVisionPatchEmbeddings(nn.Module):
|
|
819 |
else:
|
820 |
self.pos_embed = None
|
821 |
else:
|
822 |
-
self.pos_embed =
|
|
|
|
|
|
|
|
|
823 |
|
824 |
if getattr(config, "register_tokens", 0) > 0:
|
825 |
self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
|
826 |
else:
|
827 |
-
self.reg_token = None
|
828 |
|
829 |
if config.mask_token:
|
830 |
self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
|
@@ -843,7 +868,6 @@ class NomicVisionPatchEmbeddings(nn.Module):
|
|
843 |
else:
|
844 |
self.rope = None
|
845 |
|
846 |
-
|
847 |
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
848 |
"""
|
849 |
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
@@ -913,7 +937,7 @@ class NomicVisionPatchEmbeddings(nn.Module):
|
|
913 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
914 |
else:
|
915 |
if self.pos_embed is not None:
|
916 |
-
embeddings = embeddings + self.pos_embed
|
917 |
if to_cat:
|
918 |
embeddings = torch.cat(to_cat + [embeddings], dim=1)
|
919 |
else:
|
@@ -924,7 +948,7 @@ class NomicVisionPatchEmbeddings(nn.Module):
|
|
924 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
925 |
else:
|
926 |
if self.pos_embed is not None:
|
927 |
-
embeddings = embeddings + self.pos_embed
|
928 |
|
929 |
embeddings = self.patch_dropout(embeddings)
|
930 |
|
@@ -1350,8 +1374,12 @@ class NomicBertAttention(nn.Module):
|
|
1350 |
qkv = rearrange(qkv, "b h three s d -> b s three h d")
|
1351 |
elif rope is not None:
|
1352 |
q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
|
1353 |
-
q = torch.cat(
|
1354 |
-
|
|
|
|
|
|
|
|
|
1355 |
|
1356 |
qkv = torch.stack([q, k, v], dim=-2)
|
1357 |
qkv = rearrange(qkv, "b h s three d -> b s three h d")
|
@@ -1361,15 +1389,20 @@ class NomicBertAttention(nn.Module):
|
|
1361 |
query = query.permute(0, 2, 1, 3)
|
1362 |
key = key.permute(0, 2, 1, 3)
|
1363 |
value = value.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1364 |
|
1365 |
-
|
1366 |
-
|
1367 |
-
attention_scores = attention_scores + attention_mask
|
1368 |
|
1369 |
-
|
1370 |
-
attentions_probs = self.drop(attentions_probs)
|
1371 |
|
1372 |
-
attn_output = torch.matmul(attentions_probs, value)
|
1373 |
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
|
1374 |
|
1375 |
attn_output = self.out_proj(attn_output)
|
@@ -1807,6 +1840,7 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
|
1807 |
attentions=outputs.attentions,
|
1808 |
)
|
1809 |
|
|
|
1810 |
def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
|
1811 |
return GPT2Config(
|
1812 |
n_embd=vit_config.hidden_size,
|
@@ -1814,7 +1848,7 @@ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
|
|
1814 |
n_head=vit_config.num_attention_heads,
|
1815 |
n_inner=vit_config.intermediate_size,
|
1816 |
activation_function=vit_config.hidden_act,
|
1817 |
-
vocab_size=0,
|
1818 |
n_positions=0, # No absolute position embedding
|
1819 |
resid_pdrop=0.0, # No dropout
|
1820 |
embd_pdrop=getattr(vit_config, "dropout", 0.0),
|
@@ -1850,15 +1884,12 @@ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
|
|
1850 |
mask_token=False,
|
1851 |
learned_pos_embedding=False,
|
1852 |
patch_dropout=0,
|
1853 |
-
sinusoidal_pos_embedding=vit_config.model_type == "vit_mae"
|
1854 |
)
|
1855 |
|
1856 |
-
|
1857 |
class NomicAttentionPooling(nn.Module):
|
1858 |
-
def __init__(
|
1859 |
-
self,
|
1860 |
-
config
|
1861 |
-
):
|
1862 |
super().__init__()
|
1863 |
self.embed_dim = config.n_embd
|
1864 |
self.use_flash_attn = config.use_flash_attn
|
@@ -1879,7 +1910,7 @@ class NomicAttentionPooling(nn.Module):
|
|
1879 |
|
1880 |
self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
1881 |
self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
|
1882 |
-
|
1883 |
self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
1884 |
|
1885 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
@@ -1887,7 +1918,7 @@ class NomicAttentionPooling(nn.Module):
|
|
1887 |
self.drop = nn.Dropout(config.attn_pdrop)
|
1888 |
|
1889 |
def init_weights(self):
|
1890 |
-
trunc_normal_tf_(self.latent, std=self.embed_dim
|
1891 |
|
1892 |
def forward(
|
1893 |
self,
|
@@ -1938,7 +1969,7 @@ class NomicAttentionPooling(nn.Module):
|
|
1938 |
|
1939 |
return attn_output
|
1940 |
|
1941 |
-
|
1942 |
class NomicMultiHeadAttentionPooling(nn.Module):
|
1943 |
def __init__(
|
1944 |
self,
|
@@ -1993,15 +2024,16 @@ class NomicMultiHeadAttentionPooling(nn.Module):
|
|
1993 |
"""
|
1994 |
|
1995 |
attn_outputs = self.attn(
|
1996 |
-
|
1997 |
-
|
1998 |
-
|
1999 |
|
2000 |
normed = self.norm1(attn_outputs)
|
2001 |
hidden_states = hidden_states + self.mlp(normed)
|
2002 |
|
2003 |
return hidden_states
|
2004 |
|
|
|
2005 |
class NomicVisionPreTrainedModel(PreTrainedModel):
|
2006 |
"""An abstract class to handle weights initialization and
|
2007 |
a simple interface for dowloading and loading pretrained models.
|
@@ -2025,6 +2057,7 @@ class NomicVisionPreTrainedModel(PreTrainedModel):
|
|
2025 |
)
|
2026 |
self.config = config
|
2027 |
|
|
|
2028 |
class NomicVisionModel(NomicVisionPreTrainedModel):
|
2029 |
def __init__(self, config):
|
2030 |
super().__init__(config)
|
@@ -2035,7 +2068,9 @@ class NomicVisionModel(NomicVisionPreTrainedModel):
|
|
2035 |
self.selector = NomicMultiHeadAttentionPooling(config)
|
2036 |
|
2037 |
self.global_pool = getattr(config, "global_pool", None)
|
2038 |
-
self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(
|
|
|
|
|
2039 |
|
2040 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
2041 |
|
@@ -2052,20 +2087,22 @@ class NomicVisionModel(NomicVisionPreTrainedModel):
|
|
2052 |
|
2053 |
original_dtype = embeddings.dtype
|
2054 |
|
2055 |
-
hidden_states = embeddings
|
2056 |
# unused but easier to pass to gradient checkpointing as words
|
2057 |
residual = None
|
2058 |
for layer in self.layers:
|
2059 |
# need to pass none for backwards compatability
|
2060 |
-
hidden_states, _, residual = layer(
|
|
|
|
|
2061 |
|
2062 |
hidden_states = hidden_states + residual
|
2063 |
if self.global_pool == "avg":
|
2064 |
-
hidden_states = hidden_states[:, self.num_prefix_tokens:].mean(dim=1)
|
2065 |
|
2066 |
pooled_output = self.selector(hidden_states)
|
2067 |
|
2068 |
return BaseModelOutputWithPast(
|
2069 |
last_hidden_state=pooled_output,
|
2070 |
hidden_states=hidden_states,
|
2071 |
-
)
|
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
6 |
+
import collections
|
7 |
import logging
|
8 |
|
9 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
10 |
import math
|
|
|
|
|
11 |
import os
|
12 |
import re
|
13 |
from collections import OrderedDict
|
14 |
from functools import partial
|
15 |
from typing import List, Optional, Tuple, Union
|
16 |
|
17 |
+
import numpy as np
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from einops import rearrange, repeat
|
22 |
from safetensors.torch import load_file as safe_load_file
|
23 |
+
from torch.nn.modules.utils import _pair
|
24 |
+
from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
|
25 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
26 |
from transformers.models.bert.modeling_bert import (
|
27 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
28 |
MaskedLMOutput,
|
|
|
30 |
)
|
31 |
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
32 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
|
|
|
|
33 |
|
34 |
from .configuration_hf_nomic_bert import NomicBertConfig
|
35 |
|
36 |
+
try:
|
37 |
+
from torch.nn.functional import scaled_dot_product_attention
|
38 |
+
except ImportError:
|
39 |
+
scaled_dot_product_attention = None
|
40 |
+
|
41 |
logger = logging.getLogger(__name__)
|
42 |
|
43 |
|
|
|
71 |
else: # Try loading from HF hub instead of from local files
|
72 |
resolved_archive_file = None
|
73 |
for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
74 |
+
resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
|
|
|
|
|
75 |
if resolved_archive_file is not None:
|
76 |
if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
77 |
load_safe = True
|
|
|
276 |
|
277 |
return state_dict
|
278 |
|
279 |
+
|
280 |
def _trunc_normal_(tensor, mean, std, a, b):
|
281 |
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
282 |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
283 |
def norm_cdf(x):
|
284 |
# Computes standard normal cumulative distribution function
|
285 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
286 |
|
287 |
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
288 |
+
print(
|
289 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
290 |
+
"The distribution of values may be incorrect.",
|
291 |
+
stacklevel=2,
|
292 |
+
)
|
293 |
|
294 |
# Values are generated by using a truncated uniform distribution and
|
295 |
# then using the inverse CDF for the normal distribution.
|
|
|
306 |
tensor.erfinv_()
|
307 |
|
308 |
# Transform to proper mean, std
|
309 |
+
tensor.mul_(std * math.sqrt(2.0))
|
310 |
tensor.add_(mean)
|
311 |
|
312 |
# Clamp to ensure it's in the proper range
|
313 |
tensor.clamp_(min=a, max=b)
|
314 |
return tensor
|
315 |
|
316 |
+
|
317 |
+
def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
318 |
r"""Fills the input Tensor with values drawn from a truncated
|
319 |
normal distribution. The values are effectively drawn from the
|
320 |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
|
455 |
if module.padding_idx is not None:
|
456 |
nn.init.zeros_(module.weight[module.padding_idx])
|
457 |
|
458 |
+
|
459 |
def _ntuple(n):
|
460 |
def parse(x):
|
461 |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
462 |
return tuple(x)
|
463 |
return tuple(repeat(x, n))
|
464 |
+
|
465 |
return parse
|
466 |
|
467 |
|
|
|
489 |
position embeddings (with or without classification token)
|
490 |
"""
|
491 |
grid_h = np.arange(grid_size, dtype=np.float32)
|
492 |
+
|
493 |
grid_w = np.arange(grid_size, dtype=np.float32)
|
494 |
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
495 |
grid = np.stack(grid, axis=0)
|
|
|
533 |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
534 |
return emb
|
535 |
|
536 |
+
|
537 |
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
538 |
"""generate N-D grid in dimension order.
|
539 |
|
|
|
557 |
# the old behaviour of meshgrid was 'ij'
|
558 |
return torch.meshgrid(*tensors)
|
559 |
|
560 |
+
|
561 |
def build_fourier_pos_embed(
|
562 |
+
feat_shape: List[int],
|
563 |
+
bands: Optional[torch.Tensor] = None,
|
564 |
+
num_bands: int = 64,
|
565 |
+
max_res: int = 224,
|
566 |
+
temperature: float = 10000.0,
|
567 |
+
linear_bands: bool = False,
|
568 |
+
include_grid: bool = False,
|
569 |
+
in_pixels: bool = True,
|
570 |
+
ref_feat_shape: Optional[List[int]] = None,
|
571 |
+
dtype: torch.dtype = torch.float32,
|
572 |
+
device: Optional[torch.device] = None,
|
573 |
) -> List[torch.Tensor]:
|
574 |
"""
|
575 |
|
|
|
611 |
dtype = bands.dtype
|
612 |
|
613 |
if in_pixels:
|
614 |
+
t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape]
|
615 |
else:
|
616 |
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
617 |
|
|
|
629 |
|
630 |
|
631 |
def build_rotary_pos_embed(
|
632 |
+
feat_shape: List[int],
|
633 |
+
bands: Optional[torch.Tensor] = None,
|
634 |
+
dim: int = 64,
|
635 |
+
max_res: int = 224,
|
636 |
+
temperature: float = 10000.0,
|
637 |
+
linear_bands: bool = False,
|
638 |
+
in_pixels: bool = True,
|
639 |
+
ref_feat_shape: Optional[List[int]] = None,
|
640 |
+
dtype: torch.dtype = torch.float32,
|
641 |
+
device: Optional[torch.device] = None,
|
642 |
):
|
643 |
"""
|
644 |
|
|
|
676 |
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
677 |
return sin_emb, cos_emb
|
678 |
|
679 |
+
|
680 |
def freq_bands(
|
681 |
+
num_bands: int,
|
682 |
+
temperature: float = 10000.0,
|
683 |
+
step: int = 2,
|
684 |
+
device: Optional[torch.device] = None,
|
685 |
) -> torch.Tensor:
|
686 |
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
687 |
+
bands = 1.0 / (temperature**exp)
|
688 |
return bands
|
689 |
|
690 |
+
|
691 |
def pixel_freq_bands(
|
692 |
+
num_bands: int,
|
693 |
+
max_freq: float = 224.0,
|
694 |
+
linear_bands: bool = True,
|
695 |
+
device: Optional[torch.device] = None,
|
696 |
):
|
697 |
if linear_bands:
|
698 |
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
|
|
|
700 |
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
|
701 |
return bands * torch.pi
|
702 |
|
703 |
+
|
704 |
def rot(x):
|
705 |
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
706 |
|
707 |
+
|
708 |
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
709 |
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
710 |
if sin_emb.ndim == 3:
|
711 |
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
712 |
return x * cos_emb + rot(x) * sin_emb
|
713 |
|
714 |
+
|
715 |
# taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
|
716 |
class NomicVisionRotaryEmbeddingCat(nn.Module):
|
717 |
+
"""Rotary position embedding w/ concatenatd sin & cos
|
718 |
|
719 |
The following impl/resources were referenced for this impl:
|
720 |
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
|
|
722 |
"""
|
723 |
|
724 |
def __init__(
|
725 |
+
self,
|
726 |
+
dim,
|
727 |
+
max_res=224,
|
728 |
+
temperature=10000,
|
729 |
+
in_pixels=True,
|
730 |
+
linear_bands: bool = False,
|
731 |
+
feat_shape: Optional[List[int]] = None,
|
732 |
+
ref_feat_shape: Optional[List[int]] = None,
|
733 |
):
|
734 |
super().__init__()
|
735 |
self.dim = dim
|
|
|
796 |
pos_embed = self.get_embed(x.shape[2:])
|
797 |
return apply_rot_embed_cat(x, pos_embed)
|
798 |
|
799 |
+
|
800 |
class NomicVisionPatchEmbeddings(nn.Module):
|
801 |
def __init__(
|
802 |
self,
|
|
|
818 |
self.sinusoidal_pos_embedding = False
|
819 |
self.no_embed_class = getattr(config, "no_embed_class", False)
|
820 |
|
821 |
+
self.cls_token = (
|
822 |
+
nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
|
823 |
+
)
|
824 |
if config.learned_pos_embedding:
|
825 |
# this is the default in DINO
|
826 |
self.learned_pos_embedding = True
|
827 |
# hack for timm dinov2 with registers
|
828 |
num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
|
829 |
+
self.pos_embed = (
|
830 |
+
nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02)
|
831 |
+
if getattr(config, "use_pos_embed", True)
|
832 |
+
else None
|
833 |
+
)
|
834 |
elif getattr(config, "sinusoidal_pos_embedding", False):
|
835 |
self.sinusoidal_pos_embedding = True
|
836 |
if getattr(config, "use_pos_embed", True):
|
|
|
840 |
else:
|
841 |
self.pos_embed = None
|
842 |
else:
|
843 |
+
self.pos_embed = (
|
844 |
+
nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02)
|
845 |
+
if getattr(config, "use_pos_embed", True)
|
846 |
+
else None
|
847 |
+
)
|
848 |
|
849 |
if getattr(config, "register_tokens", 0) > 0:
|
850 |
self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
|
851 |
else:
|
852 |
+
self.reg_token = None
|
853 |
|
854 |
if config.mask_token:
|
855 |
self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
|
|
|
868 |
else:
|
869 |
self.rope = None
|
870 |
|
|
|
871 |
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
872 |
"""
|
873 |
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
|
|
937 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
938 |
else:
|
939 |
if self.pos_embed is not None:
|
940 |
+
embeddings = embeddings + self.pos_embed
|
941 |
if to_cat:
|
942 |
embeddings = torch.cat(to_cat + [embeddings], dim=1)
|
943 |
else:
|
|
|
948 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
949 |
else:
|
950 |
if self.pos_embed is not None:
|
951 |
+
embeddings = embeddings + self.pos_embed
|
952 |
|
953 |
embeddings = self.patch_dropout(embeddings)
|
954 |
|
|
|
1374 |
qkv = rearrange(qkv, "b h three s d -> b s three h d")
|
1375 |
elif rope is not None:
|
1376 |
q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
|
1377 |
+
q = torch.cat(
|
1378 |
+
[q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2
|
1379 |
+
).type_as(q)
|
1380 |
+
k = torch.cat(
|
1381 |
+
[k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2
|
1382 |
+
).type_as(q)
|
1383 |
|
1384 |
qkv = torch.stack([q, k, v], dim=-2)
|
1385 |
qkv = rearrange(qkv, "b h s three d -> b s three h d")
|
|
|
1389 |
query = query.permute(0, 2, 1, 3)
|
1390 |
key = key.permute(0, 2, 1, 3)
|
1391 |
value = value.permute(0, 2, 1, 3)
|
1392 |
+
if scaled_dot_product_attention is not None:
|
1393 |
+
attn_output = F.scaled_dot_product_attention(
|
1394 |
+
query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False
|
1395 |
+
)
|
1396 |
+
else:
|
1397 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
|
1398 |
+
if attention_mask is not None:
|
1399 |
+
attention_scores = attention_scores + attention_mask
|
1400 |
|
1401 |
+
attentions_probs = F.softmax(attention_scores, dim=-1)
|
1402 |
+
attentions_probs = self.drop(attentions_probs)
|
|
|
1403 |
|
1404 |
+
attn_output = torch.matmul(attentions_probs, value)
|
|
|
1405 |
|
|
|
1406 |
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
|
1407 |
|
1408 |
attn_output = self.out_proj(attn_output)
|
|
|
1840 |
attentions=outputs.attentions,
|
1841 |
)
|
1842 |
|
1843 |
+
|
1844 |
def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
|
1845 |
return GPT2Config(
|
1846 |
n_embd=vit_config.hidden_size,
|
|
|
1848 |
n_head=vit_config.num_attention_heads,
|
1849 |
n_inner=vit_config.intermediate_size,
|
1850 |
activation_function=vit_config.hidden_act,
|
1851 |
+
vocab_size=0, # no vocab since using patches
|
1852 |
n_positions=0, # No absolute position embedding
|
1853 |
resid_pdrop=0.0, # No dropout
|
1854 |
embd_pdrop=getattr(vit_config, "dropout", 0.0),
|
|
|
1884 |
mask_token=False,
|
1885 |
learned_pos_embedding=False,
|
1886 |
patch_dropout=0,
|
1887 |
+
sinusoidal_pos_embedding=vit_config.model_type == "vit_mae",
|
1888 |
)
|
1889 |
|
1890 |
+
|
1891 |
class NomicAttentionPooling(nn.Module):
|
1892 |
+
def __init__(self, config):
|
|
|
|
|
|
|
1893 |
super().__init__()
|
1894 |
self.embed_dim = config.n_embd
|
1895 |
self.use_flash_attn = config.use_flash_attn
|
|
|
1910 |
|
1911 |
self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
1912 |
self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
|
1913 |
+
|
1914 |
self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
1915 |
|
1916 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
|
|
1918 |
self.drop = nn.Dropout(config.attn_pdrop)
|
1919 |
|
1920 |
def init_weights(self):
|
1921 |
+
trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5)
|
1922 |
|
1923 |
def forward(
|
1924 |
self,
|
|
|
1969 |
|
1970 |
return attn_output
|
1971 |
|
1972 |
+
|
1973 |
class NomicMultiHeadAttentionPooling(nn.Module):
|
1974 |
def __init__(
|
1975 |
self,
|
|
|
2024 |
"""
|
2025 |
|
2026 |
attn_outputs = self.attn(
|
2027 |
+
hidden_states,
|
2028 |
+
attention_mask=attention_mask,
|
2029 |
+
)
|
2030 |
|
2031 |
normed = self.norm1(attn_outputs)
|
2032 |
hidden_states = hidden_states + self.mlp(normed)
|
2033 |
|
2034 |
return hidden_states
|
2035 |
|
2036 |
+
|
2037 |
class NomicVisionPreTrainedModel(PreTrainedModel):
|
2038 |
"""An abstract class to handle weights initialization and
|
2039 |
a simple interface for dowloading and loading pretrained models.
|
|
|
2057 |
)
|
2058 |
self.config = config
|
2059 |
|
2060 |
+
|
2061 |
class NomicVisionModel(NomicVisionPreTrainedModel):
|
2062 |
def __init__(self, config):
|
2063 |
super().__init__(config)
|
|
|
2068 |
self.selector = NomicMultiHeadAttentionPooling(config)
|
2069 |
|
2070 |
self.global_pool = getattr(config, "global_pool", None)
|
2071 |
+
self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(
|
2072 |
+
config, "register_tokens", 0
|
2073 |
+
)
|
2074 |
|
2075 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
2076 |
|
|
|
2087 |
|
2088 |
original_dtype = embeddings.dtype
|
2089 |
|
2090 |
+
hidden_states = embeddings
|
2091 |
# unused but easier to pass to gradient checkpointing as words
|
2092 |
residual = None
|
2093 |
for layer in self.layers:
|
2094 |
# need to pass none for backwards compatability
|
2095 |
+
hidden_states, _, residual = layer(
|
2096 |
+
hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope
|
2097 |
+
)
|
2098 |
|
2099 |
hidden_states = hidden_states + residual
|
2100 |
if self.global_pool == "avg":
|
2101 |
+
hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1)
|
2102 |
|
2103 |
pooled_output = self.selector(hidden_states)
|
2104 |
|
2105 |
return BaseModelOutputWithPast(
|
2106 |
last_hidden_state=pooled_output,
|
2107 |
hidden_states=hidden_states,
|
2108 |
+
)
|