entropy commited on
Commit
9700d2e
·
verified ·
1 Parent(s): 5ca8431

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CompressionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_compression.CompressionConfig",
7
+ "AutoModel": "modeling_compression.CompressionModel"
8
+ },
9
+ "compression_sizes": [
10
+ 512,
11
+ 256,
12
+ 128,
13
+ 64,
14
+ 32
15
+ ],
16
+ "dropout": 0.1,
17
+ "input_size": 768,
18
+ "loss_k_vals": [
19
+ 10,
20
+ 100,
21
+ 256
22
+ ],
23
+ "model_type": "compression_head",
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.38.2"
26
+ }
configuration_compression.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class CompressionConfig(PretrainedConfig):
5
+ model_type = "compression_head"
6
+
7
+ def __init__(self,
8
+ input_size: int = 768,
9
+ compression_sizes: List[int] = [512, 256, 128, 64, 32],
10
+ dropout: float = 0.1,
11
+ loss_k_vals: List[int] = [],
12
+ **kwargs
13
+ ):
14
+
15
+ self.input_size = input_size
16
+ self.compression_sizes = compression_sizes
17
+ self.dropout = dropout
18
+ self.loss_k_vals = loss_k_vals
19
+
20
+ super().__init__(**kwargs)
21
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d62042eb9cbd70af97e6e2abbcfe3fa25972b969d17c421ad173348fee8b4ba
3
+ size 10557544
modeling_compression.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Tuple, Optional, List
6
+ from dataclasses import dataclass
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
+
11
+ from .configuration_compression import CompressionConfig
12
+
13
+ def cosine_pairwise(embeddings):
14
+ return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
15
+
16
+ def cov(tensor, rowvar=True, bias=False):
17
+ """Estimate a covariance matrix (np.cov)"""
18
+ tensor = tensor if rowvar else tensor.transpose(-1, -2)
19
+ tensor = tensor - tensor.mean(dim=-1, keepdim=True)
20
+ factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
21
+ return factor * tensor @ tensor.transpose(-1, -2).conj()
22
+
23
+ def remove_diag(x):
24
+ n = x.shape[0]
25
+ return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1)
26
+
27
+ def corrcoef(tensor, rowvar=True):
28
+ """Get Pearson product-moment correlation coefficients (np.corrcoef)"""
29
+ covariance = cov(tensor, rowvar=rowvar)
30
+ variance = covariance.diagonal(0, -1, -2)
31
+ if variance.is_complex():
32
+ variance = variance.real
33
+ stddev = variance.sqrt()
34
+ covariance /= stddev.unsqueeze(-1)
35
+ covariance /= stddev.unsqueeze(-2)
36
+ if covariance.is_complex():
37
+ covariance.real.clip_(-1, 1)
38
+ covariance.imag.clip_(-1, 1)
39
+ else:
40
+ covariance.clip_(-1, 1)
41
+ return covariance
42
+
43
+ def compute_correlation(base_sims, compressed_sims, rm_diag=True):
44
+ if rm_diag:
45
+ base_sims = remove_diag(base_sims)
46
+ compressed_sims = remove_diag(compressed_sims)
47
+
48
+ inputs = torch.stack([base_sims,
49
+ compressed_sims], dim=1)
50
+ return (1-corrcoef(inputs)[:, 0, 1]).mean()
51
+
52
+ def loss_function(base_sims, compressed_sims, k_vals):
53
+ outputs = [compute_correlation(base_sims, compressed_sims)]
54
+
55
+ if k_vals:
56
+ base_ranks = base_sims.argsort(-1, descending=True)[:, 1:]
57
+ n = base_ranks.shape[1]
58
+ for k in k_vals:
59
+ base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k])
60
+ compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k])
61
+ outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False))
62
+
63
+ return torch.stack(outputs).unsqueeze(0)
64
+
65
+ class FeedForward(nn.Module):
66
+ def __init__(self, d_in, d_out):
67
+ super().__init__()
68
+ self.fc1 = nn.Linear(d_in, d_out*2)
69
+ self.fc2 = nn.Linear(d_out, d_out)
70
+
71
+ def forward(self, x):
72
+ x = self.fc1(x)
73
+ x1, x2 = x.chunk(2, dim=-1)
74
+ x = self.fc2(F.silu(x1) * x2)
75
+ return x
76
+
77
+ class CompressionHead(nn.Module):
78
+ def __init__(self, d_in, d_out, dropout=0.1):
79
+ super().__init__()
80
+ self.ff = FeedForward(d_in, d_out)
81
+ self.skip = nn.Linear(d_in, d_out)
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ def forward(self, x):
85
+ x = self.dropout(x)
86
+ x = self.ff(x) + self.skip(x)
87
+ return x
88
+
89
+ @dataclass
90
+ class CompressionModelOutput(ModelOutput):
91
+ loss: Optional[torch.FloatTensor] = None
92
+ losses: Optional[List[torch.FloatTensor]] = None
93
+ base_embedding: Optional[torch.FloatTensor] = None
94
+ compressed_embeddings: Optional[List[torch.FloatTensor]] = None
95
+
96
+ class CompressionModel(PreTrainedModel):
97
+ config_class = CompressionConfig
98
+ def __init__(self, config):
99
+ super().__init__(config)
100
+ self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout)
101
+ for i in config.compression_sizes])
102
+
103
+ def forward(self, embedding, compute_loss=True, return_dict=True):
104
+ outputs = []
105
+ losses = None
106
+
107
+ if compute_loss:
108
+ losses = []
109
+ emb_sims = cosine_pairwise(embedding)
110
+
111
+ for head in self.heads:
112
+ compressed_embedding = head(embedding)
113
+ outputs.append(compressed_embedding)
114
+
115
+ if compute_loss:
116
+ comp_sims = cosine_pairwise(compressed_embedding)
117
+ loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals)
118
+ losses.append(loss)
119
+
120
+ loss = torch.cat(losses).sum()
121
+
122
+ if not return_dict:
123
+ return (loss, losses, embedding, outputs)
124
+
125
+ return CompressionModelOutput(loss=loss,
126
+ losses=losses,
127
+ base_embedding=embedding,
128
+ compressed_embeddings=outputs)
129
+
130
+
131
+
132
+