import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional, List from dataclasses import dataclass from transformers import PreTrainedModel from transformers.utils import ModelOutput from .configuration_compression import CompressionConfig def cosine_pairwise(embeddings): return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2) def cov(tensor, rowvar=True, bias=False): """Estimate a covariance matrix (np.cov)""" tensor = tensor if rowvar else tensor.transpose(-1, -2) tensor = tensor - tensor.mean(dim=-1, keepdim=True) factor = 1 / (tensor.shape[-1] - int(not bool(bias))) return factor * tensor @ tensor.transpose(-1, -2).conj() def remove_diag(x): n = x.shape[0] return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1) def corrcoef(tensor, rowvar=True): """Get Pearson product-moment correlation coefficients (np.corrcoef)""" covariance = cov(tensor, rowvar=rowvar) variance = covariance.diagonal(0, -1, -2) if variance.is_complex(): variance = variance.real stddev = variance.sqrt() covariance /= stddev.unsqueeze(-1) covariance /= stddev.unsqueeze(-2) if covariance.is_complex(): covariance.real.clip_(-1, 1) covariance.imag.clip_(-1, 1) else: covariance.clip_(-1, 1) return covariance def compute_correlation(base_sims, compressed_sims, rm_diag=True): if rm_diag: base_sims = remove_diag(base_sims) compressed_sims = remove_diag(compressed_sims) inputs = torch.stack([base_sims, compressed_sims], dim=1) return (1-corrcoef(inputs)[:, 0, 1]).mean() def loss_function(base_sims, compressed_sims, k_vals): outputs = [compute_correlation(base_sims, compressed_sims)] if k_vals: base_ranks = base_sims.argsort(-1, descending=True)[:, 1:] n = base_ranks.shape[1] for k in k_vals: base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k]) compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k]) outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False)) return torch.stack(outputs).unsqueeze(0) class FeedForward(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.fc1 = nn.Linear(d_in, d_out*2) self.fc2 = nn.Linear(d_out, d_out) def forward(self, x): x = self.fc1(x) x1, x2 = x.chunk(2, dim=-1) x = self.fc2(F.silu(x1) * x2) return x class CompressionHead(nn.Module): def __init__(self, d_in, d_out, dropout=0.1): super().__init__() self.ff = FeedForward(d_in, d_out) self.skip = nn.Linear(d_in, d_out) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.dropout(x) x = self.ff(x) + self.skip(x) return x @dataclass class CompressionModelOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None losses: Optional[List[torch.FloatTensor]] = None base_embedding: Optional[torch.FloatTensor] = None compressed_embeddings: Optional[List[torch.FloatTensor]] = None class CompressionModel(PreTrainedModel): config_class = CompressionConfig def __init__(self, config): super().__init__(config) self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout) for i in config.compression_sizes]) def forward(self, embedding, compute_loss=True, return_dict=True): outputs = [] losses = None if compute_loss: losses = [] emb_sims = cosine_pairwise(embedding) for head in self.heads: compressed_embedding = head(embedding) outputs.append(compressed_embedding) if compute_loss: comp_sims = cosine_pairwise(compressed_embedding) loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals) losses.append(loss) loss = torch.cat(losses).sum() if not return_dict: return (loss, losses, embedding, outputs) return CompressionModelOutput(loss=loss, losses=losses, base_embedding=embedding, compressed_embeddings=outputs)