Profluent-E1-600M / modeling_e1.py
lhallee's picture
Upload modeling_e1.py with huggingface_hub
ba292d4 verified
### Embedding Mixin + Pooler
import os
import sqlite3
import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Callable, List, Optional
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from transformers import PreTrainedTokenizerBase
class Pooler:
def __init__(self, pooling_types: List[str]):
self.pooling_types = pooling_types
self.pooling_options = {
'mean': self.mean_pooling,
'max': self.max_pooling,
'norm': self.norm_pooling,
'median': self.median_pooling,
'std': self.std_pooling,
'var': self.var_pooling,
'cls': self.cls_pooling,
'parti': self._pool_parti,
}
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
maxed_attentions = torch.max(attentions, dim=1)[0]
return maxed_attentions
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
# Run PageRank on the attention matrix converted to a graph.
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
# Returns the PageRank scores for each token node.
G = self._convert_to_graph(attention_matrix)
if G.number_of_nodes() != attention_matrix.shape[0]:
raise Exception(
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
if G.number_of_edges() == 0:
raise Exception(f"You don't seem to have any attention edges left in the graph.")
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
def _convert_to_graph(self, matrix):
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
# Each element in the matrix represents a directed edge with a weight.
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
return G
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
# Remove keys where attention_mask is 0
if attention_mask is not None:
for k in list(dict_importance.keys()):
if attention_mask[k] == 0:
del dict_importance[k]
#dict_importance[0] # remove cls
#dict_importance[-1] # remove eos
total = sum(dict_importance.values())
return np.array([v / total for _, v in dict_importance.items()])
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
# emb is (b, L, d), maxed_attentions is (b, L, L)
emb_pooled = []
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
dict_importance = self._page_rank(a)
importance_weights = self._calculate_importance_weights(dict_importance, mask)
num_tokens = int(mask.sum().item())
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
pooled = torch.tensor(np.array(emb_pooled))
return pooled
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.mean(dim=1)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.max(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).max(dim=1).values
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.norm(dim=1, p=2)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).norm(dim=1, p=2)
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.median(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).median(dim=1).values
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.std(dim=1)
else:
# Compute variance correctly over non-masked positions, then take sqrt
var = self.var_pooling(emb, attention_mask, **kwargs)
return torch.sqrt(var)
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.var(dim=1)
else:
# Correctly compute variance over only non-masked positions
attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
# Compute mean over non-masked positions
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
mean = mean.unsqueeze(1) # (b, 1, d)
# Compute squared differences from mean, only over non-masked positions
squared_diff = (emb - mean) ** 2 # (b, L, d)
# Sum squared differences over non-masked positions and divide by count
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
return var
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
return emb[:, 0, :]
def __call__(
self,
emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attentions: Optional[torch.Tensor] = None
): # [mean, max]
final_emb = []
for pooling_type in self.pooling_types:
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
class ProteinDataset(TorchDataset):
"""Simple dataset for protein sequences."""
def __init__(self, sequences: list[str]):
self.sequences = sequences
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> str:
return self.sequences[idx]
def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]:
def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]:
return tokenizer(sequences, return_tensors="pt", padding='longest')
return _collate_fn
def parse_fasta(fasta_path: str) -> List[str]:
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
sequences = []
current_seq = []
with open(fasta_path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_seq:
sequences.append(''.join(current_seq))
current_seq = []
else:
current_seq.append(line)
if current_seq:
sequences.append(''.join(current_seq))
return sequences
class EmbeddingMixin:
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return next(self.parameters()).device
def _read_sequences_from_db(self, db_path: str) -> set[str]:
"""Read sequences from SQLite database."""
sequences = []
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
c.execute("SELECT sequence FROM embeddings")
while True:
row = c.fetchone()
if row is None:
break
sequences.append(row[0])
return set(sequences)
def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
cursor = conn.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS embeddings ("
"sequence TEXT PRIMARY KEY, "
"embedding BLOB NOT NULL, "
"shape TEXT, "
"dtype TEXT"
")"
)
cursor.execute("PRAGMA table_info(embeddings)")
rows = cursor.fetchall()
column_names = [row[1] for row in rows]
if "shape" not in column_names:
cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
if "dtype" not in column_names:
cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
conn.commit()
def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]:
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
payload = torch.load(save_path, map_location="cpu", weights_only=True)
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
for sequence, tensor in payload.items():
assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
return payload
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]:
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
loaded: dict[str, torch.Tensor] = {}
with sqlite3.connect(db_path) as conn:
self._ensure_embeddings_table(conn)
cursor = conn.cursor()
if sequences is None:
cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
else:
if len(sequences) == 0:
return loaded
placeholders = ",".join(["?"] * len(sequences))
cursor.execute(
f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
tuple(sequences),
)
rows = cursor.fetchall()
for row in rows:
sequence = row[0]
embedding_bytes = row[1]
shape_text = row[2]
dtype_text = row[3]
assert shape_text is not None, "Missing shape metadata in embeddings table."
assert dtype_text is not None, "Missing dtype metadata in embeddings table."
shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
expected_size = int(np.prod(shape_values))
np_dtype = np.dtype(dtype_text)
array = np.frombuffer(embedding_bytes, dtype=np_dtype)
assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
reshaped = array.copy().reshape(tuple(shape_values))
loaded[sequence] = torch.from_numpy(reshaped)
return loaded
def embed_dataset(
self,
sequences: Optional[List[str]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
batch_size: int = 2,
max_len: int = 512,
truncate: bool = True,
full_embeddings: bool = False,
embed_dtype: torch.dtype = torch.float32,
pooling_types: List[str] = ['mean'],
num_workers: int = 0,
sql: bool = False,
save: bool = True,
sql_db_path: str = 'embeddings.db',
save_path: str = 'embeddings.pth',
fasta_path: Optional[str] = None,
**kwargs,
) -> Optional[dict[str, torch.Tensor]]:
"""
Embed a dataset of protein sequences.
Supports two modes:
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
`fasta_path`, or both (the two sources are combined). At least one must be provided.
"""
if fasta_path is not None:
fasta_sequences = parse_fasta(fasta_path)
sequences = list(sequences or []) + fasta_sequences
assert sequences is not None and len(sequences) > 0, \
"Must provide at least one sequence via `sequences` or `fasta_path`."
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
sequences = sorted(sequences, key=len, reverse=True)
hidden_size = self.config.hidden_size
pooler = Pooler(pooling_types) if not full_embeddings else None
tokenizer_mode = tokenizer is not None
if tokenizer_mode:
collate_fn = build_collator(tokenizer)
device = self.device
else:
collate_fn = None
device = None
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if full_embeddings or residue_embeddings.ndim == 2:
return residue_embeddings
return pooler(residue_embeddings, attention_mask)
def iter_batches(to_embed: List[str]):
if tokenizer_mode:
assert collate_fn is not None
assert device is not None
dataset = ProteinDataset(to_embed)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
residue_embeddings = self._embed(input_ids, attention_mask)
yield seqs, residue_embeddings, attention_mask
else:
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
seqs = to_embed[batch_start:batch_start + batch_size]
batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
residue_embeddings, attention_mask = batch_output
assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
yield seqs, residue_embeddings, attention_mask
if sql:
conn = sqlite3.connect(sql_db_path)
self._ensure_embeddings_table(conn)
c = conn.cursor()
already_embedded = self._read_sequences_from_db(sql_db_path)
to_embed = [seq for seq in sequences if seq not in already_embedded]
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
with torch.no_grad():
for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
if full_embeddings:
emb = emb[mask.bool()].reshape(-1, hidden_size)
emb_np = emb.cpu().numpy()
emb_shape = ",".join([str(dim) for dim in emb_np.shape])
emb_dtype = str(emb_np.dtype)
c.execute(
"INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
(seq, emb_np.tobytes(), emb_shape, emb_dtype),
)
if tokenizer_mode and (i + 1) % 100 == 0:
conn.commit()
conn.commit()
conn.close()
return None
embeddings_dict = {}
if os.path.exists(save_path):
embeddings_dict = self.load_embeddings_from_pth(save_path)
to_embed = [seq for seq in sequences if seq not in embeddings_dict]
print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
print(f"Embedding {len(to_embed)} new sequences")
else:
to_embed = sequences
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
with torch.no_grad():
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
if full_embeddings:
emb = emb[mask.bool()].reshape(-1, hidden_size)
embeddings_dict[seq] = emb.cpu()
if save:
torch.save(embeddings_dict, save_path)
return embeddings_dict
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, repeat
from enum import Enum
from typing import Any, TypedDict, Callable, List
from dataclasses import dataclass
from tokenizers import Tokenizer
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import ModelOutput
from transformers.utils import logging
logger = logging.get_logger(__name__)
### Kernels Flash Attention Detection
def _infer_kernels_flash_variant(kernel) -> str | None:
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
return "flash_attn2"
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
return "flash_attn3"
return None
def _try_get_kernels_flash():
try:
from kernels import get_kernel
except ImportError:
return None, None
flash_kernel = None
flash_kernel_variant = None
try:
flash_kernel = get_kernel("kernels-community/flash-attn3")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
except Exception:
try:
flash_kernel = get_kernel("kernels-community/flash-attn2")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
except Exception:
flash_kernel = None
flash_kernel_variant = None
return flash_kernel, flash_kernel_variant
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
def _kernels_flash_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
except TypeError:
output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
def _kernels_flash_varlen_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_in_batch_q: int,
max_seqlen_in_batch_k: int,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.varlen_fwd(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
is_causal=causal,
)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_varlen_func(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
causal=causal,
)
except TypeError:
output = FLASH_KERNEL.flash_attn_varlen_func(
query_states, key_states, value_states,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
0.0, None, causal,
)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
_create_sparse_block_from_block_mask
)
try:
from kernels import get_kernel
layer_norm = get_kernel("kernels-community/triton-layer-norm")
except Exception as e:
logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead")
layer_norm = None
### Attention Backend Enum & Resolution
class AttentionBackend(Enum):
AUTO = "auto"
KERNELS_FLASH = "kernels_flash"
FLEX = "flex"
SDPA = "sdpa"
VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
_BACKEND_CONFIRMED = False
def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
global _BACKEND_CONFIRMED
assert requested_backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
if requested_backend == AttentionBackend.AUTO.value:
if FLASH_KERNEL is not None:
resolved = AttentionBackend.KERNELS_FLASH
elif flex_attention is not None:
resolved = AttentionBackend.FLEX
else:
resolved = AttentionBackend.SDPA
elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
resolved = AttentionBackend.KERNELS_FLASH
elif requested_backend == AttentionBackend.FLEX.value:
assert flex_attention is not None, "Flex Attention is not available in this environment."
resolved = AttentionBackend.FLEX
elif requested_backend == AttentionBackend.SDPA.value:
resolved = AttentionBackend.SDPA
else:
raise AssertionError(f"Unsupported attention backend: {requested_backend}")
if not _BACKEND_CONFIRMED:
print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
_BACKEND_CONFIRMED = True
return resolved
def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask:
# Assumes sequence_ids is sorted in increasing order for each batch item, except for
# the -1 values, which are used to indicate the padding tokens.
def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def]
return (
(sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx])
& (sequence_ids[b, q_idx] != -1)
& (sequence_ids[b, kv_idx] != -1)
)
batch_size, seqlen = sequence_ids.shape
return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device)
def create_within_seq_block_mask(sequence_ids: torch.Tensor) -> BlockMask:
def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def]
return (
(sequence_ids[b, q_idx] == sequence_ids[b, kv_idx])
& (sequence_ids[b, q_idx] != -1)
& (sequence_ids[b, kv_idx] != -1)
)
batch_size, seqlen = sequence_ids.shape
return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device)
def build_within_seq_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor:
not_pad = (sequence_ids != -1)
same_seq = sequence_ids.unsqueeze(-1) == sequence_ids.unsqueeze(-2)
valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2)
return (same_seq & valid).unsqueeze(1)
def build_block_causal_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor:
not_pad = (sequence_ids != -1)
causal = sequence_ids.unsqueeze(-1) >= sequence_ids.unsqueeze(-2)
valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2)
return (causal & valid).unsqueeze(1)
def flex_attention_func(
query_states: torch.Tensor, # (bs, seqlen, nh, hs)
key_states: torch.Tensor, # (bs, seqlen, nkv, hs)
value_states: torch.Tensor, # (bs, seqlen, nkv, hs)
score_mod: Callable | None = None,
block_mask: BlockMask | None = None,
) -> torch.Tensor:
assert flex_attention is not None, "Flex Attention is not available in this environment"
assert score_mod is None, "Score mod is not supported yet"
query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs)
key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs)
value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs)
outputs = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
score_mod=score_mod,
enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh
)
outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs)
return outputs
def kernels_flash_attention_func(
query_states: torch.Tensor, # (bs, seqlen, nh, hs)
key_states: torch.Tensor, # (bs, seqlen, nkv, hs)
value_states: torch.Tensor, # (bs, seqlen, nkv, hs)
q_sequence_ids: torch.Tensor,
k_sequence_ids: torch.Tensor,
causal: bool = False,
) -> torch.Tensor: # (bs, seqlen, nh, hs)
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if not causal:
batch_size, q_len = query_states.shape[0], query_states.shape[1]
(
query_states,
key_states,
value_states,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
attn_output_unpad = _kernels_flash_varlen_forward(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_in_batch_q=max_seqlen_in_batch_q,
max_seqlen_in_batch_k=max_seqlen_in_batch_k,
causal=False,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
else:
attn_output = _kernels_flash_forward(query_states, key_states, value_states, causal=True)
return attn_output
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def]
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(
-1, *other_shape
)
@staticmethod
def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def]
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
device = SLEN.device
total_tokens = torch.sum(SLEN)
B = (total_tokens + block_size - 1) // block_size
padding_tokens = B * block_size - total_tokens
SLEN = torch.cat([SLEN, padding_tokens.reshape(1).to(device=device, dtype=SLEN.dtype)], dim=0)
assert torch.sum(SLEN) == B * block_size
# Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i
cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,)
total_tokens = cum[-1].item()
# Block start/end offsets [start, end) in token index space
block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,)
block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,)
# MIN_SEQ_ID[i] = first sequence whose end > block_start
# searchsorted with right=True returns first index where cum > value
MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True)
# MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1)
# For empty tail beyond total_tokens we already clipped block_ends.
last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token
MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True)
return MIN_SEQ_ID, MAX_SEQ_ID
def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q)
MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K)
cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0)
cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1)
overlap = cond1 & cond2
cond1 = (MIN_Q == MAX_Q).unsqueeze(1)
cond2 = (MIN_K == MAX_K).unsqueeze(0)
same_seq_in_qk = cond1 & cond2
full_blocks = overlap & same_seq_in_qk
partial_blocks = overlap & ~same_seq_in_qk
return full_blocks, partial_blocks
@torch.compiler.disable
def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K)
partial_blocks = partial_blocks[None, None]
full_blocks = full_blocks[None, None]
q_doc_id = torch.repeat_interleave(SLEN_Q)
k_doc_id = torch.repeat_interleave(SLEN_K)
def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
return q_doc_id[q_idx] == k_doc_id[kv_idx]
total_q_len = q_doc_id.shape[0]
total_k_len = k_doc_id.shape[0]
return _create_sparse_block_from_block_mask(
(partial_blocks, full_blocks),
doc_mask,
seq_lengths=(total_q_len, total_k_len),
Q_BLOCK_SIZE=128,
KV_BLOCK_SIZE=128,
)
@torch.compiler.disable
def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
q_doc_id = torch.repeat_interleave(SLEN_Q)
k_doc_id = torch.repeat_interleave(SLEN_K)
def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
return q_doc_id[q_idx] == k_doc_id[kv_idx]
total_q_len = q_doc_id.shape[0]
total_k_len = k_doc_id.shape[0]
return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device)
def varlen_flex_attention_func(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
q_sequence_ids: torch.Tensor,
k_sequence_ids: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len = query_states.shape[0], query_states.shape[1]
(
query_states,
key_states,
value_states,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous()
key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous()
value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous()
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
block_mask = block_mask_creator(seqlens_q, seqlens_k)
attn_output_unpad = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=query_states.shape[1] != key_states.shape[1],
)
attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len)
return attn_output
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def]
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def]
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
non_pad_indices = sequence_ids != -1
non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten()
sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5
sequence_ids = sequence_ids.flatten()[non_pad_indices]
_, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True)
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return non_pad_indices, cu_seqlens, max_seqlen_in_batch
def _unpad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
q_sequence_ids: torch.Tensor,
k_sequence_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2]
assert query_layer.shape[:2] == q_sequence_ids.shape, (
f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}"
)
assert key_layer.shape[:2] == k_sequence_ids.shape, (
f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}"
)
assert query_length <= kv_seq_len, (
f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}"
)
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if torch.equal(q_sequence_ids, k_sequence_ids):
indices_q = indices_k
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
else:
indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids)
query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q)
assert cu_seqlens_q.shape == cu_seqlens_k.shape, (
f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}"
)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
index_first_axis = IndexFirstAxis.apply
block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask
PAD_TOKEN_ID = 0
def get_tokenizer() -> Tokenizer:
try:
fname = os.path.join(os.path.dirname(__file__), "tokenizer.json")
tokenizer: Tokenizer = Tokenizer.from_file(fname)
except Exception:
print("E1 Tokenizer not found in local directory, downloading from Hugging Face")
from huggingface_hub import hf_hub_download
fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json")
tokenizer: Tokenizer = Tokenizer.from_file(fname)
assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, (
f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}"
)
return tokenizer
@dataclass
class DataPrepConfig:
max_num_sequences: int = 512
max_num_positions_within_seq: int = 8192
remove_X_tokens: bool = False
def get_context(sequence: str) -> str | None:
if "," in sequence:
return sequence.rsplit(",", 1)[0]
return None
class E1BatchPreparer:
def __init__(
self,
data_prep_config: DataPrepConfig | None = None,
tokenizer: Tokenizer | None = None,
preserve_context_labels: bool = False,
):
self.tokenizer = tokenizer or get_tokenizer()
self.data_prep_config = data_prep_config or DataPrepConfig()
self.pad_token_id = self.tokenizer.token_to_id("<pad>")
self.preserve_context_labels = preserve_context_labels
device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu")
self.boundary_token_ids = torch.tensor(
[self.tokenizer.token_to_id(token) for token in ["<bos>", "<eos>", "1", "2", "<pad>"]], device=device
).long()
self.mask_token = "?" # nosec
self.mask_token_id = self.tokenizer.token_to_id(self.mask_token)
self.X_token_id = self.tokenizer.token_to_id("X")
self.vocab = self.tokenizer.get_vocab()
def get_batch_kwargs( # type: ignore[override]
self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False
) -> dict[str, torch.Tensor | list[str] | list[int]]:
sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences]
return self.pad_encodings(sequence_encodings, device, non_blocking)
def pad_encodings(
self,
sequence_encodings: list[dict[str, torch.Tensor]],
device: torch.device = torch.device("cpu"),
non_blocking: bool = False,
) -> dict[str, torch.Tensor | list[str] | list[int]]:
non_blocking = non_blocking and device.type == "cuda"
padded_encodings = {}
# Note: We use -1 as the padding value for sequence and position ids because the 0 value
# is a valid value for sequence and position ids. -1 is then used to distinguish valid
# tokens from padding tokens, for example, when doing padding/unpadding for flash attention.
for key, padding_value in {
"input_ids": self.pad_token_id,
"sequence_ids": -1,
"within_seq_position_ids": -1,
"global_position_ids": -1,
"labels": self.pad_token_id,
}.items():
padded_encodings[key] = pad_sequence(
[enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value
).to(device=device, dtype=torch.long, non_blocking=non_blocking)
padded_encodings["context"] = [enc["context"] for enc in sequence_encodings]
padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings]
return padded_encodings
def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]:
single_sequences = sequence.split(",")
if len(single_sequences) > self.data_prep_config.max_num_sequences:
raise ValueError(
f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}"
" in the provided multi-sequence instance. Please remove some homologous sequences before trying again."
)
single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences]
num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings]
input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings])
labels = torch.cat([x["labels"] for x in single_sequence_encodings])
within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings])
global_position_ids, ctx_len = [], 0
for encoding in single_sequence_encodings:
global_position_ids.append(encoding["position_ids"] + ctx_len)
ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1)
global_position_ids = torch.cat(global_position_ids)
sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens))
# Get multi-seq context & mask out all but last sequence in multi-seq instance if desired
context_len = sum(num_tokens[:-1])
context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False)
if not self.preserve_context_labels:
labels[:context_len] = self.pad_token_id
assert (
input_ids.shape
== sequence_ids.shape
== within_seq_position_ids.shape
== global_position_ids.shape
== labels.shape
), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape"
assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length"
return {
"input_ids": input_ids,
"sequence_ids": sequence_ids,
"within_seq_position_ids": within_seq_position_ids,
"global_position_ids": global_position_ids,
"labels": labels,
"context": context,
"context_len": context_len,
}
def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]:
if not self.validate_sequence(sequence):
raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only")
if len(sequence) > self.data_prep_config.max_num_positions_within_seq:
raise ValueError(
f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}"
)
# Can also use `tokens = torch.tensor(self.tokenizer.encode(f"<bos>1{sequence}2<eos>").ids)`
# but following is faster since our vocabulary is simple.
tokens = torch.tensor([self.vocab[token] for token in ["<bos>", "1", *sequence, "2", "<eos>"]])
position_ids = torch.arange(len(tokens))
if self.data_prep_config.remove_X_tokens:
X_positions = torch.where(tokens != self.X_token_id)[0]
tokens = tokens[X_positions]
position_ids = position_ids[X_positions]
return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids}
def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
return torch.isin(tokens, self.boundary_token_ids.to(tokens.device))
def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
return tokens == self.mask_token_id
def validate_sequence(self, sequence: str) -> bool:
assert isinstance(sequence, str), "Sequence must be a string"
sequence = sequence.replace(self.mask_token, "")
return sequence.isalpha() and sequence.isupper()
class E1Config(PretrainedConfig):
model_type = "E1"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( # type: ignore
self,
# Model architecture/initialization
vocab_size=None,
hidden_size=4096,
intermediate_size=16384,
gated_mlp=False,
num_hidden_layers=40,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
rms_norm_eps=1e-5,
initializer_range=0.02,
dtype="bfloat16",
gradient_checkpointing=False,
no_ffn_gradient_checkpointing=False,
# Tokenization
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
tie_word_embeddings=False,
# Attention implementation & rotary positional embeddings
global_attention_every_n_layers=0,
max_num_sequences=512,
max_num_positions_within_seq=8192,
max_num_positions_global=1024 * 128,
rope_theta_within_seq=10000.0,
rope_theta_global=100000.0,
clip_qkv=None,
attn_backend="sdpa",
**kwargs,
) -> None:
tokenizer = get_tokenizer()
super().__init__(
pad_token_id=tokenizer.token_to_id("<pad>"),
bos_token_id=tokenizer.token_to_id("<bos>"),
eos_token_id=tokenizer.token_to_id("<eos>"),
tie_word_embeddings=tie_word_embeddings,
dtype=dtype,
**kwargs,
)
self.hidden_size = hidden_size
if intermediate_size is None:
intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size
self.intermediate_size = intermediate_size
self.gated_mlp = gated_mlp
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_num_positions_within_seq = max_num_positions_within_seq
self.max_num_positions_global = max_num_positions_global
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta_within_seq = rope_theta_within_seq
self.rope_theta_global = rope_theta_global
self.max_num_sequences = max_num_sequences
assert clip_qkv is None or clip_qkv > 0
self.clip_qkv = clip_qkv
self.global_attention_every_n_layers = global_attention_every_n_layers
self.vocab_size = tokenizer.get_vocab_size()
self.gradient_checkpointing = gradient_checkpointing
self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing
self.attn_backend = attn_backend
if vocab_size is not None:
if vocab_size < self.vocab_size:
logger.warning(
f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL."
)
self.vocab_size = vocab_size
elif vocab_size > self.vocab_size:
logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.")
self.vocab_size = vocab_size
if pad_token_id is not None and pad_token_id != self.pad_token_id:
logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer")
if bos_token_id is not None and bos_token_id != self.bos_token_id:
logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer")
if eos_token_id is not None and eos_token_id != self.eos_token_id:
logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer")
class DynamicCache:
"""
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`.
Args:
key_cache (`list[torch.Tensor]`): The list of key states.
value_cache (`list[torch.Tensor]`): The list of value states.
"""
def __init__(self) -> None:
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim]
value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim]
layer_idx (`int`): The index of the layer to update.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim].
"""
# Lazy initialization
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: int = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or not self.key_cache[layer_idx].numel() # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0
return layer_seq_length
def crop(self, max_length: int) -> None:
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
assert max_length > 0, "max_length must be positive"
if self.get_seq_length() <= max_length:
return
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].numel():
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...]
def batch_repeat_interleave(self, repeats: int) -> None:
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].numel():
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor) -> None:
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].numel():
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
class KVCache:
def __init__(self, cache_size: int = 4) -> None:
self.cache_size = cache_size
self.tensor_input_field_names = [
"input_ids",
"within_seq_position_ids",
"global_position_ids",
"sequence_ids",
"labels",
]
self.tensor_output_field_names = ["logits", "embeddings"]
self.cache_dict: dict[str, DynamicCache] = {}
self.cache_queue: list[str] = []
def reset(self) -> None:
for k in list(self.cache_dict.keys()):
del self.cache_dict[k]
del self.cache_dict
self.cache_dict = {}
self.cache_queue = []
torch.cuda.empty_cache()
def before_forward(self, batch: dict[str, torch.Tensor]) -> None:
contexts: list[str] | None = batch.get("context", None)
if contexts is None or "context_len" not in batch:
logger.warning_once(
"KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping."
)
return
context_lens: list[int] = list(set(batch["context_len"]))
contexts: list[str] = list(set(contexts)) # type: ignore[no-redef]
if len(contexts) != 1 or len(context_lens) != 1:
logger.warning(
"SingleContextKVCache requires a single context and context length. "
"Multiple contexts or context lengths found in a single batch. Skipping."
)
return
batch_size = batch["input_ids"].shape[0]
unique_context = contexts[0]
unique_context_len = context_lens[0]
batch["use_cache"] = True
if unique_context not in self.cache_dict:
return
self.cache_dict[unique_context].batch_repeat_interleave(batch_size)
past_key_values = self.cache_dict[unique_context]
batch["past_key_values"] = past_key_values
# Remove context from the input fields
for field_name in self.tensor_input_field_names:
if batch.get(field_name, None) is not None:
batch[field_name] = batch[field_name][:, unique_context_len:]
def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None:
contexts = batch.get("context", None)
context_lens = batch.get("context_len", [])
if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0:
return
assert batch["use_cache"]
unique_context = contexts[0]
unique_context_len = context_lens[0]
past_key_values = getattr(outputs, "past_key_values", None)
if not isinstance(past_key_values, DynamicCache):
logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.")
return
if "past_key_values" not in batch:
if len(self.cache_queue) == self.cache_size:
last_context = self.cache_queue.pop(0)
if last_context not in self.cache_queue:
del self.cache_dict[last_context]
torch.cuda.empty_cache()
self.cache_dict[unique_context] = past_key_values
self.cache_queue.append(unique_context)
# Remove context from the input fields
for field_name in self.tensor_input_field_names:
if field_name in batch and batch[field_name] is not None:
batch[field_name] = batch[field_name][:, unique_context_len:]
# Remove context from the output fields
for field_name in self.tensor_output_field_names:
if field_name in outputs and outputs[field_name] is not None:
outputs[field_name] = outputs[field_name][:, unique_context_len:]
if "hidden_states" in outputs and outputs["hidden_states"] is not None:
outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]]
self.cache_dict[unique_context].crop(unique_context_len)
self.cache_dict[unique_context].batch_select_indices([0])
class AttentionLayerType(Enum):
WITHIN_SEQ = "within_seq"
GLOBAL = "global"
class AttentionArgs(TypedDict, total=False):
within_seq_block_mask: BlockMask | None
block_causal_block_mask: BlockMask | None
within_seq_mask_4d: torch.Tensor | None
block_causal_mask_4d: torch.Tensor | None
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch,
num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RotaryPositionalEmbedding(nn.Module):
def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None
):
super().__init__()
self.dim = dim
self.base = base
self.max_position_embeddings = max_position_embeddings
inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
@staticmethod
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None:
# Different from paper, but it uses a different permutation in order to obtain the same calculation
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
angles = torch.outer(t, self.inv_freq.to(device))
angles = torch.cat((angles, angles), dim=1)
self.register_buffer("cos_cached", angles.cos(), persistent=False)
self.register_buffer("sin_cached", angles.sin(), persistent=False)
def forward(
self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
# x: [bsz, seq_len, num_attention_heads, head_size]
device, dtype = q.device, q.dtype
seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len
if seq_len > self.max_seq_len_cached:
self._set_sin_cos_cache(seq_len=seq_len, device=device)
# angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim),
# so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim).
idxs = position_ids.to(device)
cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
# Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is
# Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is
# Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same.
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def __init__(self, config: E1Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.max_num_seqs = config.max_num_sequences
self.clip_qkv = config.clip_qkv
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
if self.config.global_attention_every_n_layers > 0:
self.layer_type = (
AttentionLayerType.GLOBAL
if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0
else AttentionLayerType.WITHIN_SEQ
)
else:
self.layer_type = AttentionLayerType.WITHIN_SEQ
self.rope_theta = (
config.rope_theta_within_seq
if self.layer_type == AttentionLayerType.WITHIN_SEQ
else config.rope_theta_global
)
self.max_position_embeddings = (
config.max_num_positions_within_seq
if self.layer_type == AttentionLayerType.WITHIN_SEQ
else config.max_num_positions_global
)
self.rotary_emb = RotaryPositionalEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
)
self.attn_backend = resolve_attention_backend(config.attn_backend)
def prepare_qkv(
self,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
past_key_value: DynamicCache | None = None,
use_cache: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, q_len, _ = hidden_states.size()
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
val_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
if self.clip_qkv is not None:
query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv)
key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv)
val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
if use_cache and past_key_value is not None:
key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx)
input_dtype = query_states.dtype
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype
if input_dtype != target_dtype:
logger.warning_once(
f"The input hidden states seems to be silently casted in {input_dtype}. "
f"This might be because you have upcasted embedding or layer norm layers "
f"in {input_dtype}. We will cast back the input in {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
val_states = val_states.to(target_dtype)
return query_states, key_states, val_states
def forward(
self,
hidden_states: torch.Tensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
attention_args: AttentionArgs | None = None,
past_key_value: DynamicCache | None = None,
output_attentions: bool = False,
output_s_max: bool = False,
use_cache: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]:
is_cache_prefilled = (
use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0
)
query_states, key_states, val_states = self.prepare_qkv(
hidden_states=hidden_states,
position_ids=within_seq_position_ids
if self.layer_type == AttentionLayerType.WITHIN_SEQ
else global_position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
)
attn_output, attn_weights, s_max = self._attn(
query_states=query_states,
key_states=key_states,
val_states=val_states,
sequence_ids=sequence_ids,
attention_args=attention_args,
output_attentions=output_attentions,
output_s_max=output_s_max,
is_cache_prefilled=is_cache_prefilled,
)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value, s_max
def _attn(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
val_states: torch.Tensor,
sequence_ids: torch.Tensor,
attention_args: AttentionArgs | None = None,
output_attentions: bool = False,
output_s_max: bool = False,
is_cache_prefilled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
effective_layer_type = self.layer_type
if is_cache_prefilled and self.layer_type == AttentionLayerType.GLOBAL:
effective_layer_type = AttentionLayerType.WITHIN_SEQ
if output_attentions:
return self._manual_attn(
query_states, key_states, val_states,
sequence_ids=sequence_ids,
attention_args=attention_args,
effective_layer_type=effective_layer_type,
output_s_max=output_s_max,
is_cache_prefilled=is_cache_prefilled,
)
if self.attn_backend == AttentionBackend.KERNELS_FLASH:
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
attn_output, attn_weights = self._kernels_flash_attn(
query_states, key_states, val_states,
sequence_ids=sequence_ids,
is_cache_prefilled=is_cache_prefilled,
)
else:
attn_output, attn_weights = self._flex_attn(
query_states, key_states, val_states,
attention_args=attention_args,
effective_layer_type=effective_layer_type,
)
elif self.attn_backend == AttentionBackend.FLEX:
attn_output, attn_weights = self._flex_attn(
query_states, key_states, val_states,
attention_args=attention_args,
effective_layer_type=effective_layer_type,
)
elif self.attn_backend == AttentionBackend.SDPA:
attn_output, attn_weights = self._sdpa_attn(
query_states, key_states, val_states,
sequence_ids=sequence_ids,
attention_args=attention_args,
effective_layer_type=effective_layer_type,
is_cache_prefilled=is_cache_prefilled,
)
else:
raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}")
s_max = self._compute_s_max(query_states, key_states) if output_s_max else None
return attn_output, attn_weights, s_max
@torch.no_grad()
def _compute_s_max(
self,
query_states: torch.Tensor, # (B, L, H, D)
key_states: torch.Tensor, # (B, L, Hkv, D)
) -> list[torch.Tensor]:
query_BHLD = query_states.transpose(1, 2).contiguous()
key_BHLD = key_states.transpose(1, 2).contiguous()
key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups)
scale = 1.0 / (self.head_dim ** 0.5)
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * scale
return [s_max_bound[h] for h in range(self.num_heads)]
def _kernels_flash_attn(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
val_states: torch.Tensor,
sequence_ids: torch.Tensor,
is_cache_prefilled: bool = False,
) -> tuple[torch.Tensor, None]:
bsz, q_len = query_states.shape[0], query_states.shape[1]
_, kv_len = key_states.shape[0], key_states.shape[1]
if self.layer_type == AttentionLayerType.GLOBAL and not is_cache_prefilled:
q_sequence_ids = sequence_ids
if q_len < kv_len:
first_token_id = sequence_ids[:, 0].unsqueeze(1)
k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1)
else:
k_sequence_ids = sequence_ids
else:
if q_len < kv_len:
key_states = key_states[:, -q_len:]
val_states = val_states[:, -q_len:]
q_sequence_ids = k_sequence_ids = sequence_ids
attn_output = kernels_flash_attention_func(
query_states, key_states, val_states,
q_sequence_ids=q_sequence_ids,
k_sequence_ids=k_sequence_ids,
causal=False,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
return attn_output, None
def _flex_attn(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
val_states: torch.Tensor,
attention_args: AttentionArgs | None = None,
effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ,
) -> tuple[torch.Tensor, None]:
bsz, q_len = query_states.shape[0], query_states.shape[1]
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
block_mask = attention_args["within_seq_block_mask"] if attention_args is not None else None
else:
block_mask = attention_args["block_causal_block_mask"] if attention_args is not None else None
outputs = flex_attention_func(query_states, key_states, val_states, block_mask=block_mask)
outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous()
return outputs, None
def _sdpa_attn(
self,
query_states: torch.Tensor, # (B, L, H, D)
key_states: torch.Tensor, # (B, L, Hkv, D)
val_states: torch.Tensor, # (B, L, Hkv, D)
sequence_ids: torch.Tensor,
attention_args: AttentionArgs | None = None,
effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ,
is_cache_prefilled: bool = False,
) -> tuple[torch.Tensor, None]:
bsz, q_len = query_states.shape[:2]
kv_len = key_states.shape[1]
if is_cache_prefilled and q_len < kv_len:
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
key_states = key_states[:, -q_len:]
val_states = val_states[:, -q_len:]
attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None
elif attention_args is not None:
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
attention_mask_4d = attention_args["within_seq_mask_4d"]
else:
attention_mask_4d = attention_args["block_causal_mask_4d"]
else:
attention_mask_4d = None
query_BHLD = query_states.transpose(1, 2).contiguous()
key_BHLD = key_states.transpose(1, 2).contiguous()
val_BHLD = val_states.transpose(1, 2).contiguous()
key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups)
val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups)
context_BHLD = F.scaled_dot_product_attention(query_BHLD, key_BHLD, val_BHLD, attn_mask=attention_mask_4d)
attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous()
return attn_output, None
def _manual_attn(
self,
query_states: torch.Tensor, # (B, L, H, D)
key_states: torch.Tensor, # (B, L, Hkv, D)
val_states: torch.Tensor, # (B, L, Hkv, D)
sequence_ids: torch.Tensor,
attention_args: AttentionArgs | None = None,
effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ,
output_s_max: bool = False,
is_cache_prefilled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
bsz, q_len = query_states.shape[:2]
kv_len = key_states.shape[1]
if is_cache_prefilled and q_len < kv_len:
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
key_states = key_states[:, -q_len:]
val_states = val_states[:, -q_len:]
attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None
elif attention_args is not None:
if effective_layer_type == AttentionLayerType.WITHIN_SEQ:
attention_mask_4d = attention_args["within_seq_mask_4d"]
else:
attention_mask_4d = attention_args["block_causal_mask_4d"]
else:
attention_mask_4d = None
query_BHLD = query_states.transpose(1, 2).contiguous()
key_BHLD = key_states.transpose(1, 2).contiguous()
val_BHLD = val_states.transpose(1, 2).contiguous()
key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups)
val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups)
scale = 1.0 / (self.head_dim ** 0.5)
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
if attention_mask_4d is not None:
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1)
context_BHLD = torch.matmul(attn_weights, val_BHLD)
attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous()
s_max = self._compute_s_max(query_states, key_states) if output_s_max else None
return attn_output, attn_weights, s_max
class MLP(nn.Module):
def __init__(self, config: E1Config):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.w2(self.act_fn(self.w1(hidden_states)))
class GLUMLP(nn.Module):
def __init__(self, config: E1Config):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
hidden_states = self.w2(hidden_states)
return hidden_states
class FFN(nn.Module):
def __init__(self, config: E1Config):
super().__init__()
mlp_cls = GLUMLP if config.gated_mlp else MLP
self.mlp = mlp_cls(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(hidden_states)
@dataclass
class E1ModelOutputWithPast(ModelOutput):
"""Base class for model's outputs, with potential hidden states and attentions.
Attributes:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor | None = None
past_key_values: DynamicCache | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
s_max: tuple[list[torch.Tensor], ...] | None = None
@dataclass
class E1MaskedLMOutputWithPast(ModelOutput):
loss: torch.FloatTensor | None = None
mlm_loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
past_key_values: DynamicCache | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
s_max: tuple[list[torch.Tensor], ...] | None = None
@dataclass
class E1ClassificationOutputWithPast(ModelOutput):
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
past_key_values: DynamicCache | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
s_max: tuple[list[torch.Tensor], ...] | None = None
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
if layer_norm is None:
return torch.nn.functional.rms_norm(
hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon
).to(input_dtype)
else:
return layer_norm.rms_norm_fn(
x=hidden_states,
weight=self.weight,
bias=None, # no bias
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0, # no dropout by default
prenorm=False,
residual_in_fp32=False,
).to(input_dtype)
class NormAttentionNorm(nn.Module):
def __init__(self, config: E1Config, layer_idx: int):
super().__init__()
self.self_attn = Attention(config, layer_idx)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
attention_args: AttentionArgs | None = None,
past_key_value: DynamicCache | None = None,
output_attentions: bool = False,
output_s_max: bool = False,
use_cache: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value, s_max = self.self_attn(
hidden_states=hidden_states,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
attention_args=attention_args,
past_key_value=past_key_value,
output_attentions=output_attentions,
output_s_max=output_s_max,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return hidden_states, residual, self_attn_weights, present_key_value, s_max
class DecoderLayer(nn.Module):
def __init__(self, config: E1Config, layer_idx: int):
super().__init__()
self.initializer_range = config.initializer_range
self.hidden_size = config.hidden_size
self.norm_attn_norm = NormAttentionNorm(config, layer_idx)
self.ffn = FFN(config)
def forward(
self,
hidden_states: torch.Tensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
attention_args: AttentionArgs | None = None,
past_key_value: DynamicCache | None = None,
output_attentions: bool = False,
output_s_max: bool = False,
use_cache: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]:
hidden_states, residual, self_attn_weights, present_key_value, s_max = self.norm_attn_norm(
hidden_states=hidden_states,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
attention_args=attention_args,
past_key_value=past_key_value,
output_attentions=output_attentions,
output_s_max=output_s_max,
use_cache=use_cache,
)
# Fully Connected
hidden_states = self.ffn(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, self_attn_weights, present_key_value, s_max
class E1PreTrainedModel(PreTrainedModel):
config_class = E1Config
config: E1Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["DecoderLayer"]
_transformer_layer_cls = [DecoderLayer]
_skip_keys_device_placement = "past_key_values"
all_tied_weights_keys = {}
def _init_weights(self, module: nn.Module) -> None:
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)
def _backward_compatibility_gradient_checkpointing(self) -> None:
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable(dict(use_reentrant=False))
def post_init(self) -> None:
super().post_init()
@property
def _device(self) -> torch.device:
return next(self.parameters()).device
@property
def attn_backend(self) -> str:
return self.config.attn_backend
@attn_backend.setter
def attn_backend(self, backend: str) -> None:
assert backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
self.config.attn_backend = backend
resolved = resolve_attention_backend(backend)
for module in self.modules():
if isinstance(module, FAST_E1_ENCODER):
module._attn_backend = resolved
elif isinstance(module, Attention):
module.attn_backend = resolved
class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
config: E1Config
config_class = E1Config
def __init__(self, config: E1Config, **kwargs):
E1PreTrainedModel.__init__(self, config, **kwargs)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = config.gradient_checkpointing
self.prep_tokens = E1BatchPreparer()
self._attn_backend = resolve_attention_backend(config.attn_backend)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.embed_tokens = value
@torch.inference_mode()
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
if return_attention_mask:
attention_mask = (batch['sequence_ids'] != -1).long()
return last_hidden_state, attention_mask
else:
return last_hidden_state
# Ignore copy
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
past_key_values: DynamicCache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_s_max: bool = False,
**kwargs
) -> E1ModelOutputWithPast:
"""
Args:
input_ids: (batch_size, seq_length)
within_seq_position_ids: (batch_size, seq_length)
This tensor contains the position of each residue within the sequence itself.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
global_position_ids: (batch_size, seq_length)
This tensor contains the position of each residue within the global sequence.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
sequence_ids: (batch_size, seq_length)
This tensor contains the sequence id of each residue.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
past_key_values: DynamicCache
use_cache: bool
output_attentions: bool
output_hidden_states: bool
output_s_max: bool
Returns:
E1ModelOutputWithPast: Model Outputs
"""
batch_size, seq_length = input_ids.shape
if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
elif not use_cache:
past_key_values = None
global_position_ids = global_position_ids.view(-1, seq_length).long()
within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long()
sequence_ids = sequence_ids.view(-1, seq_length).long()
max_position_id = torch.max(within_seq_position_ids).item()
min_position_id = torch.min(within_seq_position_ids).item()
assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, (
f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}"
)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0))
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype
hidden_states = inputs_embeds.to(target_dtype)
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
attn_backend = self._attn_backend
has_global_layers = self.config.global_attention_every_n_layers > 0
needs_4d_masks = (attn_backend == AttentionBackend.SDPA) or output_attentions
needs_block_causal_flex = (
(attn_backend == AttentionBackend.FLEX and has_global_layers)
or (attn_backend == AttentionBackend.KERNELS_FLASH and has_global_layers)
)
needs_within_seq_flex = (attn_backend == AttentionBackend.FLEX)
attention_args: AttentionArgs | None = None
if past_key_values_length == 0:
attention_args = AttentionArgs(
block_causal_block_mask=create_block_causal_mask_optimized(sequence_ids) if needs_block_causal_flex else None,
within_seq_block_mask=create_within_seq_block_mask(sequence_ids) if needs_within_seq_flex else None,
within_seq_mask_4d=build_within_seq_mask_4d(sequence_ids) if needs_4d_masks else None,
block_causal_mask_4d=build_block_causal_mask_4d(sequence_ids) if needs_4d_masks else None,
)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
full_s_max = () if output_s_max else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore[operator]
if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
within_seq_position_ids,
global_position_ids,
sequence_ids,
attention_args,
past_key_values,
output_attentions,
output_s_max,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
attention_args=attention_args,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_s_max=output_s_max,
use_cache=use_cache,
)
hidden_states, self_attn_weights, present_key_value, s_max = layer_outputs
if use_cache:
next_decoder_cache = past_key_values = present_key_value
if output_attentions:
all_self_attns += (self_attn_weights,) # type: ignore[operator]
if full_s_max is not None:
full_s_max += (s_max,) # type: ignore[operator]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore[operator]
next_cache = next_decoder_cache if use_cache else None
return E1ModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
s_max=full_s_max,
)
class E1Model(E1PreTrainedModel, EmbeddingMixin):
config: E1Config
config_class = E1Config
def __init__(self, config: E1Config, **kwargs):
E1PreTrainedModel.__init__(self, config, **kwargs)
self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs)
self.prep_tokens = self.model.prep_tokens
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings()
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.model.set_input_embeddings(value)
@torch.inference_mode()
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs)
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
past_key_values: DynamicCache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_s_max: bool = False,
**kwargs,
) -> E1ModelOutputWithPast:
return self.model(
input_ids=input_ids,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
**kwargs,
)
class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
config: E1Config
config_class = E1Config
def __init__(self, config: E1Config, **kwargs):
E1PreTrainedModel.__init__(self, config, **kwargs)
self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs)
self.vocab_size = config.vocab_size
self.mlm_head = torch.nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size, bias=True),
nn.GELU(),
nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps),
nn.Linear(config.hidden_size, config.vocab_size, bias=True),
)
self.gradient_checkpointing = config.gradient_checkpointing
self.prep_tokens = self.model.prep_tokens
self.post_init()
@property
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
return self.model.device_mesh
@torch.inference_mode()
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
if return_attention_mask:
attention_mask = (batch['sequence_ids'] != -1).long()
return last_hidden_state, attention_mask
else:
return last_hidden_state
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
labels: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_s_max: bool = False,
**kwargs,
) -> E1MaskedLMOutputWithPast:
"""
Args:
input_ids: (batch_size, seq_length)
within_seq_position_ids: (batch_size, seq_length)
This tensor contains the position of each residue within the sequence itself.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
global_position_ids: (batch_size, seq_length)
This tensor contains the position of each residue within the global sequence.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
sequence_ids: (batch_size, seq_length)
This tensor contains the sequence id of each residue.
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
labels: (batch_size, seq_length)
past_key_values: DynamicCache
use_cache: bool
output_attentions: bool
output_hidden_states: bool
output_s_max: bool
Returns:
E1MaskedLMOutputWithPast: Model Outputs
"""
outputs: E1ModelOutputWithPast = self.model(
input_ids=input_ids,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
)
last_hidden_state = outputs.last_hidden_state
loss = None
mlm_logits = self.mlm_head(last_hidden_state).float()
mlm_loss = 0.0
if labels is not None:
mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size)
mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1)
mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none")
mask = mlm_labels_flat != self.model.padding_idx
n_mlm = mask.sum()
mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm)
loss = 0.0
loss += mlm_loss
return E1MaskedLMOutputWithPast(
loss=loss,
mlm_loss=mlm_loss,
logits=mlm_logits,
last_hidden_state=last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
s_max=outputs.s_max,
)
class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
config: E1Config
config_class = E1Config
def __init__(self, config: E1Config, **kwargs):
E1PreTrainedModel.__init__(self, config, **kwargs)
self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs)
self.vocab_size = config.vocab_size
self.num_labels = config.num_labels
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
nn.GELU(),
nn.LayerNorm(config.hidden_size * 4),
nn.Linear(config.hidden_size * 4, config.num_labels),
)
self.mse = nn.MSELoss()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
self.gradient_checkpointing = config.gradient_checkpointing
self.prep_tokens = self.model.prep_tokens
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
pooling_types = kwargs['pooling_types']
else:
pooling_types = ['mean', 'var']
self.pooler = Pooler(pooling_types)
self.post_init()
@property
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
return self.model.device_mesh
@torch.inference_mode()
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
if return_attention_mask:
attention_mask = (batch['sequence_ids'] != -1).long()
return last_hidden_state, attention_mask
else:
return last_hidden_state
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
labels: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_s_max: bool = False,
**kwargs,
) -> E1ClassificationOutputWithPast:
outputs: E1ModelOutputWithPast = self.model(
input_ids=input_ids,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
)
attention_mask = (sequence_ids != -1).long()
x = outputs.last_hidden_state
features = self.pooler(x, attention_mask)
logits = self.classifier(features)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = self.mse(logits.flatten(), labels.flatten())
else:
loss = self.mse(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = self.bce(logits, labels)
return E1ClassificationOutputWithPast(
loss=loss,
logits=logits,
last_hidden_state=x,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
s_max=outputs.s_max,
)
class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
config: E1Config
config_class = E1Config
def __init__(self, config: E1Config, **kwargs):
E1PreTrainedModel.__init__(self, config, **kwargs)
self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs)
self.vocab_size = config.vocab_size
self.num_labels = config.num_labels
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
nn.GELU(),
nn.LayerNorm(config.hidden_size * 4),
nn.Linear(config.hidden_size * 4, config.num_labels),
)
self.loss_fct = nn.CrossEntropyLoss()
self.gradient_checkpointing = config.gradient_checkpointing
self.prep_tokens = self.model.prep_tokens
self.post_init()
@property
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
return self.model.device_mesh
@torch.inference_mode()
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
if return_attention_mask:
attention_mask = (batch['sequence_ids'] != -1).long()
return last_hidden_state, attention_mask
else:
return last_hidden_state
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
labels: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
output_s_max: bool = False,
**kwargs,
) -> E1ClassificationOutputWithPast:
outputs: E1ModelOutputWithPast = self.model(
input_ids=input_ids,
within_seq_position_ids=within_seq_position_ids,
global_position_ids=global_position_ids,
sequence_ids=sequence_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
)
x = outputs.last_hidden_state
logits = self.classifier(x)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return E1ClassificationOutputWithPast(
loss=loss,
logits=logits,
last_hidden_state=x,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
s_max=outputs.s_max,
)
if __name__ == "__main__":
import random
import torch
from torch import Tensor
def print_tensor_shapes(prefix: str, obj):
if isinstance(obj, Tensor):
print(f"{prefix}{obj.shape}")
elif isinstance(obj, dict):
for name, value in obj.items():
print_tensor_shapes(f"{prefix}{name}.", value)
elif isinstance(obj, list):
for idx, value in enumerate(obj):
print_tensor_shapes(f"{prefix}[{idx}].", value)
elif isinstance(obj, tuple):
for idx, value in enumerate(obj):
print_tensor_shapes(f"{prefix}[{idx}].", value)
elif hasattr(obj, "__dict__"):
for name, value in vars(obj).items():
if name.startswith("_"):
continue
print_tensor_shapes(f"{prefix}{name}.", value)
else:
print(f"{prefix}{type(obj)}")
def get_e1_batch(tokenizer, sequences: list[str], device: torch.device):
preparer = E1BatchPreparer(data_prep_config=DataPrepConfig(max_num_positions_within_seq=64), tokenizer=tokenizer)
return preparer.get_batch_kwargs(sequences=sequences, device=device)
random.seed(0)
torch.manual_seed(0)
num_attention_heads = random.choice([2, 4])
config = E1Config(
hidden_size=16 * num_attention_heads,
intermediate_size=64 * num_attention_heads,
num_hidden_layers=random.choice([1, 2]),
num_attention_heads=num_attention_heads,
num_key_value_heads=num_attention_heads,
max_num_positions_within_seq=128,
max_num_positions_global=256,
max_num_sequences=8,
dtype="float32",
)
model = E1ForMaskedLM(config=config).eval()
tokenizer = get_tokenizer()
batch = get_e1_batch(tokenizer=tokenizer, sequences=["ACDEFG", "MKTW"], device=torch.device("cpu"))
batch["labels"] = batch["labels"].clone()
with torch.no_grad():
output = model(
input_ids=batch["input_ids"],
within_seq_position_ids=batch["within_seq_position_ids"],
global_position_ids=batch["global_position_ids"],
sequence_ids=batch["sequence_ids"],
labels=batch["labels"],
)
print("Batch shape:")
print_tensor_shapes("", batch)
print("Output shape:")
print_tensor_shapes("", output)