Upload model
Browse files- config.json +26 -0
- configuration_compression.py +21 -0
- model.safetensors +3 -0
- modeling_compression.py +132 -0
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"CompressionModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_compression.CompressionConfig",
|
7 |
+
"AutoModel": "modeling_compression.CompressionModel"
|
8 |
+
},
|
9 |
+
"compression_sizes": [
|
10 |
+
512,
|
11 |
+
256,
|
12 |
+
128,
|
13 |
+
64,
|
14 |
+
32
|
15 |
+
],
|
16 |
+
"dropout": 0.1,
|
17 |
+
"input_size": 768,
|
18 |
+
"loss_k_vals": [
|
19 |
+
10,
|
20 |
+
100,
|
21 |
+
256
|
22 |
+
],
|
23 |
+
"model_type": "compression_head",
|
24 |
+
"torch_dtype": "float32",
|
25 |
+
"transformers_version": "4.38.2"
|
26 |
+
}
|
configuration_compression.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class CompressionConfig(PretrainedConfig):
|
5 |
+
model_type = "compression_head"
|
6 |
+
|
7 |
+
def __init__(self,
|
8 |
+
input_size: int = 768,
|
9 |
+
compression_sizes: List[int] = [512, 256, 128, 64, 32],
|
10 |
+
dropout: float = 0.1,
|
11 |
+
loss_k_vals: List[int] = [],
|
12 |
+
**kwargs
|
13 |
+
):
|
14 |
+
|
15 |
+
self.input_size = input_size
|
16 |
+
self.compression_sizes = compression_sizes
|
17 |
+
self.dropout = dropout
|
18 |
+
self.loss_k_vals = loss_k_vals
|
19 |
+
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d62042eb9cbd70af97e6e2abbcfe3fa25972b969d17c421ad173348fee8b4ba
|
3 |
+
size 10557544
|
modeling_compression.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from typing import Tuple, Optional, List
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
from transformers import PreTrainedModel
|
9 |
+
from transformers.utils import ModelOutput
|
10 |
+
|
11 |
+
from .configuration_compression import CompressionConfig
|
12 |
+
|
13 |
+
def cosine_pairwise(embeddings):
|
14 |
+
return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
|
15 |
+
|
16 |
+
def cov(tensor, rowvar=True, bias=False):
|
17 |
+
"""Estimate a covariance matrix (np.cov)"""
|
18 |
+
tensor = tensor if rowvar else tensor.transpose(-1, -2)
|
19 |
+
tensor = tensor - tensor.mean(dim=-1, keepdim=True)
|
20 |
+
factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
|
21 |
+
return factor * tensor @ tensor.transpose(-1, -2).conj()
|
22 |
+
|
23 |
+
def remove_diag(x):
|
24 |
+
n = x.shape[0]
|
25 |
+
return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1)
|
26 |
+
|
27 |
+
def corrcoef(tensor, rowvar=True):
|
28 |
+
"""Get Pearson product-moment correlation coefficients (np.corrcoef)"""
|
29 |
+
covariance = cov(tensor, rowvar=rowvar)
|
30 |
+
variance = covariance.diagonal(0, -1, -2)
|
31 |
+
if variance.is_complex():
|
32 |
+
variance = variance.real
|
33 |
+
stddev = variance.sqrt()
|
34 |
+
covariance /= stddev.unsqueeze(-1)
|
35 |
+
covariance /= stddev.unsqueeze(-2)
|
36 |
+
if covariance.is_complex():
|
37 |
+
covariance.real.clip_(-1, 1)
|
38 |
+
covariance.imag.clip_(-1, 1)
|
39 |
+
else:
|
40 |
+
covariance.clip_(-1, 1)
|
41 |
+
return covariance
|
42 |
+
|
43 |
+
def compute_correlation(base_sims, compressed_sims, rm_diag=True):
|
44 |
+
if rm_diag:
|
45 |
+
base_sims = remove_diag(base_sims)
|
46 |
+
compressed_sims = remove_diag(compressed_sims)
|
47 |
+
|
48 |
+
inputs = torch.stack([base_sims,
|
49 |
+
compressed_sims], dim=1)
|
50 |
+
return (1-corrcoef(inputs)[:, 0, 1]).mean()
|
51 |
+
|
52 |
+
def loss_function(base_sims, compressed_sims, k_vals):
|
53 |
+
outputs = [compute_correlation(base_sims, compressed_sims)]
|
54 |
+
|
55 |
+
if k_vals:
|
56 |
+
base_ranks = base_sims.argsort(-1, descending=True)[:, 1:]
|
57 |
+
n = base_ranks.shape[1]
|
58 |
+
for k in k_vals:
|
59 |
+
base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k])
|
60 |
+
compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k])
|
61 |
+
outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False))
|
62 |
+
|
63 |
+
return torch.stack(outputs).unsqueeze(0)
|
64 |
+
|
65 |
+
class FeedForward(nn.Module):
|
66 |
+
def __init__(self, d_in, d_out):
|
67 |
+
super().__init__()
|
68 |
+
self.fc1 = nn.Linear(d_in, d_out*2)
|
69 |
+
self.fc2 = nn.Linear(d_out, d_out)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.fc1(x)
|
73 |
+
x1, x2 = x.chunk(2, dim=-1)
|
74 |
+
x = self.fc2(F.silu(x1) * x2)
|
75 |
+
return x
|
76 |
+
|
77 |
+
class CompressionHead(nn.Module):
|
78 |
+
def __init__(self, d_in, d_out, dropout=0.1):
|
79 |
+
super().__init__()
|
80 |
+
self.ff = FeedForward(d_in, d_out)
|
81 |
+
self.skip = nn.Linear(d_in, d_out)
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.dropout(x)
|
86 |
+
x = self.ff(x) + self.skip(x)
|
87 |
+
return x
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class CompressionModelOutput(ModelOutput):
|
91 |
+
loss: Optional[torch.FloatTensor] = None
|
92 |
+
losses: Optional[List[torch.FloatTensor]] = None
|
93 |
+
base_embedding: Optional[torch.FloatTensor] = None
|
94 |
+
compressed_embeddings: Optional[List[torch.FloatTensor]] = None
|
95 |
+
|
96 |
+
class CompressionModel(PreTrainedModel):
|
97 |
+
config_class = CompressionConfig
|
98 |
+
def __init__(self, config):
|
99 |
+
super().__init__(config)
|
100 |
+
self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout)
|
101 |
+
for i in config.compression_sizes])
|
102 |
+
|
103 |
+
def forward(self, embedding, compute_loss=True, return_dict=True):
|
104 |
+
outputs = []
|
105 |
+
losses = None
|
106 |
+
|
107 |
+
if compute_loss:
|
108 |
+
losses = []
|
109 |
+
emb_sims = cosine_pairwise(embedding)
|
110 |
+
|
111 |
+
for head in self.heads:
|
112 |
+
compressed_embedding = head(embedding)
|
113 |
+
outputs.append(compressed_embedding)
|
114 |
+
|
115 |
+
if compute_loss:
|
116 |
+
comp_sims = cosine_pairwise(compressed_embedding)
|
117 |
+
loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals)
|
118 |
+
losses.append(loss)
|
119 |
+
|
120 |
+
loss = torch.cat(losses).sum()
|
121 |
+
|
122 |
+
if not return_dict:
|
123 |
+
return (loss, losses, embedding, outputs)
|
124 |
+
|
125 |
+
return CompressionModelOutput(loss=loss,
|
126 |
+
losses=losses,
|
127 |
+
base_embedding=embedding,
|
128 |
+
compressed_embeddings=outputs)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|