import torch import torch.nn as nn import torch.nn.functional as F class LatentAttentionPooling(nn.Module): def __init__(self, hidden_size, num_latents=2048, num_heads=8): super().__init__() self.latents = nn.Parameter(torch.randn(num_latents, hidden_size)) self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, batch_first=True ) self.mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size), ) def forward(self, hidden_states, attention_mask=None): bsz, seq_len, d = hidden_states.shape queries = self.latents.unsqueeze(0).expand(bsz, -1, -1) key_padding_mask = None if attention_mask is not None: key_padding_mask = attention_mask == 0 attn_out, _ = self.cross_attn( queries, hidden_states, hidden_states, key_padding_mask=key_padding_mask ) return self.mlp(attn_out).mean(dim=1) class MatryoshkaProjection(nn.Module): def __init__(self, hidden_size, max_embed_dim): super().__init__() self.weight = nn.Parameter(torch.randn(max_embed_dim, hidden_size)) def forward(self, pooled): return F.linear(pooled, self.weight) print("Created modeling.py with custom classes.")