|
|
|
|
|
|
|
|
|
import collections |
|
import logging |
|
|
|
import json |
|
import math |
|
import os |
|
import re |
|
from collections import OrderedDict |
|
from functools import partial |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from safetensors.torch import load_file as safe_load_file |
|
from torch.nn.modules.utils import _pair |
|
from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel |
|
from transformers.models.bert.modeling_bert import ( |
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
MaskedLMOutput, |
|
SequenceClassifierOutput, |
|
) |
|
from transformers.modeling_outputs import ( |
|
MaskedLMOutput, |
|
MultipleChoiceModelOutput, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutput, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME |
|
from transformers.utils.hub import cached_file, get_checkpoint_shard_files |
|
|
|
|
|
class ContextualNomicBertConfig(GPT2Config): |
|
model_type = "nomic_bert" |
|
|
|
def __init__( |
|
self, |
|
prenorm=False, |
|
parallel_block=False, |
|
parallel_block_tied_norm=False, |
|
rotary_emb_fraction=0.0, |
|
fused_dropout_add_ln=False, |
|
fused_bias_fc=False, |
|
use_flash_attn=False, |
|
use_xentropy=False, |
|
qkv_proj_bias=True, |
|
rotary_emb_base=10_000, |
|
rotary_emb_scale_base=None, |
|
rotary_emb_interleaved=False, |
|
mlp_fc1_bias=True, |
|
mlp_fc2_bias=True, |
|
use_rms_norm=False, |
|
causal=False, |
|
type_vocab_size=2, |
|
dense_seq_output=True, |
|
pad_vocab_size_multiple=1, |
|
tie_word_embeddings=True, |
|
rotary_scaling_factor=None, |
|
max_trained_positions=2048, |
|
**kwargs, |
|
): |
|
self.prenorm = prenorm |
|
self.parallel_block = parallel_block |
|
self.parallel_block_tied_norm = parallel_block_tied_norm |
|
self.rotary_emb_fraction = rotary_emb_fraction |
|
self.tie_word_embeddings = tie_word_embeddings |
|
self.fused_dropout_add_ln = fused_dropout_add_ln |
|
self.fused_bias_fc = fused_bias_fc |
|
self.use_flash_attn = use_flash_attn |
|
self.use_xentropy = use_xentropy |
|
self.qkv_proj_bias = qkv_proj_bias |
|
self.rotary_emb_base = rotary_emb_base |
|
self.rotary_emb_scale_base = rotary_emb_scale_base |
|
self.rotary_emb_interleaved = rotary_emb_interleaved |
|
self.mlp_fc1_bias = mlp_fc1_bias |
|
self.mlp_fc2_bias = mlp_fc2_bias |
|
self.use_rms_norm = use_rms_norm |
|
self.causal = causal |
|
self.type_vocab_size = type_vocab_size |
|
self.dense_seq_output = dense_seq_output |
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple |
|
self.rotary_scaling_factor = rotary_scaling_factor |
|
self.max_trained_positions = max_trained_positions |
|
|
|
super().__init__(**kwargs) |
|
try: |
|
from torch.nn.functional import scaled_dot_product_attention |
|
except ImportError: |
|
scaled_dot_product_attention = None |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None): |
|
|
|
mapped_device = "cpu" if dtype not in [torch.float32, None] else device |
|
is_sharded = False |
|
load_safe = False |
|
resolved_archive_file = None |
|
|
|
weights_path = os.path.join(model_name, WEIGHTS_NAME) |
|
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) |
|
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) |
|
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) |
|
|
|
if os.path.isfile(weights_path): |
|
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) |
|
elif os.path.isfile(weights_index_path): |
|
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False) |
|
is_sharded = True |
|
elif os.path.isfile(safe_weights_path): |
|
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) |
|
load_safe = True |
|
elif os.path.isfile(safe_weights_index_path): |
|
resolved_archive_file = cached_file( |
|
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False |
|
) |
|
is_sharded = True |
|
load_safe = True |
|
else: |
|
resolved_archive_file = None |
|
for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: |
|
resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False) |
|
if resolved_archive_file is not None: |
|
if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]: |
|
load_safe = True |
|
if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: |
|
is_sharded = True |
|
break |
|
|
|
if resolved_archive_file is None: |
|
raise EnvironmentError(f"Model name {model_name} was not found.") |
|
|
|
if load_safe: |
|
loader = partial(safe_load_file, device=mapped_device) |
|
else: |
|
loader = partial(torch.load, map_location=mapped_device) |
|
|
|
if is_sharded: |
|
|
|
|
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file) |
|
state_dict = {} |
|
for sharded_file in resolved_archive_file: |
|
state_dict.update(loader(sharded_file)) |
|
else: |
|
state_dict = loader(resolved_archive_file) |
|
|
|
if dtype is not None: |
|
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} |
|
state_dict = {k: v.to(device=device) for k, v in state_dict.items()} |
|
return state_dict |
|
|
|
|
|
def filter_shapes(state_dict, model): |
|
""" |
|
Filters the state dict to match the current model shape. |
|
""" |
|
filtered_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if key in model.state_dict(): |
|
if value.shape == model.state_dict()[key].shape: |
|
filtered_state_dict[key] = value |
|
return filtered_state_dict |
|
|
|
|
|
def remap_bert_state_dict( |
|
state_dict, |
|
config, |
|
remove_bert=False, |
|
remove_cls_weights=False, |
|
add_pooling_layer=False, |
|
): |
|
""" |
|
Map the state_dict of a Huggingface BERT model to be flash_attn compatible. |
|
""" |
|
|
|
def add_bert_prefix(key): |
|
|
|
if key.startswith("bert.") or key.startswith("cls."): |
|
return key |
|
return f"bert.{key}" |
|
|
|
state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items()) |
|
|
|
|
|
def key_mapping_ln_gamma_beta(key): |
|
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) |
|
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) |
|
return key |
|
|
|
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) |
|
|
|
|
|
def key_mapping_layers(key): |
|
return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key) |
|
|
|
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) |
|
|
|
|
|
def key_mapping_ln(key): |
|
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) |
|
key = re.sub( |
|
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", |
|
r"bert.encoder.layers.\1.norm1.\2", |
|
key, |
|
) |
|
key = re.sub( |
|
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", |
|
r"bert.encoder.layers.\1.norm2.\2", |
|
key, |
|
) |
|
key = re.sub( |
|
r"^cls.predictions.transform.LayerNorm.(weight|bias)", |
|
r"cls.predictions.transform.layer_norm.\1", |
|
key, |
|
) |
|
return key |
|
|
|
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) |
|
|
|
|
|
def key_mapping_mlp(key): |
|
key = re.sub( |
|
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", |
|
r"bert.encoder.layers.\1.mlp.fc1.\2", |
|
key, |
|
) |
|
key = re.sub( |
|
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", |
|
r"bert.encoder.layers.\1.mlp.fc2.\2", |
|
key, |
|
) |
|
return key |
|
|
|
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) |
|
|
|
|
|
last_layer_subset = getattr(config, "last_layer_subset", False) |
|
for d in range(config.num_hidden_layers): |
|
if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict: |
|
continue |
|
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") |
|
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") |
|
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") |
|
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") |
|
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") |
|
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") |
|
if not (last_layer_subset and d == config.num_hidden_layers - 1): |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) |
|
else: |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq |
|
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0) |
|
|
|
def key_mapping_attn(key): |
|
return re.sub( |
|
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", |
|
r"bert.encoder.layers.\1.attn.out_proj.\2", |
|
key, |
|
) |
|
|
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) |
|
|
|
def key_mapping_decoder_bias(key): |
|
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) |
|
|
|
|
|
state_dict.pop("cls.seq_relationship.weight", None) |
|
state_dict.pop("cls.seq_relationship.bias", None) |
|
state_dict.pop("bert.embeddings.position_ids", None) |
|
|
|
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) |
|
|
|
if remove_cls_weights: |
|
cls_weights = [ |
|
"cls.predictions.decoder.bias", |
|
"cls.predictions.transform.dense.weight", |
|
"cls.predictions.transform.dense.bias", |
|
"cls.predictions.transform.layer_norm.weight", |
|
"cls.predictions.transform.layer_norm.bias", |
|
"cls.predictions.decoder.weight", |
|
] |
|
for weight in cls_weights: |
|
state_dict.pop(weight, None) |
|
|
|
|
|
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) |
|
if pad_vocab_size_multiple > 1: |
|
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] |
|
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( |
|
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) |
|
) |
|
if not remove_cls_weights: |
|
decoder_weight = state_dict["cls.predictions.decoder.weight"] |
|
state_dict["cls.predictions.decoder.weight"] = F.pad( |
|
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) |
|
) |
|
|
|
|
|
|
|
if "cls.predictions.decoder.bias" in state_dict: |
|
decoder_bias = state_dict["cls.predictions.decoder.bias"] |
|
state_dict["cls.predictions.decoder.bias"] = F.pad( |
|
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 |
|
) |
|
|
|
if add_pooling_layer is False: |
|
pooler_weights = [ |
|
"bert.pooler.dense.weight", |
|
"bert.pooler.dense.bias", |
|
] |
|
for key in pooler_weights: |
|
state_dict.pop(key, None) |
|
|
|
if remove_bert: |
|
|
|
def remove_bert_prefix(key): |
|
key = re.sub(r"^bert.", "", key) |
|
return key |
|
|
|
state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items()) |
|
|
|
return state_dict |
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
print( |
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2, |
|
) |
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
|
|
def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): |
|
r"""Fills the input Tensor with values drawn from a truncated |
|
normal distribution. The values are effectively drawn from the |
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
with values outside :math:`[a, b]` redrawn until they are within |
|
the bounds. The method used for generating the random values works |
|
best when :math:`a \leq \text{mean} \leq b`. |
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the |
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 |
|
and the result is subsquently scaled and shifted by the mean and std args. |
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
a: the minimum cutoff value |
|
b: the maximum cutoff value |
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.trunc_normal_(w) |
|
""" |
|
with torch.no_grad(): |
|
_trunc_normal_(tensor, 0, 1.0, a, b) |
|
tensor.mul_(std).add_(mean) |
|
return tensor |
|
|
|
|
|
class ContextualNomicBertPreTrainedModel(PreTrainedModel): |
|
"""An abstract class to handle weights initialization and |
|
a simple interface for dowloading and loading pretrained models. |
|
""" |
|
|
|
config_class = ContextualNomicBertConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["Block"] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def __init__(self, config, *inputs, **kwargs): |
|
super().__init__(config) |
|
if not isinstance(config, GPT2Config): |
|
raise ValueError( |
|
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " |
|
"To create a model from a Google pretrained model use " |
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( |
|
self.__class__.__name__, self.__class__.__name__ |
|
) |
|
) |
|
self.config = config |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name, config=None, *inputs, **kwargs): |
|
""" |
|
Instantiate a ContextualNomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict. |
|
Download and cache the pre-trained model file if needed. |
|
Params: |
|
pretrained_model_name_or_path: either: |
|
- a path or url to a pretrained model archive containing: |
|
. `bert_config.json` a configuration file for the model |
|
. `pytorch_model.bin` a PyTorch dump of a ContextualNomicBertForPretraining instance |
|
- a path or url to a pretrained model archive containing: |
|
. `bert_config.json` a configuration file for the model |
|
. `model.chkpt` a TensorFlow checkpoint |
|
*inputs, **kwargs: additional input for the specific ContextualNomicBert class |
|
(ex: num_labels for ContextualNomicBertForSequenceClassification) |
|
""" |
|
|
|
if config is None: |
|
config = cls.config_class.from_pretrained(model_name) |
|
remove_cls = cls != ContextualNomicBertForPreTraining |
|
remove_bert_prefix = cls not in [ContextualNomicBertForPreTraining, ContextualNomicBertForSequenceClassification, ContextualNomicBertForTokenClassification, ContextualNomicBertForMultipleChoice, ContextualNomicBertForQuestionAnswering] |
|
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False) |
|
num_labels = kwargs.pop("num_labels", None) |
|
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None) |
|
strict = kwargs.pop("strict", True) |
|
dtype = kwargs.pop("torch_dtype", None) |
|
if rotary_scaling_factor: |
|
config.rotary_scaling_factor = rotary_scaling_factor |
|
|
|
if config.n_positions <= 0 and config.rotary_emb_fraction > 0: |
|
config.n_positions = 2048 |
|
if num_labels: |
|
config.num_labels = num_labels |
|
|
|
if "add_pooling_layer" in kwargs: |
|
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer")) |
|
else: |
|
if cls == ContextualNomicBertModel: |
|
model = cls(config, *inputs, add_pooling_layer=False) |
|
else: |
|
model = cls(config, *inputs) |
|
|
|
if dtype is not None: |
|
model = model.to(dtype=dtype) |
|
|
|
|
|
|
|
if os.path.exists(model_name): |
|
model_path = f"{model_name}/pytorch_model.bin" |
|
if os.path.exists(model_path): |
|
state_dict = torch.load(f"{model_name}/pytorch_model.bin") |
|
else: |
|
model_path = f"{model_name}/model.safetensors" |
|
if not os.path.exists(model_path): |
|
raise ValueError(f"Model path {model_path} not found") |
|
state_dict = safe_load_file(model_path) |
|
|
|
if ignore_mismatched_shapes: |
|
state_dict = filter_shapes(state_dict, model) |
|
load_return = model.load_state_dict(state_dict, strict=False) |
|
else: |
|
|
|
state_dict = state_dict_from_pretrained(model_name, dtype=dtype) |
|
state_dict = remap_bert_state_dict( |
|
state_dict, |
|
config, |
|
remove_bert=remove_bert_prefix, |
|
remove_cls_weights=remove_cls, |
|
add_pooling_layer=getattr(config, "add_pooling_layer", False), |
|
) |
|
if ignore_mismatched_shapes: |
|
state_dict = filter_shapes(state_dict, model) |
|
|
|
load_return = model.load_state_dict(state_dict, strict=strict) |
|
logger.warning(load_return) |
|
return model |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, ContextualNomicBertEncoder): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
|
|
def _init_weights(module, initializer_range=0.02): |
|
if isinstance(module, nn.Linear): |
|
nn.init.normal_(module.weight, std=initializer_range) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, std=initializer_range) |
|
if module.padding_idx is not None: |
|
nn.init.zeros_(module.weight[module.padding_idx]) |
|
|
|
|
|
def _ntuple(n): |
|
def parse(x): |
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
return tuple(x) |
|
return tuple(repeat(x, n)) |
|
|
|
return parse |
|
|
|
|
|
to_1tuple = _ntuple(1) |
|
to_2tuple = _ntuple(2) |
|
to_3tuple = _ntuple(3) |
|
to_4tuple = _ntuple(4) |
|
to_ntuple = _ntuple |
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): |
|
""" |
|
Create 2D sin/cos positional embeddings. |
|
Args: |
|
embed_dim (`int`): |
|
Embedding dimension. |
|
grid_size (`int`): |
|
The grid height and width. |
|
add_cls_token (`bool`, *optional*, defaults to `False`): |
|
Whether or not to add a classification (CLS) token. |
|
Returns: |
|
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the |
|
position embeddings (with or without classification token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
|
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if add_cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
if embed_dim % 2 != 0: |
|
raise ValueError("embed_dim must be even") |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) |
|
""" |
|
if embed_dim % 2 != 0: |
|
raise ValueError("embed_dim must be even") |
|
|
|
omega = np.arange(embed_dim // 2, dtype=float) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: |
|
"""generate N-D grid in dimension order. |
|
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. |
|
That is, the statement |
|
[X1,X2,X3] = ndgrid(x1,x2,x3) |
|
produces the same result as |
|
[X2,X1,X3] = meshgrid(x2,x1,x3) |
|
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make |
|
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). |
|
""" |
|
try: |
|
return torch.meshgrid(*tensors, indexing='ij') |
|
except TypeError: |
|
|
|
|
|
return torch.meshgrid(*tensors) |
|
|
|
|
|
def build_fourier_pos_embed( |
|
feat_shape: List[int], |
|
bands: Optional[torch.Tensor] = None, |
|
num_bands: int = 64, |
|
max_res: int = 224, |
|
temperature: float = 10000.0, |
|
linear_bands: bool = False, |
|
include_grid: bool = False, |
|
in_pixels: bool = True, |
|
ref_feat_shape: Optional[List[int]] = None, |
|
dtype: torch.dtype = torch.float32, |
|
device: Optional[torch.device] = None, |
|
) -> List[torch.Tensor]: |
|
""" |
|
Args: |
|
feat_shape: Feature shape for embedding. |
|
bands: Pre-calculated frequency bands. |
|
num_bands: Number of frequency bands (determines output dim). |
|
max_res: Maximum resolution for pixel based freq. |
|
temperature: Temperature for non-pixel freq. |
|
linear_bands: Linear band spacing for pixel based freq. |
|
include_grid: Include the spatial grid in output. |
|
in_pixels: Output in pixel freq. |
|
ref_feat_shape: Reference feature shape for resize / fine-tune. |
|
dtype: Output dtype. |
|
device: Output device. |
|
Returns: |
|
""" |
|
if bands is None: |
|
if in_pixels: |
|
bands = pixel_freq_bands( |
|
num_bands, |
|
float(max_res), |
|
linear_bands=linear_bands, |
|
device=device, |
|
) |
|
else: |
|
bands = freq_bands( |
|
num_bands, |
|
temperature=temperature, |
|
step=1, |
|
device=device, |
|
) |
|
else: |
|
if device is None: |
|
device = bands.device |
|
if dtype is None: |
|
dtype = bands.dtype |
|
|
|
if in_pixels: |
|
t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape] |
|
else: |
|
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] |
|
|
|
if ref_feat_shape is not None: |
|
|
|
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] |
|
|
|
grid = torch.stack(ndgrid(t), dim=-1) |
|
grid = grid.unsqueeze(-1) |
|
pos = grid * bands |
|
|
|
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) |
|
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] |
|
return out |
|
|
|
|
|
def build_rotary_pos_embed( |
|
feat_shape: List[int], |
|
bands: Optional[torch.Tensor] = None, |
|
dim: int = 64, |
|
max_res: int = 224, |
|
temperature: float = 10000.0, |
|
linear_bands: bool = False, |
|
in_pixels: bool = True, |
|
ref_feat_shape: Optional[List[int]] = None, |
|
dtype: torch.dtype = torch.float32, |
|
device: Optional[torch.device] = None, |
|
): |
|
""" |
|
Args: |
|
feat_shape: Spatial shape of the target tensor for embedding. |
|
bands: Optional pre-generated frequency bands |
|
dim: Output dimension of embedding tensor. |
|
max_res: Maximum resolution for pixel mode. |
|
temperature: Temperature (inv freq) for non-pixel mode |
|
linear_bands: Linearly (instead of log) spaced bands for pixel mode |
|
in_pixels: Pixel vs language (inv freq) mode. |
|
dtype: Output dtype. |
|
device: Output device. |
|
Returns: |
|
""" |
|
sin_emb, cos_emb = build_fourier_pos_embed( |
|
feat_shape, |
|
bands=bands, |
|
num_bands=dim // 4, |
|
max_res=max_res, |
|
temperature=temperature, |
|
linear_bands=linear_bands, |
|
in_pixels=in_pixels, |
|
ref_feat_shape=ref_feat_shape, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
num_spatial_dim = 1 |
|
|
|
for x in feat_shape: |
|
num_spatial_dim *= x |
|
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) |
|
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) |
|
return sin_emb, cos_emb |
|
|
|
|
|
def freq_bands( |
|
num_bands: int, |
|
temperature: float = 10000.0, |
|
step: int = 2, |
|
device: Optional[torch.device] = None, |
|
) -> torch.Tensor: |
|
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands |
|
bands = 1.0 / (temperature**exp) |
|
return bands |
|
|
|
|
|
def pixel_freq_bands( |
|
num_bands: int, |
|
max_freq: float = 224.0, |
|
linear_bands: bool = True, |
|
device: Optional[torch.device] = None, |
|
): |
|
if linear_bands: |
|
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) |
|
else: |
|
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) |
|
return bands * torch.pi |
|
|
|
|
|
def rot(x): |
|
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) |
|
|
|
|
|
def apply_rot_embed_cat(x: torch.Tensor, emb): |
|
sin_emb, cos_emb = emb.tensor_split(2, -1) |
|
if sin_emb.ndim == 3: |
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) |
|
return x * cos_emb + rot(x) * sin_emb |
|
|
|
|
|
class ContextualNomicBertEmbeddings(nn.Module): |
|
def __init__(self, config): |
|
""" |
|
If max_position_embeddings <= 0, there's no position embeddings |
|
If type_vocab_size <= 0, there's no token type embeddings |
|
""" |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0 |
|
self.type_vocab_size = config.type_vocab_size |
|
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0: |
|
self.position_embeddings = nn.Embedding( |
|
config.max_position_embeddings, |
|
config.hidden_size, |
|
) |
|
if self.type_vocab_size > 0: |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
|
def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None): |
|
""" |
|
input_ids: (batch, seqlen) |
|
position_ids: (batch, seqlen) |
|
token_type_ids: (batch, seqlen) |
|
""" |
|
if inputs_embeds is None: |
|
embeddings = self.word_embeddings(input_ids) |
|
else: |
|
embeddings = inputs_embeds |
|
batch_size, seqlen, _ = embeddings.shape |
|
|
|
if self.type_vocab_size > 0: |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
embeddings = embeddings + token_type_embeddings |
|
|
|
if self.max_position_embeddings > 0: |
|
if position_ids is None: |
|
position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = embeddings + position_embeddings |
|
return embeddings |
|
|
|
|
|
class ContextualNomicBertMLP(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
activation=F.gelu, |
|
bias1=True, |
|
bias2=True, |
|
return_residual=False, |
|
fused_bias_fc=False, |
|
): |
|
super().__init__() |
|
out_features = out_features if out_features is not None else in_features |
|
hidden_features = hidden_features if hidden_features is not None else in_features * 4 |
|
self.return_residual = return_residual |
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1) |
|
approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" |
|
self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation |
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) |
|
|
|
def forward(self, x): |
|
y = self.fc1(x) |
|
y = self.activation(y) |
|
y = self.fc2(y) |
|
return y if not self.return_residual else (y, x) |
|
|
|
|
|
class NomciBertGatedMLP(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
activation=F.sigmoid, |
|
bias1=True, |
|
bias2=True, |
|
multiple_of=256, |
|
return_residual=False, |
|
fused_bias_fc=True, |
|
device=None, |
|
dtype=None, |
|
norm_layer=False, |
|
): |
|
super().__init__() |
|
out_features = out_features if out_features is not None else in_features |
|
hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3) |
|
hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of) |
|
self.return_residual = return_residual |
|
|
|
self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1) |
|
self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1) |
|
self.activation = activation |
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2) |
|
self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
y = self.fc11(x) |
|
gate = self.fc12(x) |
|
if self.activation == F.sigmoid: |
|
y = F.glu(torch.cat([y, gate], dim=-1), dim=-1) |
|
else: |
|
y = y * self.activation(gate) |
|
|
|
|
|
y = self.norm(y) |
|
|
|
y = self.fc2(y) |
|
return y if not self.return_residual else (y, x) |
|
|
|
|
|
def rotate_half(x, interleaved=False): |
|
if not interleaved: |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
else: |
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) |
|
|
|
|
|
def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False): |
|
""" |
|
x: (batch_size, seqlen, nheads, headdim) |
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) |
|
""" |
|
ro_dim = cos.shape[-1] * 2 |
|
assert ro_dim <= x.shape[-1] |
|
cos, sin = ( |
|
cos[offset : offset + x.shape[1]], |
|
sin[offset : offset + x.shape[1]], |
|
) |
|
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
|
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
|
return torch.cat( |
|
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], |
|
dim=-1, |
|
) |
|
|
|
|
|
class ContextualNomicBertRotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
base=10000.0, |
|
interleaved=False, |
|
scale_base=None, |
|
pos_idx_in_fp32=True, |
|
device=None, |
|
): |
|
""" |
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
|
of 1st half and 2nd half (GPT-NeoX style). |
|
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, |
|
otherwise they might be in lower precision. |
|
This option was added because previously (before 2023-07-02), when we construct |
|
the position indices, we use the dtype of self.inv_freq. In most cases this would |
|
be fp32, but if the model is trained in pure bf16 (not mixed precision), then |
|
self.inv_freq would be bf16, and the position indices are also in bf16. |
|
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the |
|
embeddings for some positions will coincide. |
|
To maintain compatibility with models previously trained in pure bf16, |
|
we add this option. |
|
""" |
|
super().__init__() |
|
self.dim = dim |
|
self.base = float(base) |
|
self.pos_idx_in_fp32 = pos_idx_in_fp32 |
|
|
|
inv_freq = self._compute_inv_freq(device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.interleaved = interleaved |
|
self.scale_base = scale_base |
|
scale = ( |
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
|
if scale_base is not None |
|
else None |
|
) |
|
self.register_buffer("scale", scale, persistent=False) |
|
|
|
self._seq_len_cached = 0 |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
self._cos_k_cached = None |
|
self._sin_k_cached = None |
|
|
|
def _compute_inv_freq(self, device=None): |
|
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) |
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
|
|
|
|
|
|
|
if ( |
|
seqlen > self._seq_len_cached |
|
or self._cos_cached is None |
|
or self._cos_cached.device != device |
|
or self._cos_cached.dtype != dtype |
|
or (self.training and self._cos_cached.is_inference()) |
|
): |
|
self._seq_len_cached = seqlen |
|
|
|
|
|
|
|
if self.pos_idx_in_fp32: |
|
t = torch.arange(seqlen, device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32: |
|
inv_freq = self._compute_inv_freq(device=device) |
|
else: |
|
inv_freq = self.inv_freq |
|
else: |
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
|
inv_freq = self.inv_freq |
|
|
|
|
|
freqs = torch.outer(t, inv_freq) |
|
self._cos_cached = torch.cos(freqs).to(dtype) |
|
self._sin_cached = torch.sin(freqs).to(dtype) |
|
|
|
def forward( |
|
self, |
|
qkv: torch.Tensor, |
|
kv: Optional[torch.Tensor] = None, |
|
seqlen_offset: Union[int, torch.Tensor] = 0, |
|
max_seqlen: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, |
|
else it's just q of shape (batch, seqlen, nheads, headdim) |
|
kv: (batch, seqlen, 2, nheads, headdim) |
|
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. |
|
Most commonly used in inference when we have KV cache. |
|
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one |
|
should pass in max_seqlen, which will update the cos / sin cache up to that length. |
|
Apply rotary embedding *inplace* to qkv and / or kv. |
|
""" |
|
seqlen = qkv.shape[1] |
|
if seqlen > self._seq_len_cached: |
|
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype) |
|
elif max_seqlen is not None: |
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) |
|
elif isinstance(seqlen_offset, int): |
|
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) |
|
|
|
q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) |
|
k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved) |
|
return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2) |
|
|
|
|
|
class ContextualNomicBertDynamicNTKRotaryEmbedding(ContextualNomicBertRotaryEmbedding): |
|
def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs): |
|
super().__init__(**kwargs) |
|
self.rotary_scaling_factor = rotary_scaling_factor |
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
def _compute_inv_freq(self, base=None, device=None): |
|
if base is None: |
|
base = self.base |
|
return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) |
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
|
|
|
|
|
|
|
if seqlen > self.max_position_embeddings: |
|
base = self.base * ( |
|
(self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1) |
|
) ** (self.dim / (self.dim - 2)) |
|
inv_freq = self._compute_inv_freq(base=base, device=device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
if ( |
|
seqlen > self._seq_len_cached |
|
or self._cos_cached is None |
|
or self._cos_cached.device != device |
|
or self._cos_cached.dtype != dtype |
|
or (self.training and self._cos_cached.is_inference()) |
|
): |
|
self._seq_len_cached = seqlen |
|
|
|
|
|
|
|
if self.pos_idx_in_fp32: |
|
t = torch.arange(seqlen, device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32: |
|
if seqlen > self.max_position_embeddings: |
|
base = self.base * ( |
|
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) |
|
) ** (self.dim / (self.dim - 2)) |
|
else: |
|
base = self.base |
|
inv_freq = self._compute_inv_freq(device=device, base=base) |
|
else: |
|
inv_freq = self.inv_freq |
|
else: |
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
|
inv_freq = self.inv_freq |
|
|
|
|
|
freqs = torch.outer(t, inv_freq) |
|
if self.scale is None: |
|
self._cos_cached = torch.cos(freqs).to(dtype) |
|
self._sin_cached = torch.sin(freqs).to(dtype) |
|
else: |
|
power = ( |
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 |
|
) / self.scale_base |
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
|
|
|
class ContextualNomicBertAttention(nn.Module): |
|
"""Multi-head self-attention and cross-attention""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
) -> None: |
|
""" |
|
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. |
|
return_residual: whether to return the input x along with the output. This is for |
|
performance reason: for post-norm architecture, returning the input allows us |
|
to fuse the backward of nn.Linear with the residual connection. |
|
""" |
|
super().__init__() |
|
self.embed_dim = config.n_embd |
|
self.use_flash_attn = config.use_flash_attn |
|
self.fused_bias_fc = config.fused_bias_fc |
|
|
|
self.num_heads = config.n_head |
|
self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads |
|
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" |
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) |
|
|
|
self.register_buffer( |
|
"norm_factor", |
|
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), |
|
persistent=False, |
|
) |
|
|
|
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction |
|
if self.rotary_emb_dim > 0: |
|
if getattr(config, "rotary_scaling_factor", None): |
|
self.rotary_emb = ContextualNomicBertDynamicNTKRotaryEmbedding( |
|
dim=self.rotary_emb_dim, |
|
base=config.rotary_emb_base, |
|
scale_base=config.rotary_emb_scale_base, |
|
interleaved=config.rotary_emb_interleaved, |
|
rotary_scaling_factor=config.rotary_scaling_factor, |
|
max_position_embeddings=config.max_trained_positions, |
|
) |
|
else: |
|
self.rotary_emb = ContextualNomicBertRotaryEmbedding( |
|
dim=self.rotary_emb_dim, |
|
base=config.rotary_emb_base, |
|
scale_base=config.rotary_emb_scale_base, |
|
interleaved=config.rotary_emb_interleaved, |
|
) |
|
|
|
|
|
self.rotary_head_dim = getattr(config, "rotary_head_dim", False) |
|
|
|
self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias) |
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) |
|
self.causal = config.causal |
|
self.drop = nn.Dropout(config.attn_pdrop) |
|
self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1) |
|
self.rotary_start_pos = 0 |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
is_padded_inputs: Optional[bool] = True, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seq_len: Optional[int] = None, |
|
rope: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
has_layer_past = past_key_value is not None |
|
|
|
if has_layer_past: |
|
past_key_value = past_key_value[0] |
|
past_len = past_key_value[1] |
|
else: |
|
past_len = 0 |
|
|
|
qkv = self.Wqkv(hidden_states) |
|
|
|
|
|
if self.rotary_start_pos > 0: |
|
|
|
assert len(qkv.shape) == 3 |
|
|
|
original_qkv = qkv.clone() |
|
|
|
|
|
qkv_zeros = torch.zeros_like(qkv, device=qkv.device) |
|
|
|
is_contextual_token_mask = torch.arange(qkv.shape[1], device=qkv.device) < self.rotary_start_pos |
|
qkv = qkv_zeros.where( |
|
is_contextual_token_mask[None, :, None].expand_as(qkv), |
|
qkv |
|
) |
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) |
|
|
|
past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None |
|
|
|
assert self.rotary_emb_dim > 0 |
|
qkv = rearrange(qkv, "b s three h d -> b h three s d") |
|
qkv = self.rotary_emb(qkv, seqlen_offset=past_len) |
|
|
|
qkv = rearrange(qkv, "b h three s d -> b s three h d") |
|
|
|
|
|
if self.rotary_start_pos > 0: |
|
|
|
|
|
original_qkv = original_qkv.reshape(qkv.shape) |
|
qkv = original_qkv.where( |
|
is_contextual_token_mask[None, :, None, None, None].expand_as(qkv), |
|
qkv |
|
) |
|
|
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
|
|
|
query = query.permute(0, 2, 1, 3) |
|
key = key.permute(0, 2, 1, 3) |
|
value = value.permute(0, 2, 1, 3) |
|
|
|
if scaled_dot_product_attention is not None: |
|
attn_output = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False |
|
) |
|
else: |
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor |
|
if attention_mask is not None: |
|
attention_scores = attention_scores + attention_mask |
|
|
|
attentions_probs = F.softmax(attention_scores, dim=-1) |
|
attentions_probs = self.drop(attentions_probs) |
|
|
|
attn_output = torch.matmul(attentions_probs, value) |
|
|
|
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class ContextualNomicBertBlock(ContextualNomicBertPreTrainedModel): |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
self.prenorm = config.prenorm |
|
self.fused_dropout_add_ln = config.fused_dropout_add_ln |
|
|
|
self.attn = ContextualNomicBertAttention(config) |
|
activation = ( |
|
F.sigmoid |
|
if config.activation_function == "glu" |
|
else (F.silu if config.activation_function == "swiglu" else F.gelu) |
|
) |
|
if config.activation_function in ["glu", "swiglu", "geglu"]: |
|
self.mlp = NomciBertGatedMLP( |
|
config.n_embd, |
|
hidden_features=config.n_inner, |
|
bias1=config.mlp_fc1_bias, |
|
bias2=config.mlp_fc2_bias, |
|
activation=activation, |
|
fused_bias_fc=config.fused_bias_fc, |
|
norm_layer=getattr(config, "norm_mlp", False), |
|
) |
|
else: |
|
self.mlp = ContextualNomicBertMLP( |
|
config.n_embd, |
|
hidden_features=config.n_inner, |
|
bias1=config.mlp_fc1_bias, |
|
bias2=config.mlp_fc2_bias, |
|
activation=activation, |
|
fused_bias_fc=config.fused_bias_fc, |
|
) |
|
|
|
self.dropout1 = nn.Dropout(config.resid_pdrop) |
|
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
self.dropout2 = nn.Dropout(config.resid_pdrop) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
hidden_states2: torch.Tensor, |
|
residual: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
is_padded_inputs: Optional[bool] = True, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seq_len: Optional[int] = None, |
|
rope: Optional[torch.Tensor] = None, |
|
): |
|
r"""Pass the input through the encoder layer. |
|
Args: |
|
hidden_states: the sequence to the encoder layer (required). |
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) |
|
mixer_subset: for cross-attention only. If not None, will take a subset of x |
|
before applying the query projection. Useful for e.g., ViT where we only care |
|
about the CLS token in the last layer. |
|
""" |
|
if self.prenorm: |
|
dropped = self.dropout1(hidden_states) |
|
residual = (dropped + residual) if residual is not None else dropped |
|
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
|
hidden_states = self.attn( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
is_padded_inputs=is_padded_inputs, |
|
cu_seqlens=cu_seqlens, |
|
max_seq_len=max_seq_len, |
|
rope=rope, |
|
) |
|
|
|
dropped = self.dropout2(hidden_states) |
|
residual = (dropped + residual) if residual is not None else dropped |
|
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
|
hidden_states = self.mlp(hidden_states) |
|
|
|
return hidden_states, None, residual |
|
else: |
|
assert residual is None |
|
attn_outputs = self.attn( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
is_padded_inputs=is_padded_inputs, |
|
cu_seqlens=cu_seqlens, |
|
max_seq_len=max_seq_len, |
|
rope=rope, |
|
) |
|
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype)) |
|
mlp_out = self.mlp(hidden_states) |
|
|
|
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype)) |
|
return hidden_states, None, None |
|
|
|
|
|
class ContextualNomicBertEncoder(nn.Module): |
|
def __init__(self, config: GPT2Config): |
|
super().__init__() |
|
self.layers = nn.ModuleList([ContextualNomicBertBlock(config) for _ in range(config.n_layer)]) |
|
self.gradient_checkpointing = False |
|
self.config = config |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
is_padded_inputs: Optional[bool] = True, |
|
rope: Optional[torch.Tensor] = None, |
|
): |
|
"""If subset_mask is not None, we only want output for the subset of the sequence. |
|
This means that we only compute the last layer output for these tokens. |
|
subset_mask: (batch, seqlen), dtype=torch.bool |
|
""" |
|
hidden_states2 = None |
|
residual = None |
|
|
|
for _, layer in enumerate(self.layers): |
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(layer), |
|
hidden_states, |
|
hidden_states2, |
|
residual, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
is_padded_inputs, |
|
output_attentions, |
|
use_cache, |
|
None, |
|
None, |
|
rope, |
|
|
|
|
|
|
|
use_reentrant=False, |
|
) |
|
|
|
else: |
|
hidden_states, hidden_states2, residual = layer( |
|
hidden_states, |
|
hidden_states2, |
|
residual, |
|
attention_mask, |
|
position_ids, |
|
None, |
|
is_padded_inputs, |
|
output_attentions, |
|
use_cache, |
|
rope=rope, |
|
) |
|
return hidden_states |
|
|
|
|
|
class ContextualNomicBertPooler(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.n_embd, config.n_embd) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states, pool=True): |
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] if pool else hidden_states |
|
pooled_output = self.dense(first_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class ContextualNomicBertPredictionHeadTransform(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias) |
|
approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" |
|
if config.activation_function == "swiglu": |
|
self.transform_act_fn = F.silu |
|
else: |
|
self.transform_act_fn = nn.GELU(approximate=approximate) |
|
|
|
self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class ContextualNomicBertLMPredictionHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.transform = ContextualNomicBertPredictionHeadTransform(config) |
|
|
|
self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.transform(hidden_states) |
|
hidden_states = self.decoder(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ContextualNomicBertPreTrainingHeads(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.predictions = ContextualNomicBertLMPredictionHead(config) |
|
|
|
def forward(self, sequence_output): |
|
prediction_scores = self.predictions(sequence_output) |
|
return prediction_scores |
|
|
|
|
|
class ContextualNomicBertModel(ContextualNomicBertPreTrainedModel): |
|
def __init__(self, config: GPT2Config, add_pooling_layer=True): |
|
super().__init__(config) |
|
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) |
|
if config.vocab_size % self.pad_vocab_size_multiple != 0: |
|
config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple) |
|
|
|
assert config.activation_function in [ |
|
"gelu", |
|
"gelu_new", |
|
"gelu_fast", |
|
"gelu_pytorch_tanh", |
|
"swiglu", |
|
"geglu", |
|
"glu", |
|
] |
|
|
|
self.embeddings = ContextualNomicBertEmbeddings(config) |
|
self.emb_drop = nn.Dropout(config.resid_pdrop) |
|
self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
self.encoder = ContextualNomicBertEncoder(config) |
|
self.pooler = ContextualNomicBertPooler(config) if add_pooling_layer else None |
|
|
|
self.apply(partial(_init_weights, initializer_range=config.initializer_range)) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
token_type_ids=None, |
|
return_dict=None, |
|
matryoshka_dim=None, |
|
inputs_embeds=None, |
|
): |
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
hidden_states = self.embeddings( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
token_type_ids=token_type_ids, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
hidden_states = self.emb_ln(hidden_states) |
|
hidden_states = self.emb_drop(hidden_states) |
|
|
|
attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1]) |
|
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict) |
|
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if matryoshka_dim: |
|
sequence_output = sequence_output[:, :matryoshka_dim] |
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
) |
|
|
|
|
|
class ContextualNomicBertForPreTraining(ContextualNomicBertPreTrainedModel): |
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
|
def __init__(self, config: GPT2Config): |
|
super().__init__(config) |
|
|
|
self.bert = ContextualNomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False)) |
|
self.cls = ContextualNomicBertPreTrainingHeads(config) |
|
self.mlm_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
self.apply(partial(_init_weights, initializer_range=config.initializer_range)) |
|
self.tie_weights() |
|
|
|
def tie_weights(self): |
|
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
position_ids=None, |
|
token_type_ids=None, |
|
attention_mask=None, |
|
labels=None, |
|
): |
|
""" |
|
If labels are provided, they must be -100 for masked out tokens (as specified in the attention |
|
mask). |
|
Outputs: |
|
if `labels` and `next_sentence_label` are not `None`: |
|
Outputs the total_loss which is the sum of the masked language modeling loss and the next |
|
sentence classification loss. |
|
if `labels` or `next_sentence_label` is `None`: |
|
Outputs a tuple comprising |
|
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and |
|
- the next sentence classification logits of shape [batch_size, 2]. |
|
""" |
|
outputs = self.bert( |
|
input_ids, |
|
position_ids=position_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask.bool() if attention_mask is not None else None, |
|
) |
|
sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output |
|
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
total_loss = None |
|
if labels is not None: |
|
masked_lm_loss = self.mlm_loss( |
|
rearrange(prediction_scores, "... v -> (...) v"), |
|
rearrange(labels, "... -> (...)"), |
|
) |
|
total_loss = masked_lm_loss.float() |
|
|
|
return MaskedLMOutput( |
|
loss=total_loss, |
|
logits=prediction_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=None, |
|
) |
|
|
|
|
|
class ContextualNomicBertForSequenceClassification(ContextualNomicBertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.bert = ContextualNomicBertModel(config) |
|
classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.n_embd, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.bert( |
|
input_ids, |
|
position_ids=position_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask.bool() if attention_mask is not None else None, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = nn.MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class ContextualNomicBertForMultipleChoice(ContextualNomicBertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = ContextualNomicBertModel(config, add_pooling_layer=True) |
|
classifier_dropout = ( |
|
getattr(config, "classifier_dropout", config.resid_pdrop) |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
unpad_inputs: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., |
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See |
|
`input_ids` above) |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
|
inputs_embeds = ( |
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
|
if inputs_embeds is not None |
|
else None |
|
) |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
reshaped_logits = logits.view(-1, num_choices) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(reshaped_logits, labels) |
|
|
|
if not return_dict: |
|
output = (reshaped_logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return MultipleChoiceModelOutput( |
|
loss=loss, |
|
logits=reshaped_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class ContextualNomicBertForTokenClassification(ContextualNomicBertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = ContextualNomicBertModel(config, add_pooling_layer=False) |
|
classifier_dropout = ( |
|
getattr(config, "classifier_dropout", config.resid_pdrop) |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class ContextualNomicBertForQuestionAnswering(ContextualNomicBertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = ContextualNomicBertModel(config, add_pooling_layer=False) |
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
start_positions: Optional[torch.Tensor] = None, |
|
end_positions: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: |
|
r""" |
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1) |
|
|
|
ignored_index = start_logits.size(1) |
|
start_positions = start_positions.clamp(0, ignored_index) |
|
end_positions = end_positions.clamp(0, ignored_index) |
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
|
start_loss = loss_fct(start_logits, start_positions) |
|
end_loss = loss_fct(end_logits, end_positions) |
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
if not return_dict: |
|
output = (start_logits, end_logits) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
return QuestionAnsweringModelOutput( |
|
loss=total_loss, |
|
start_logits=start_logits, |
|
end_logits=end_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config: |
|
return GPT2Config( |
|
n_embd=vit_config.hidden_size, |
|
n_layer=vit_config.num_hidden_layers, |
|
n_head=vit_config.num_attention_heads, |
|
n_inner=vit_config.intermediate_size, |
|
activation_function=vit_config.hidden_act, |
|
vocab_size=0, |
|
n_positions=0, |
|
resid_pdrop=0.0, |
|
embd_pdrop=getattr(vit_config, "dropout", 0.0), |
|
attn_pdrop=vit_config.attention_probs_dropout_prob, |
|
layer_norm_epsilon=vit_config.layer_norm_eps, |
|
initializer_range=vit_config.initializer_range, |
|
bos_token_id=None, |
|
eos_token_id=None, |
|
|
|
drop_path_rate=0.0, |
|
|
|
prepre_layernom=False, |
|
layer_scale=False, |
|
layer_scale_init=None, |
|
img_size=vit_config.image_size, |
|
patch_size=vit_config.patch_size, |
|
num_channels=vit_config.num_channels, |
|
prenorm=True, |
|
parallel_block=False, |
|
parallel_block_tied_norm=False, |
|
rotary_emb_fraction=0, |
|
tie_word_embeddings=False, |
|
fused_dropout_add_ln=True, |
|
fused_bias_fc=True, |
|
patch_embed_bias=True, |
|
use_flash_attn=True, |
|
qkv_proj_bias=True, |
|
mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True), |
|
mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True), |
|
use_rms_norm=False, |
|
causal=False, |
|
hidden_features_scaling_factor=1.0, |
|
mask_token=False, |
|
learned_pos_embedding=False, |
|
patch_dropout=0, |
|
sinusoidal_pos_embedding=vit_config.model_type == "vit_mae", |
|
) |
|
|
|
|
|
class ContextualNomicAttentionPooling(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.embed_dim = config.n_embd |
|
self.use_flash_attn = config.use_flash_attn |
|
self.fused_bias_fc = config.fused_bias_fc |
|
|
|
self.num_heads = config.n_head |
|
self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads |
|
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" |
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
kv_dim = 2 * self.head_dim * self.num_heads_kv |
|
|
|
self.register_buffer( |
|
"norm_factor", |
|
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), |
|
persistent=False, |
|
) |
|
|
|
self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) |
|
self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias) |
|
|
|
self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) |
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias) |
|
self.causal = config.causal |
|
self.drop = nn.Dropout(config.attn_pdrop) |
|
|
|
def init_weights(self): |
|
trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5) |
|
|
|
def forward( |
|
self, |
|
kv, |
|
attention_mask=None, |
|
cu_seqlens_k=None, |
|
max_seqlen_k=None, |
|
is_padded_inputs: Optional[bool] = True, |
|
output_attentions: bool = False, |
|
): |
|
"""Implements the multihead softmax attention. |
|
Arguments |
|
--------- |
|
q: The tensor containing the query. (B, Sq, H, D) |
|
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) |
|
causal: if passed, will override self.causal |
|
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
|
of the sequences in the batch, used to index into q. |
|
max_seqlen: int. Maximum sequence length in the batch of q. |
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
|
of the sequences in the batch, used to index into kv. |
|
max_seqlen_k: int. Maximum sequence length in the batch of k and v. |
|
""" |
|
q_latent = self.latent.expand(kv.size(0), -1, -1) |
|
q = self.Wq(q_latent) |
|
bsz, q_len, h_size = q.shape |
|
kv = self.Wkv(kv) |
|
query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) |
|
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) |
|
|
|
key, value = kv[:, :, 0], kv[:, :, 1] |
|
|
|
query = query.permute(0, 2, 1, 3) |
|
key = key.permute(0, 2, 1, 3) |
|
value = value.permute(0, 2, 1, 3) |
|
|
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor |
|
if attention_mask is not None: |
|
attention_scores = attention_scores + attention_mask |
|
|
|
attentions_probs = F.softmax(attention_scores, dim=-1) |
|
attentions_probs = self.drop(attentions_probs) |
|
|
|
attn_output = torch.matmul(attentions_probs, value) |
|
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)") |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class ContextualNomicMultiHeadAttentionPooling(nn.Module): |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__() |
|
self.prenorm = config.prenorm |
|
self.fused_dropout_add_ln = config.fused_dropout_add_ln |
|
|
|
self.attn = ContextualNomicAttentionPooling(config) |
|
activation = ( |
|
F.sigmoid |
|
if config.activation_function == "glu" |
|
else (F.silu if config.activation_function == "swiglu" else F.gelu) |
|
) |
|
if config.activation_function in ["glu", "swiglu", "geglu"]: |
|
self.mlp = NomciBertGatedMLP( |
|
config.n_embd, |
|
hidden_features=config.n_inner, |
|
bias1=config.mlp_fc1_bias, |
|
bias2=config.mlp_fc2_bias, |
|
activation=activation, |
|
fused_bias_fc=config.fused_bias_fc, |
|
) |
|
else: |
|
self.mlp = ContextualNomicBertMLP( |
|
config.n_embd, |
|
hidden_features=config.n_inner, |
|
bias1=config.mlp_fc1_bias, |
|
bias2=config.mlp_fc2_bias, |
|
activation=activation, |
|
fused_bias_fc=config.fused_bias_fc, |
|
) |
|
|
|
self.dropout1 = nn.Dropout(config.resid_pdrop) |
|
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
self.dropout2 = nn.Dropout(config.resid_pdrop) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
): |
|
r"""Pass the input through the encoder layer. |
|
Args: |
|
hidden_states: the sequence to the encoder layer (required). |
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) |
|
mixer_subset: for cross-attention only. If not None, will take a subset of x |
|
before applying the query projection. Useful for e.g., ViT where we only care |
|
about the CLS token in the last layer. |
|
""" |
|
|
|
attn_outputs = self.attn( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
normed = self.norm1(attn_outputs) |
|
hidden_states = hidden_states + self.mlp(normed) |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Dict, Optional, Union, Tuple |
|
import copy |
|
import math |
|
import multiprocessing |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
|
|
|
|
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): |
|
"""We create a dummy configuration class that will just set properties |
|
based on whatever kwargs we pass in. |
|
|
|
When this class is initialized (see experiments.py) we pass in the |
|
union of all data, model, and training args, all of which should |
|
get saved to the config json. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
for key, value in kwargs.items(): |
|
try: |
|
json.dumps(value) |
|
setattr(self, key, value) |
|
except TypeError: |
|
|
|
continue |
|
super().__init__() |
|
|
|
|
|
def load_embedder_and_tokenizer(name: str) -> Tuple[ |
|
transformers.PreTrainedModel, |
|
transformers.PreTrainedTokenizer |
|
]: |
|
print("Loading model:", name) |
|
if name.startswith("nomic") or (name == "bert-base-uncased"): |
|
model = ContextualNomicBertForPreTraining.from_pretrained(name, trust_remote_code=True).bert |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
elif name in ["gtr-base", "gtr_base"]: |
|
model = transformers.AutoModel.from_pretrained( |
|
"sentence-transformers/gtr-t5-base" |
|
).encoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"sentence-transformers/gtr-t5-base" |
|
) |
|
elif name == "pile-t5-base-encoder": |
|
model = transformers.AutoModel.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
).encoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
elif name == "pile-t5-base-decoder": |
|
model = transformers.AutoModel.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
).decoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): |
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
name, |
|
|
|
attn_implementation="flash_attention_2", |
|
low_cpu_mem_usage=True, |
|
|
|
) |
|
model.padding_side = "right" |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.add_eos_token = True |
|
else: |
|
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
|
|
|
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def get_world_size() -> int: |
|
try: |
|
return torch.distributed.get_world_size() |
|
except (RuntimeError, ValueError): |
|
return 1 |
|
|
|
|
|
def get_rank() -> int: |
|
try: |
|
return torch.distributed.get_rank() |
|
except (RuntimeError, ValueError): |
|
return 0 |
|
|
|
def gather(t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
world_size = get_world_size() |
|
if world_size == 1: |
|
return t |
|
|
|
if t.ndim == 0: |
|
t = t.unsqueeze(0) |
|
|
|
gathered = [torch.empty_like(t) for _ in range(world_size)] |
|
torch.distributed.all_gather(gathered, t) |
|
gathered[get_rank()] = t |
|
return torch.cat(gathered, dim=0) |
|
|
|
|
|
def gather_sum(t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
world_size = get_world_size() |
|
if world_size == 1: |
|
return t |
|
|
|
if t.ndim == 0: |
|
t = t.unsqueeze(0) |
|
|
|
gathered = [torch.empty_like(t) for _ in range(world_size)] |
|
torch.distributed.all_gather(gathered, t) |
|
gathered = torch.stack(gathered, dim=0) |
|
return gathered.sum(dim=0) |
|
|
|
|
|
def get_num_proc() -> int: |
|
world_size: int = get_world_size() |
|
try: |
|
|
|
|
|
return len(os.sched_getaffinity(0)) // world_size |
|
except AttributeError: |
|
return multiprocessing.cpu_count() // world_size |
|
|
|
|
|
def torch_main_worker_finish_first(func: Callable): |
|
def wrapper(*args, **kwargs): |
|
|
|
try: |
|
local_rank = torch.distributed.get_rank() |
|
ddp_enabled = True |
|
except (RuntimeError, ValueError): |
|
local_rank = -1 |
|
ddp_enabled = False |
|
is_main_worker = local_rank <= 0 |
|
|
|
if is_main_worker: |
|
result = func(*args, **kwargs) |
|
|
|
if ddp_enabled: |
|
torch.distributed.barrier() |
|
|
|
if not is_main_worker: |
|
result = func(*args, **kwargs) |
|
|
|
if ddp_enabled: |
|
torch.distributed.barrier() |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
def print0(*args, **kwargs) -> None: |
|
if get_rank() == 0: |
|
print(*args, **kwargs) |
|
|
|
|
|
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: |
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
world_size = get_world_size() |
|
|
|
if world_size > 8: |
|
print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") |
|
return |
|
|
|
for name, param in model.named_parameters(): |
|
if param is None: continue |
|
if param.grad is None: |
|
print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") |
|
continue |
|
gathered_param = gather(param).reshape((world_size, -1)) |
|
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() |
|
rank_params_eq = (absolute_diffs < atol).all() |
|
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" |
|
|
|
gathered_param_grad = gather(param.grad).reshape((world_size, -1)) |
|
absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() |
|
rank_grad_params_eq = (absolute_grad_diffs < atol).all() |
|
assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" |
|
|
|
|
|
|
|
print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") |
|
|
|
|
|
|
|
def mean_pool_3d( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, T, S, D = hidden_states.shape |
|
unmasked_outputs = hidden_states * attention_mask[..., None] |
|
pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) |
|
|
|
|
|
sequence_means = ( |
|
hidden_states.reshape((B, S * T, D)) |
|
.mean(dim=1, keepdim=True) |
|
.expand(-1, T, -1) |
|
) |
|
pooled_outputs = pooled_outputs.where( |
|
(attention_mask.sum(dim=2)[..., None] > 0), |
|
sequence_means |
|
) |
|
assert pooled_outputs.shape == (B, T, D) |
|
|
|
return pooled_outputs |
|
|
|
def mean_pool( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, _S, D = hidden_states.shape |
|
unmasked_outputs = hidden_states * attention_mask[..., None] |
|
pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) |
|
|
|
assert pooled_outputs.shape == (B, D) |
|
return pooled_outputs |
|
|
|
|
|
def mean_pool_weighted( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, _S, D = hidden_states.shape |
|
attention_mask *= attention_mask.cumsum(dim=1) |
|
s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) |
|
d = attention_mask.sum(dim=1, keepdim=True).float() |
|
return s / d |
|
|
|
|
|
def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: |
|
assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" |
|
t = t.coalesce() |
|
row_idxs = t.indices()[0] |
|
index_mask = (min_row <= row_idxs) & (row_idxs < max_row) |
|
|
|
num_rows = (max_row - min_row) |
|
num_cols = t.shape[1] |
|
|
|
idxs = t.indices()[:, index_mask] |
|
vals = t.values()[index_mask] |
|
return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() |
|
|
|
|
|
def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: |
|
if t.is_sparse: |
|
return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) |
|
else: |
|
return t[min_row:max_row] |
|
|
|
|
|
@torch.no_grad |
|
def maxsim( |
|
X: torch.Tensor, y: torch.Tensor, |
|
maximize: bool, chunk_size: int = 8_000, |
|
debug_mem_usage: bool = False) -> torch.Tensor: |
|
device = X.device |
|
n_samples = X.shape[0] |
|
|
|
max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) |
|
max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) |
|
|
|
|
|
|
|
rank = get_rank() |
|
world_size = get_world_size() |
|
|
|
worker_worklist_size = int(math.ceil(n_samples / world_size)) |
|
splits_start_idx = worker_worklist_size * rank |
|
splits_end_idx = worker_worklist_size * (rank + 1) |
|
|
|
for i in range(splits_start_idx, splits_end_idx, chunk_size): |
|
start, end = i, min(i + chunk_size, n_samples) |
|
sub_x = slice_tensor_rows(X, start, end) |
|
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") |
|
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) |
|
sub_sim = sub_x @ y |
|
sub_sim = sub_sim |
|
if maximize: |
|
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) |
|
else: |
|
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) |
|
del sub_sim |
|
del sub_x |
|
torch.cuda.empty_cache() |
|
max_sim_v[start: end] = sub_max_sim_v |
|
max_sim_i[start: end] = sub_max_sim_i |
|
|
|
|
|
max_sim_v = gather_sum(max_sim_v) |
|
max_sim_i = gather_sum(max_sim_i) |
|
k = y.shape[1] |
|
|
|
assert max_sim_v.shape == (n_samples,) |
|
assert max_sim_i.shape == (n_samples,) |
|
assert max_sim_i.min() >= 0 |
|
assert max_sim_i.max() <= k |
|
|
|
return max_sim_v, max_sim_i |
|
|
|
|
|
def forward_batched( |
|
model: torch.nn.Module, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
batch_size: int, |
|
dataset_input_ids: Optional[torch.Tensor] = None, |
|
dataset_attention_mask: Optional[torch.Tensor] = None, |
|
**second_stage_model_kwargs, |
|
) -> torch.Tensor: |
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
if hasattr(model, "first_stage_model"): |
|
|
|
if len(dataset_input_ids.shape) == 2: |
|
dataset_input_ids = dataset_input_ids[None] |
|
dataset_attention_mask = dataset_attention_mask[None] |
|
|
|
dataset_embeddings = [] |
|
for j in range(len(dataset_input_ids)): |
|
i = 0 |
|
dataset_embeddings_batch = [] |
|
while i < dataset_input_ids.shape[1]: |
|
dataset_embeddings_batch.append( |
|
model.first_stage_model( |
|
input_ids=dataset_input_ids[j][i:i+batch_size], |
|
attention_mask=dataset_attention_mask[j][i:i+batch_size], |
|
) |
|
) |
|
i += batch_size |
|
dataset_embeddings.append( |
|
torch.cat(dataset_embeddings_batch, dim=0) |
|
) |
|
|
|
|
|
dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) |
|
|
|
j = 0 |
|
outputs = [] |
|
while j < len(input_ids): |
|
outputs.append( |
|
model.second_stage_model( |
|
input_ids=input_ids[j:j+batch_size], |
|
attention_mask=attention_mask[j:j+batch_size], |
|
dataset_embeddings=dataset_embeddings, |
|
**second_stage_model_kwargs, |
|
) |
|
) |
|
j += batch_size |
|
return torch.cat(outputs, dim=0) |
|
|
|
else: |
|
i = 0 |
|
outputs = [] |
|
while i < len(input_ids): |
|
outputs.append( |
|
model( |
|
input_ids=input_ids[i:i+batch_size], |
|
attention_mask=attention_mask[i:i+batch_size], |
|
**second_stage_model_kwargs, |
|
) |
|
) |
|
i += batch_size |
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
|
b, n, d = hidden_state.size() |
|
|
|
|
|
|
|
reversed_mask = torch.flip(attention_mask, dims=(1,)) |
|
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) |
|
gather_indices = attention_mask.size(1) - argmax_reverse - 1 |
|
|
|
gather_indices = torch.clamp(gather_indices, min=0) |
|
|
|
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) |
|
gather_indices = gather_indices.unsqueeze(1) |
|
assert gather_indices.shape == (b, 1, d) |
|
|
|
|
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() |
|
return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
|
|
|
def print0(*args, **kwargs) -> None: |
|
if get_rank() == 0: |
|
print(*args, **kwargs) |
|
|
|
|
|
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: |
|
if hasattr(model, 'transformer'): |
|
if hasattr(model.transformer, 'h'): |
|
|
|
model.transformer.h = model.transformer.h[:n_layers] |
|
else: |
|
model.transformer.layer = model.transformer.layer[:n_layers] |
|
elif hasattr(model, 'encoder'): |
|
if hasattr(model.encoder, 'layers'): |
|
model.encoder.layers = model.encoder.layers[:n_layers] |
|
else: |
|
model.encoder.layer = model.encoder.layer[:n_layers] |
|
else: |
|
raise RuntimeError(f"unknown how to limit layers of model {type(model)}") |
|
|
|
|
|
|
|
def disable_dropout(model: torch.nn.Module): |
|
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] |
|
for m in dropout_modules: |
|
m.p = 0.0 |
|
print0( |
|
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" |
|
) |
|
|
|
|
|
def disable_causality(model: torch.nn.Module): |
|
disabled_modules = 0 |
|
for m in model.modules(): |
|
if hasattr(m, "is_causal"): |
|
m.is_causal = False |
|
disabled_modules += 1 |
|
print0( |
|
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" |
|
) |
|
|
|
|
|
class ContextualModelMixin(nn.Module): |
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def contextual_init(self): |
|
self.n_soft_prompt = 8 |
|
self.prompt_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) |
|
) |
|
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) |
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.randomize_dataset_sequence_order = True |
|
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) |
|
if self.sequence_dropout_prob > 0.0: |
|
self.sequence_dropout_null_embedding = torch.nn.Parameter( |
|
torch.randn(self.hidden_size) * 0.01, |
|
requires_grad = True |
|
) |
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size) |
|
) |
|
|
|
def _prepare_dataset_embeddings( |
|
self, |
|
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
if not isinstance(dataset_embeddings, torch.Tensor): |
|
dataset_embeddings = torch.tensor(dataset_embeddings) |
|
|
|
if len(dataset_embeddings.shape) == 2: |
|
|
|
dataset_embeddings = dataset_embeddings[None, :, :] |
|
dataset_embeddings = dataset_embeddings.to(input_ids.device) |
|
|
|
if len(dataset_embeddings.shape) < 3: |
|
raise ValueError(f"dataset_embeddings must have at least 3 dimensions, got {dataset_embeddings.shape}") |
|
|
|
batch_size = input_ids.shape[0] |
|
if (self.transductive_tokens_per_document > 1): |
|
if self.training: |
|
|
|
|
|
|
|
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document |
|
R = torch.randint( |
|
low=0, |
|
high=len(dataset_embeddings), |
|
size=(batch_size, self.config.transductive_corpus_size), |
|
device=dataset_embeddings.device |
|
) |
|
|
|
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) |
|
else: |
|
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) |
|
|
|
|
|
if dataset_embeddings.shape[1] < self.num_corpus_tokens: |
|
raise ValueError(f"dataset_embeddings must have at least {self.num_corpus_tokens} tokens, got {dataset_embeddings.shape[1]}") |
|
elif dataset_embeddings.shape[1] > self.num_corpus_tokens: |
|
|
|
|
|
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] |
|
|
|
_, corpus_size, _hidden_size = dataset_embeddings.shape |
|
if _ == 1: |
|
|
|
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) |
|
|
|
if self.training and self.sequence_dropout_prob > 0.0: |
|
sequence_dropout_mask = ( |
|
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob |
|
) |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = torch.where( |
|
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings |
|
) |
|
elif null_dataset_embedding: |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = null_embeddings |
|
|
|
|
|
|
|
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) |
|
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) |
|
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) |
|
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) |
|
|
|
|
|
|
|
if self.training and self.randomize_dataset_sequence_order: |
|
randomized_order = torch.stack( |
|
[ |
|
torch.cat( |
|
( |
|
torch.randperm(corpus_size, device=soft_prompt.device), |
|
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size |
|
), dim=0) |
|
for _ in range(batch_size)]) |
|
randomized_order = randomized_order.to(soft_prompt.device) |
|
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) |
|
|
|
return soft_prompt |
|
|
|
class BiEncoder(transformers.PreTrainedModel): |
|
embedder: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
embedder, _ = load_embedder_and_tokenizer( |
|
config.embedder, |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(embedder, config.limit_layers) |
|
|
|
self.embedder = embedder |
|
|
|
|
|
|
|
self.hidden_size = self.embedder.config.hidden_size |
|
|
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.mlp = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), |
|
) |
|
self.temp = config.logit_scale |
|
|
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
self.pooling_strategy = vars(config).get("pooling_strategy", "mean") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor] = None, |
|
dataset_attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids = None, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) |
|
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) |
|
where the corpus_size >= batch_size and is structured like this: |
|
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] |
|
for a corpus with three documents and two hard negatives per document |
|
""" |
|
del token_type_ids |
|
|
|
outputs = ( |
|
self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
).last_hidden_state |
|
) |
|
|
|
if self.transductive_tokens_per_document > 1: |
|
document_embeddings = None |
|
batch_size, seq_length, output_dim = outputs.shape |
|
|
|
if seq_length % self.transductive_tokens_per_document != 0: |
|
|
|
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) |
|
outputs = torch.cat( |
|
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), |
|
dim=1 |
|
) |
|
attention_mask = torch.cat( |
|
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), |
|
dim=1 |
|
) |
|
seq_length += n_extra_embeds |
|
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") |
|
|
|
|
|
|
|
outputs = outputs.reshape( |
|
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) |
|
) |
|
|
|
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) |
|
document_embeddings = mean_pool_3d(outputs, attention_mask) |
|
|
|
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) |
|
else: |
|
if self.pooling_strategy == "mean": |
|
document_embeddings = mean_pool(outputs, attention_mask) |
|
else: |
|
document_embeddings = document_embeddings.max(dim=1) |
|
output = self.mlp(document_embeddings) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": outputs, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
first_stage_hidden_size: int, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.backbone_hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = first_stage_hidden_size |
|
self.contextual_init() |
|
disable_causality(self.backbone) |
|
|
|
self.input_ln = torch.nn.LayerNorm( |
|
self.backbone_hidden_size, |
|
eps=1e-5 |
|
) |
|
|
|
|
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) |
|
) |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
@property |
|
def corpus_token_ratio(self) -> float: |
|
|
|
|
|
return self.backbone_hidden_size / self.hidden_size |
|
|
|
def corpus_token_pad_size(self, n_tokens: int) -> int: |
|
return self.hidden_size % self.backbone_hidden_size |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
|
|
print("Warning: Positional embedding disabling not implemented for LLAMA.") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
|
|
|
|
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() |
|
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) |
|
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) |
|
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) |
|
soft_prompt = torch.cat((soft_prompt, padding), dim=1) |
|
soft_prompt = soft_prompt.reshape( |
|
(soft_prompt.shape[0], -1, self.backbone_hidden_size) |
|
) |
|
soft_prompt = self.input_ln(soft_prompt) |
|
|
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
token_embeddings = self.backbone.get_input_embeddings() |
|
inputs_embeds = token_embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=input_attention_mask, |
|
output_hidden_states=True, |
|
) |
|
|
|
last_hidden_state = output.hidden_states[-1] |
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
if vars(self.config).get("pooling_strategy") == "last_token": |
|
output_pooled = last_token_pool(output_vectors, output_attention_mask) |
|
elif vars(self.config).get("pooling_strategy") == "mean": |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
else: |
|
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = dataset_backbone.config.hidden_size |
|
|
|
|
|
|
|
|
|
self.contextual_init() |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: |
|
|
|
|
|
self.backbone.config.rotary_start_pos = 0.0 |
|
rotary_disabled = 0 |
|
|
|
rotary_start_pos = self.num_corpus_tokens |
|
for module in self.backbone.modules(): |
|
if hasattr(module, "rotary_emb_dim"): |
|
print(f"editing module", type(module)) |
|
module.rotary_start_pos = rotary_start_pos |
|
rotary_disabled += 1 |
|
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
|
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
inputs_embeds = self.backbone.embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
|
|
|
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
|
|
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
|
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
embedder: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.embedder = embedder |
|
self.hidden_size = self.embedder.config.hidden_size |
|
self.contextual_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: torch.Tensor, |
|
dataset_attention_mask: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) |
|
|
|
dataset_input_ids = dataset_input_ids[R] |
|
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) |
|
|
|
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) |
|
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) |
|
output_attention_mask = torch.cat( |
|
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1 |
|
) |
|
|
|
output = self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=input_attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
S_d = dataset_attention_mask.shape[1] |
|
output_vectors = output_vectors[:, S_d:, :] |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel): |
|
config_class = ContextualModelConfig |
|
embedder: transformers.PreTrainedModel |
|
dataset_backbone: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
dataset_backbone, _ = load_embedder_and_tokenizer( |
|
vars(config).get("dataset_backbone", config.embedder) |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(dataset_backbone, config.limit_layers) |
|
|
|
biencoder_config = copy.deepcopy(config) |
|
biencoder_config.embedding_output_dim = None |
|
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) |
|
self.first_stage_model = BiEncoder( |
|
config=biencoder_config, |
|
) |
|
|
|
if vars(config).get("autoregressive_backbone", False): |
|
self.second_stage_model = DatasetConditionedAutoregressive( |
|
config=config, |
|
dataset_backbone=dataset_backbone, |
|
first_stage_hidden_size=self.first_stage_model.hidden_size, |
|
) |
|
else: |
|
self.second_stage_model = DatasetConditionedBiencoder( |
|
config=config, |
|
dataset_backbone=dataset_backbone |
|
) |
|
|
|
self.temp = config.logit_scale |
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
|
|
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) |
|
if transductive_tie_token_embeddings: |
|
self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( |
|
self.first_stage_model.embedder.embeddings.word_embeddings.weight |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor], |
|
dataset_attention_mask: Optional[torch.Tensor], |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
input_ids (long torch.Tensor) – ids of input tokens |
|
attention_mask (bool torch.Tensor) |
|
""" |
|
dataset_embeddings = self.first_stage_model( |
|
input_ids=dataset_input_ids, |
|
attention_mask=dataset_attention_mask |
|
) |
|
return self.second_stage_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
dataset_embeddings=dataset_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
|
|
def get_model_class(name: str): |
|
if name in 'transductive': |
|
return ContextualDocumentEmbeddingTransformer |
|
elif name == 'biencoder': |
|
return BiEncoder |
|
elif name == "dataset_prefix_biencoder": |
|
return DatasetPrefixBiencoder |
|
else: |
|
raise ValueError(f'unknown model cls {name}') |
|
|