ekolodin commited on
Commit
e5a0718
·
verified ·
1 Parent(s): ac04eb8

Обновление весов модели

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
The diff for this file is too large to render. See raw diff
 
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
3
  "add_eos": true,
4
  "add_pad_token": true,
5
  "architectures": [
@@ -10,7 +10,7 @@
10
  "AutoModel": "modeling_gigarembed.GigarEmbedModel"
11
  },
12
  "hidden_size": 2048,
13
- "is_mask_instruction": false,
14
  "latent_attention_config": {
15
  "cross_dim_head": 2048,
16
  "hidden_dim": 2048,
@@ -21,7 +21,8 @@
21
  "model_type": "gigarembed",
22
  "padding_side": "right",
23
  "text_config": {
24
- "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
 
25
  "activation_checkpoint_layers_num": null,
26
  "add_cross_attention": false,
27
  "architectures": [
@@ -78,7 +79,7 @@
78
  "num_attention_heads": 16,
79
  "num_beam_groups": 1,
80
  "num_beams": 1,
81
- "num_hidden_layers": 27,
82
  "num_key_value_heads": 2,
83
  "num_return_sequences": 1,
84
  "output_attentions": false,
@@ -122,5 +123,5 @@
122
  "vocab_size": 128256
123
  },
124
  "torch_dtype": "float32",
125
- "transformers_version": "4.40.0.dev0"
126
  }
 
1
  {
2
+ "_name_or_path": "/home/jovyan/ekolodin/gigachat-embeddings/ckpt/multitask_prenorm_lr2e-5/checkpoint-6591",
3
  "add_eos": true,
4
  "add_pad_token": true,
5
  "architectures": [
 
10
  "AutoModel": "modeling_gigarembed.GigarEmbedModel"
11
  },
12
  "hidden_size": 2048,
13
+ "is_mask_instruction": true,
14
  "latent_attention_config": {
15
  "cross_dim_head": 2048,
16
  "hidden_dim": 2048,
 
21
  "model_type": "gigarembed",
22
  "padding_side": "right",
23
  "text_config": {
24
+ "_attn_implementation_autoset": false,
25
+ "_name_or_path": "/home/jovyan/ekolodin/models/qiwiembed2.5_3b_pretrain/",
26
  "activation_checkpoint_layers_num": null,
27
  "add_cross_attention": false,
28
  "architectures": [
 
79
  "num_attention_heads": 16,
80
  "num_beam_groups": 1,
81
  "num_beams": 1,
82
+ "num_hidden_layers": 36,
83
  "num_key_value_heads": 2,
84
  "num_return_sequences": 1,
85
  "output_attentions": false,
 
123
  "vocab_size": 128256
124
  },
125
  "torch_dtype": "float32",
126
+ "transformers_version": "4.46.3"
127
  }
config_sentence_transformers.json CHANGED
@@ -1,9 +1,10 @@
1
  {
2
  "__version__": {
3
- "sentence_transformers": "2.2.2",
4
- "transformers": "4.40.0.dev0",
5
- "pytorch": "2.0.1+cu118"
6
  },
7
  "prompts": {},
8
- "default_prompt_name": null
9
- }
 
 
1
  {
2
  "__version__": {
3
+ "sentence_transformers": "3.3.1",
4
+ "transformers": "4.46.3",
5
+ "pytorch": "2.1.1+cu121"
6
  },
7
  "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": "cosine"
10
+ }
configuration_gigarembed.py CHANGED
@@ -76,7 +76,6 @@ class LatentAttentionConfig(PretrainedConfig):
76
  self.cross_dim_head = cross_dim_head
77
  self._attn_implementation = "eager"
78
 
79
-
80
  class BidirectionalLlamaConfig(LlamaConfig):
81
  model_type = BIDIR_LLAMA_TYPE
82
  keys_to_ignore_at_inference = ["past_key_values"]
 
76
  self.cross_dim_head = cross_dim_head
77
  self._attn_implementation = "eager"
78
 
 
79
  class BidirectionalLlamaConfig(LlamaConfig):
80
  model_type = BIDIR_LLAMA_TYPE
81
  keys_to_ignore_at_inference = ["past_key_values"]
model-00001-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f5bdcf6ab584d8e3fa1e53ad3061a203125be81e043b11e7c867e04670e8aa7
3
- size 4913926592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f32eedfe2127f8e9507427af1d796c6547df4c8e5795e4ea8b3a22a96e782292
3
+ size 4930720644
model-00002-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:29f27ca086d141ee606037884f8bb93176a61032230ba1ca92ba0c2fe7615cb2
3
  size 4932780264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edc8c0c52613a2712e8c65b3d8b4249b6e99622c695ee5aec698ca37a5a556d3
3
  size 4932780264
model-00003-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:60c4fbeddfae2f00234d2d9e6646cd7a11914454eaf46e278f85bfb088d1417f
3
- size 270557856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cb8d3f0fb5c526162adc18efcc9e1c13d07088602775e742037a5e53d1531b9
3
+ size 3045246736
model.safetensors.index.json CHANGED
@@ -1,9 +1,26 @@
1
  {
2
  "metadata": {
3
- "total_size": 10117234688
4
  },
5
  "weight_map": {
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "latent_attention_model.latents": "model-00001-of-00003.safetensors",
 
 
 
 
7
  "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
  "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
  "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
@@ -185,6 +202,33 @@
185
  "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
186
  "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
187
  "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
189
  "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
190
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
@@ -194,6 +238,60 @@
194
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
195
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
196
  "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
198
  "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
199
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 12908707844
4
  },
5
  "weight_map": {
6
+ "latent_attention_model.cross_attend_blocks.0.fn.to_kv.weight": "model-00001-of-00003.safetensors",
7
+ "latent_attention_model.cross_attend_blocks.0.fn.to_out.weight": "model-00001-of-00003.safetensors",
8
+ "latent_attention_model.cross_attend_blocks.0.fn.to_q.weight": "model-00001-of-00003.safetensors",
9
+ "latent_attention_model.cross_attend_blocks.0.norm.bias": "model-00001-of-00003.safetensors",
10
+ "latent_attention_model.cross_attend_blocks.0.norm.weight": "model-00001-of-00003.safetensors",
11
+ "latent_attention_model.cross_attend_blocks.0.norm_context.bias": "model-00001-of-00003.safetensors",
12
+ "latent_attention_model.cross_attend_blocks.0.norm_context.weight": "model-00001-of-00003.safetensors",
13
+ "latent_attention_model.cross_attend_blocks.1.fn.net.0.bias": "model-00001-of-00003.safetensors",
14
+ "latent_attention_model.cross_attend_blocks.1.fn.net.0.weight": "model-00001-of-00003.safetensors",
15
+ "latent_attention_model.cross_attend_blocks.1.fn.net.2.bias": "model-00001-of-00003.safetensors",
16
+ "latent_attention_model.cross_attend_blocks.1.fn.net.2.weight": "model-00001-of-00003.safetensors",
17
+ "latent_attention_model.cross_attend_blocks.1.norm.bias": "model-00001-of-00003.safetensors",
18
+ "latent_attention_model.cross_attend_blocks.1.norm.weight": "model-00001-of-00003.safetensors",
19
  "latent_attention_model.latents": "model-00001-of-00003.safetensors",
20
+ "latent_attention_model.w_lexical.bias": "model-00001-of-00003.safetensors",
21
+ "latent_attention_model.w_lexical.weight": "model-00001-of-00003.safetensors",
22
+ "latent_attention_model.w_multi_vector.bias": "model-00001-of-00003.safetensors",
23
+ "latent_attention_model.w_multi_vector.weight": "model-00001-of-00003.safetensors",
24
  "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
25
  "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
26
  "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
 
202
  "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
203
  "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
204
  "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
205
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
206
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
213
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
214
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
215
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
216
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
217
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
218
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
219
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
220
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
221
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
222
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
223
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
224
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
225
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
226
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
227
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
228
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
229
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
230
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
231
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
232
  "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
233
  "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
234
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
 
238
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
239
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
240
  "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
241
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
242
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
243
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
244
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
245
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
246
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
247
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
248
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
249
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
250
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
251
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
256
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
257
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
258
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
259
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
260
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
261
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
262
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
263
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
264
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
265
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
266
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
267
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
268
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
269
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
270
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
271
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
272
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
273
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
274
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
275
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
276
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
277
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
278
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
279
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
280
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
281
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
282
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
283
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
284
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
285
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
286
+ "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
287
+ "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
288
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
289
+ "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
290
+ "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
291
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
292
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
293
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
294
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
295
  "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
296
  "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
297
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
modeling_gigarembed.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  import os
4
  import json
5
  import numpy as np
 
 
6
  from functools import partial
7
  from contextlib import nullcontext
8
  from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
@@ -227,15 +229,22 @@ def input_transform_func(
227
  class PreNorm(torch.nn.Module):
228
  def __init__(self, dim, fn, context_dim = None):
229
  super().__init__()
230
- # TODO remove this layer, we don't use it
 
 
231
 
232
  def forward(self, x, **kwargs):
233
- return x
 
 
 
 
 
234
 
235
  class GEGLU(torch.nn.Module):
236
  def forward(self, x):
237
  x, gates = x.chunk(2, dim = -1)
238
- return x * torch.nn.functional.gelu(gates)
239
 
240
  class FeedForward(torch.nn.Module):
241
  def __init__(self, dim, mult = 4):
@@ -275,17 +284,8 @@ class Attention(torch.nn.Module):
275
  k, v = self.to_kv(context).chunk(2, dim = -1)
276
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
277
 
278
- attn_weights = torch.matmul(q, k.transpose(-1, -2)) / self.scale
279
-
280
- mask_value = torch.finfo(attn_weights.dtype).min
281
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
282
-
283
- padding_mask = mask[:, :, None].repeat(self.heads, 1, 1).bool()
284
-
285
- attn_weights = torch.where(padding_mask, attn_weights, mask_value)
286
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
287
-
288
- out = torch.matmul(attn_weights, v)
289
  out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
290
  return self.to_out(out)
291
 
@@ -304,31 +304,35 @@ class LatentAttentionModel(PreTrainedModel):
304
  context_dim = dim),
305
  PreNorm(latent_dim, FeedForward(latent_dim)),
306
  ])
307
- self.output_normalize = config.output_normalize
 
 
 
 
308
  self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
 
309
 
310
  def forward(self, hiddens, attention_mask: torch.Tensor=None):
311
  # cross-attention block
312
  cross_attn, cross_ff = self.cross_attend_blocks
313
  b, *_, device = *hiddens.shape, hiddens.device
314
  x = repeat(self.latents, 'n d -> b n d', b = b)
315
- hiddens = cross_attn(hiddens, context=x, mask=attention_mask) + hiddens
316
- hiddens = cross_ff(hiddens) + hiddens
317
  if attention_mask != None:
318
- s = torch.sum(hiddens * attention_mask.unsqueeze(-1).float(), dim=1)
319
- d = attention_mask.sum(dim=1, keepdim=True).float()
320
- hiddens = s / d
321
- if self.output_normalize:
322
- hiddens = torch.nn.functional.normalize(hiddens, p=2, dim=-1)
323
- return hiddens
324
-
325
  class GigarEmbedModel(PreTrainedModel):
326
  config_class = GigarEmbedConfig
327
  _no_split_modules = ["LlamaDecoderLayer", "LatentAttentionModel"]
328
 
329
  def __init__(self, config: GigarEmbedConfig):
330
  super().__init__(config)
331
- self.latent_attention_model = AutoModel.from_config(config.latent_attention_config).float()
332
  self.model = AutoModel.from_config(
333
  config.text_config,
334
  ) if config.text_config is not None else None
@@ -339,12 +343,6 @@ class GigarEmbedModel(PreTrainedModel):
339
  self.mask_type = config.mask_type
340
  if config.add_pad_token and self.tokenizer is not None:
341
  self.add_pad_token()
342
-
343
- self.latent_attention_model.apply(self._init_weights)
344
-
345
- def _init_weights(self, module):
346
- if isinstance(module, torch.nn.Linear):
347
- torch.nn.init.xavier_normal_(module.weight)
348
 
349
  def add_pad_token(self):
350
  self.tokenizer.pad_token_id = 0
@@ -360,7 +358,7 @@ class GigarEmbedModel(PreTrainedModel):
360
  # Mask out the instruction tokens for mean-pooling
361
  attention_mask[:, :instruction_lens] = 0
362
  features: GigarEmbedFeatures = {
363
- 'input_ids': batch_dict['input_ids'],
364
  'attention_mask': batch_dict['attention_mask'],
365
  'pool_mask': attention_mask,
366
  }
@@ -410,12 +408,14 @@ class GigarEmbedModel(PreTrainedModel):
410
  def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None,
411
  return_dict: bool=True, **kwargs):
412
  kwargs.pop('token_type_ids', None)
413
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
414
-
415
- embeds = self.latent_attention_model(
416
- outputs.last_hidden_state,
417
- pool_mask,
418
- )
 
 
419
  if not return_dict:
420
  return (embeds,)
421
  return {"sentence_embeddings": embeds}
 
3
  import os
4
  import json
5
  import numpy as np
6
+ import torch.nn.functional as F
7
+
8
  from functools import partial
9
  from contextlib import nullcontext
10
  from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
 
229
  class PreNorm(torch.nn.Module):
230
  def __init__(self, dim, fn, context_dim = None):
231
  super().__init__()
232
+ self.fn = fn
233
+ self.norm = torch.nn.LayerNorm(dim)
234
+ self.norm_context = torch.nn.LayerNorm(context_dim) if exists(context_dim) else None
235
 
236
  def forward(self, x, **kwargs):
237
+ x = self.norm(x)
238
+ if exists(self.norm_context):
239
+ context = kwargs['context']
240
+ normed_context = self.norm_context(context)
241
+ kwargs.update(context = normed_context)
242
+ return self.fn(x, **kwargs)
243
 
244
  class GEGLU(torch.nn.Module):
245
  def forward(self, x):
246
  x, gates = x.chunk(2, dim = -1)
247
+ return x * F.gelu(gates)
248
 
249
  class FeedForward(torch.nn.Module):
250
  def __init__(self, dim, mult = 4):
 
284
  k, v = self.to_kv(context).chunk(2, dim = -1)
285
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
286
 
287
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
288
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
 
 
 
 
 
 
 
 
 
289
  out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
290
  return self.to_out(out)
291
 
 
304
  context_dim = dim),
305
  PreNorm(latent_dim, FeedForward(latent_dim)),
306
  ])
307
+
308
+ self.w_lexical = torch.nn.Linear(latent_dim, 1)
309
+ self.w_multi_vector = torch.nn.Linear(latent_dim, latent_dim)
310
+
311
+ # self.output_normalize = config.output_normalize
312
  self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
313
+ self._attn_implementation = "eager"
314
 
315
  def forward(self, hiddens, attention_mask: torch.Tensor=None):
316
  # cross-attention block
317
  cross_attn, cross_ff = self.cross_attend_blocks
318
  b, *_, device = *hiddens.shape, hiddens.device
319
  x = repeat(self.latents, 'n d -> b n d', b = b)
320
+ output = cross_attn(hiddens, context=x, mask=attention_mask) + hiddens
321
+ output = cross_ff(output) + output
322
  if attention_mask != None:
323
+ s = torch.sum(output * attention_mask.unsqueeze(-1), dim=1)
324
+ d = attention_mask.sum(dim=1, keepdim=True)
325
+ output = s / d
326
+ output = F.normalize(output, p=2, dim=-1)
327
+ return output
328
+
 
329
  class GigarEmbedModel(PreTrainedModel):
330
  config_class = GigarEmbedConfig
331
  _no_split_modules = ["LlamaDecoderLayer", "LatentAttentionModel"]
332
 
333
  def __init__(self, config: GigarEmbedConfig):
334
  super().__init__(config)
335
+ self.latent_attention_model = AutoModel.from_config(config.latent_attention_config)
336
  self.model = AutoModel.from_config(
337
  config.text_config,
338
  ) if config.text_config is not None else None
 
343
  self.mask_type = config.mask_type
344
  if config.add_pad_token and self.tokenizer is not None:
345
  self.add_pad_token()
 
 
 
 
 
 
346
 
347
  def add_pad_token(self):
348
  self.tokenizer.pad_token_id = 0
 
358
  # Mask out the instruction tokens for mean-pooling
359
  attention_mask[:, :instruction_lens] = 0
360
  features: GigarEmbedFeatures = {
361
+ 'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
362
  'attention_mask': batch_dict['attention_mask'],
363
  'pool_mask': attention_mask,
364
  }
 
408
  def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None,
409
  return_dict: bool=True, **kwargs):
410
  kwargs.pop('token_type_ids', None)
411
+
412
+ with torch.autocast('cuda', dtype=torch.bfloat16):
413
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
414
+
415
+ if pool_mask is None: pool_mask = attention_mask.clone()
416
+
417
+ embeds = self.latent_attention_model(outputs.last_hidden_state, pool_mask)
418
+
419
  if not return_dict:
420
  return (embeds,)
421
  return {"sentence_embeddings": embeds}
modules.json CHANGED
@@ -10,11 +10,5 @@
10
  "name": "1",
11
  "path": "1_Pooling",
12
  "type": "sentence_transformers.models.Pooling"
13
- },
14
- {
15
- "idx": 2,
16
- "name": "2",
17
- "path": "2_Normalize",
18
- "type": "sentence_transformers.models.Normalize"
19
  }
20
- ]
 
10
  "name": "1",
11
  "path": "1_Pooling",
12
  "type": "sentence_transformers.models.Pooling"
 
 
 
 
 
 
13
  }
14
+ ]
sentence_bert_config.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
- "max_seq_length": 4096,
3
  "do_lower_case": false
4
- }
 
1
  {
2
+ "max_seq_length": null,
3
  "do_lower_case": false
4
+ }
special_tokens_map.json CHANGED
@@ -1,5 +1,37 @@
1
  {
2
- "bos_token": "<s>",
3
- "eos_token": "</s>",
4
- "unk_token": "<unk>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  }
 
1
  {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "<unk>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
  }
tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -2076,8 +2076,16 @@
2076
  "bos_token": "<s>",
2077
  "clean_up_tokenization_spaces": true,
2078
  "eos_token": "</s>",
 
2079
  "model_max_length": 1000000000000000019884624838656,
 
 
 
 
 
 
2080
  "tokenizer_class": "PreTrainedTokenizerFast",
2081
- "unk_token": "<unk>",
2082
- "pad_token": "<unk>"
 
2083
  }
 
2076
  "bos_token": "<s>",
2077
  "clean_up_tokenization_spaces": true,
2078
  "eos_token": "</s>",
2079
+ "max_length": 512,
2080
  "model_max_length": 1000000000000000019884624838656,
2081
+ "pad_to_multiple_of": null,
2082
+ "pad_token": "<unk>",
2083
+ "pad_token_type_id": 0,
2084
+ "padding_side": "right",
2085
+ "sep_token": "<unk>",
2086
+ "stride": 0,
2087
  "tokenizer_class": "PreTrainedTokenizerFast",
2088
+ "truncation_side": "right",
2089
+ "truncation_strategy": "longest_first",
2090
+ "unk_token": "<unk>"
2091
  }