evoreign's picture
Upload modeling.py with huggingface_hub
d837865 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class LatentAttentionPooling(nn.Module):
def __init__(self, hidden_size, num_latents=2048, num_heads=8):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, hidden_size))
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_heads,
batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, hidden_states, attention_mask=None):
bsz, seq_len, d = hidden_states.shape
queries = self.latents.unsqueeze(0).expand(bsz, -1, -1)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = attention_mask == 0
attn_out, _ = self.cross_attn(
queries,
hidden_states,
hidden_states,
key_padding_mask=key_padding_mask
)
return self.mlp(attn_out).mean(dim=1)
class MatryoshkaProjection(nn.Module):
def __init__(self, hidden_size, max_embed_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(max_embed_dim, hidden_size))
def forward(self, pooled):
return F.linear(pooled, self.weight)
print("Created modeling.py with custom classes.")