|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class LatentAttentionPooling(nn.Module): |
|
def __init__(self, hidden_size, num_latents=2048, num_heads=8): |
|
super().__init__() |
|
self.latents = nn.Parameter(torch.randn(num_latents, hidden_size)) |
|
self.cross_attn = nn.MultiheadAttention( |
|
embed_dim=hidden_size, |
|
num_heads=num_heads, |
|
batch_first=True |
|
) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.GELU(), |
|
nn.Linear(hidden_size, hidden_size), |
|
) |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
bsz, seq_len, d = hidden_states.shape |
|
queries = self.latents.unsqueeze(0).expand(bsz, -1, -1) |
|
key_padding_mask = None |
|
if attention_mask is not None: |
|
key_padding_mask = attention_mask == 0 |
|
attn_out, _ = self.cross_attn( |
|
queries, |
|
hidden_states, |
|
hidden_states, |
|
key_padding_mask=key_padding_mask |
|
) |
|
return self.mlp(attn_out).mean(dim=1) |
|
|
|
class MatryoshkaProjection(nn.Module): |
|
def __init__(self, hidden_size, max_embed_dim): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.randn(max_embed_dim, hidden_size)) |
|
|
|
def forward(self, pooled): |
|
return F.linear(pooled, self.weight) |
|
|
|
print("Created modeling.py with custom classes.") |
|
|