jupyterjazz commited on
Commit
12eb796
·
1 Parent(s): 7c77aab

refactor: multi-vec, st truncation, etc

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (4) hide show
  1. config.json +1 -1
  2. custom_st.py +5 -5
  3. modeling_jina_embeddings_v4.py +5 -7
  4. modules.json +1 -1
config.json CHANGED
@@ -55,6 +55,6 @@
55
  "vocab_size": 151936,
56
  "truncate_dim": null,
57
  "task_names": ["retrieval", "text-matching", "code"],
58
- "matryoshka_dims": [128, 256, 512, 1024],
59
  "_attn_implementation": "flash_attention_2"
60
  }
 
55
  "vocab_size": 151936,
56
  "truncate_dim": null,
57
  "task_names": ["retrieval", "text-matching", "code"],
58
+ "matryoshka_dims": [128, 256, 512, 1024, 2048],
59
  "_attn_implementation": "flash_attention_2"
60
  }
custom_st.py CHANGED
@@ -103,7 +103,7 @@ class Transformer(nn.Module):
103
  return encoding
104
 
105
  def forward(
106
- self, features: Dict[str, torch.Tensor], task: Optional[str] = None
107
  ) -> Dict[str, torch.Tensor]:
108
  self.model.eval()
109
 
@@ -136,8 +136,8 @@ class Transformer(nn.Module):
136
  text_embeddings = self.model(
137
  **text_batch, task_label=task
138
  ).single_vec_emb
139
- if self.config.truncate_dim:
140
- text_embeddings = text_embeddings[:, : self.config.truncate_dim]
141
  text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
142
  for i, embedding in enumerate(text_embeddings):
143
  all_embeddings.append((text_indices[i], embedding))
@@ -154,8 +154,8 @@ class Transformer(nn.Module):
154
  img_embeddings = self.model(
155
  **image_batch, task_label=task
156
  ).single_vec_emb
157
- if self.config.truncate_dim:
158
- img_embeddings = img_embeddings[:, : self.config.truncate_dim]
159
  img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
160
 
161
  for i, embedding in enumerate(img_embeddings):
 
103
  return encoding
104
 
105
  def forward(
106
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None, truncate_dim: Optional[int] = None
107
  ) -> Dict[str, torch.Tensor]:
108
  self.model.eval()
109
 
 
136
  text_embeddings = self.model(
137
  **text_batch, task_label=task
138
  ).single_vec_emb
139
+ if truncate_dim:
140
+ text_embeddings = text_embeddings[:, : truncate_dim]
141
  text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
142
  for i, embedding in enumerate(text_embeddings):
143
  all_embeddings.append((text_indices[i], embedding))
 
154
  img_embeddings = self.model(
155
  **image_batch, task_label=task
156
  ).single_vec_emb
157
+ if truncate_dim:
158
+ img_embeddings = img_embeddings[:, : truncate_dim]
159
  img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
160
 
161
  for i, embedding in enumerate(img_embeddings):
modeling_jina_embeddings_v4.py CHANGED
@@ -127,13 +127,11 @@ class JinaEmbeddingsV4ModelOutput:
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
130
- attention_mask (torch.Tensor, optional): Attention mask.
131
  """
132
 
133
  vlm_last_hidden_states: Optional[torch.Tensor] = None
134
  single_vec_emb: Optional[torch.Tensor] = None
135
  multi_vec_emb: Optional[torch.Tensor] = None
136
- attention_mask: Optional[torch.Tensor] = None
137
 
138
 
139
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
@@ -314,7 +312,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
314
  ),
315
  single_vec_emb=single_vec_emb,
316
  multi_vec_emb=multi_vec_emb,
317
- attention_mask=attention_mask,
318
  )
319
 
320
  def _process_batches(
@@ -345,17 +342,18 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
345
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
346
  ):
347
  embeddings = self(**batch, task_label=task_label)
348
- attention_mask = embeddings.attention_mask
349
  if not return_multivector:
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
353
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
354
  else:
355
  embeddings = embeddings.multi_vec_emb
 
356
  if return_multivector and not return_numpy:
357
- valid_tokens = attention_mask.bool()
358
- embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
 
 
359
  results.append(embeddings)
360
  else:
361
  results.append(
 
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
 
130
  """
131
 
132
  vlm_last_hidden_states: Optional[torch.Tensor] = None
133
  single_vec_emb: Optional[torch.Tensor] = None
134
  multi_vec_emb: Optional[torch.Tensor] = None
 
135
 
136
 
137
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
 
312
  ),
313
  single_vec_emb=single_vec_emb,
314
  multi_vec_emb=multi_vec_emb,
 
315
  )
316
 
317
  def _process_batches(
 
342
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
343
  ):
344
  embeddings = self(**batch, task_label=task_label)
 
345
  if not return_multivector:
346
  embeddings = embeddings.single_vec_emb
347
  if truncate_dim is not None:
348
  embeddings = embeddings[:, :truncate_dim]
 
349
  else:
350
  embeddings = embeddings.multi_vec_emb
351
+
352
  if return_multivector and not return_numpy:
353
+ valid_tokens = batch["attention_mask"].bool()
354
+ embeddings = [
355
+ emb[mask] for emb, mask in zip(embeddings, valid_tokens)
356
+ ]
357
  results.append(embeddings)
358
  else:
359
  results.append(
modules.json CHANGED
@@ -4,6 +4,6 @@
4
  "name": "transformer",
5
  "path": "",
6
  "type": "custom_st.Transformer",
7
- "kwargs": ["task"]
8
  }
9
  ]
 
4
  "name": "transformer",
5
  "path": "",
6
  "type": "custom_st.Transformer",
7
+ "kwargs": ["task", "truncate_dim"]
8
  }
9
  ]