Commit
·
12eb796
1
Parent(s):
7c77aab
refactor: multi-vec, st truncation, etc
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- config.json +1 -1
- custom_st.py +5 -5
- modeling_jina_embeddings_v4.py +5 -7
- modules.json +1 -1
config.json
CHANGED
@@ -55,6 +55,6 @@
|
|
55 |
"vocab_size": 151936,
|
56 |
"truncate_dim": null,
|
57 |
"task_names": ["retrieval", "text-matching", "code"],
|
58 |
-
"matryoshka_dims": [128, 256, 512, 1024],
|
59 |
"_attn_implementation": "flash_attention_2"
|
60 |
}
|
|
|
55 |
"vocab_size": 151936,
|
56 |
"truncate_dim": null,
|
57 |
"task_names": ["retrieval", "text-matching", "code"],
|
58 |
+
"matryoshka_dims": [128, 256, 512, 1024, 2048],
|
59 |
"_attn_implementation": "flash_attention_2"
|
60 |
}
|
custom_st.py
CHANGED
@@ -103,7 +103,7 @@ class Transformer(nn.Module):
|
|
103 |
return encoding
|
104 |
|
105 |
def forward(
|
106 |
-
self, features: Dict[str, torch.Tensor], task: Optional[str] = None
|
107 |
) -> Dict[str, torch.Tensor]:
|
108 |
self.model.eval()
|
109 |
|
@@ -136,8 +136,8 @@ class Transformer(nn.Module):
|
|
136 |
text_embeddings = self.model(
|
137 |
**text_batch, task_label=task
|
138 |
).single_vec_emb
|
139 |
-
if
|
140 |
-
text_embeddings = text_embeddings[:, :
|
141 |
text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
|
142 |
for i, embedding in enumerate(text_embeddings):
|
143 |
all_embeddings.append((text_indices[i], embedding))
|
@@ -154,8 +154,8 @@ class Transformer(nn.Module):
|
|
154 |
img_embeddings = self.model(
|
155 |
**image_batch, task_label=task
|
156 |
).single_vec_emb
|
157 |
-
if
|
158 |
-
img_embeddings = img_embeddings[:, :
|
159 |
img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
|
160 |
|
161 |
for i, embedding in enumerate(img_embeddings):
|
|
|
103 |
return encoding
|
104 |
|
105 |
def forward(
|
106 |
+
self, features: Dict[str, torch.Tensor], task: Optional[str] = None, truncate_dim: Optional[int] = None
|
107 |
) -> Dict[str, torch.Tensor]:
|
108 |
self.model.eval()
|
109 |
|
|
|
136 |
text_embeddings = self.model(
|
137 |
**text_batch, task_label=task
|
138 |
).single_vec_emb
|
139 |
+
if truncate_dim:
|
140 |
+
text_embeddings = text_embeddings[:, : truncate_dim]
|
141 |
text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
|
142 |
for i, embedding in enumerate(text_embeddings):
|
143 |
all_embeddings.append((text_indices[i], embedding))
|
|
|
154 |
img_embeddings = self.model(
|
155 |
**image_batch, task_label=task
|
156 |
).single_vec_emb
|
157 |
+
if truncate_dim:
|
158 |
+
img_embeddings = img_embeddings[:, : truncate_dim]
|
159 |
img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
|
160 |
|
161 |
for i, embedding in enumerate(img_embeddings):
|
modeling_jina_embeddings_v4.py
CHANGED
@@ -127,13 +127,11 @@ class JinaEmbeddingsV4ModelOutput:
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
130 |
-
attention_mask (torch.Tensor, optional): Attention mask.
|
131 |
"""
|
132 |
|
133 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
134 |
single_vec_emb: Optional[torch.Tensor] = None
|
135 |
multi_vec_emb: Optional[torch.Tensor] = None
|
136 |
-
attention_mask: Optional[torch.Tensor] = None
|
137 |
|
138 |
|
139 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
@@ -314,7 +312,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
314 |
),
|
315 |
single_vec_emb=single_vec_emb,
|
316 |
multi_vec_emb=multi_vec_emb,
|
317 |
-
attention_mask=attention_mask,
|
318 |
)
|
319 |
|
320 |
def _process_batches(
|
@@ -345,17 +342,18 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
345 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
346 |
):
|
347 |
embeddings = self(**batch, task_label=task_label)
|
348 |
-
attention_mask = embeddings.attention_mask
|
349 |
if not return_multivector:
|
350 |
embeddings = embeddings.single_vec_emb
|
351 |
if truncate_dim is not None:
|
352 |
embeddings = embeddings[:, :truncate_dim]
|
353 |
-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
|
354 |
else:
|
355 |
embeddings = embeddings.multi_vec_emb
|
|
|
356 |
if return_multivector and not return_numpy:
|
357 |
-
valid_tokens = attention_mask.bool()
|
358 |
-
embeddings = [
|
|
|
|
|
359 |
results.append(embeddings)
|
360 |
else:
|
361 |
results.append(
|
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
|
|
130 |
"""
|
131 |
|
132 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
133 |
single_vec_emb: Optional[torch.Tensor] = None
|
134 |
multi_vec_emb: Optional[torch.Tensor] = None
|
|
|
135 |
|
136 |
|
137 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
|
312 |
),
|
313 |
single_vec_emb=single_vec_emb,
|
314 |
multi_vec_emb=multi_vec_emb,
|
|
|
315 |
)
|
316 |
|
317 |
def _process_batches(
|
|
|
342 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
343 |
):
|
344 |
embeddings = self(**batch, task_label=task_label)
|
|
|
345 |
if not return_multivector:
|
346 |
embeddings = embeddings.single_vec_emb
|
347 |
if truncate_dim is not None:
|
348 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
349 |
else:
|
350 |
embeddings = embeddings.multi_vec_emb
|
351 |
+
|
352 |
if return_multivector and not return_numpy:
|
353 |
+
valid_tokens = batch["attention_mask"].bool()
|
354 |
+
embeddings = [
|
355 |
+
emb[mask] for emb, mask in zip(embeddings, valid_tokens)
|
356 |
+
]
|
357 |
results.append(embeddings)
|
358 |
else:
|
359 |
results.append(
|
modules.json
CHANGED
@@ -4,6 +4,6 @@
|
|
4 |
"name": "transformer",
|
5 |
"path": "",
|
6 |
"type": "custom_st.Transformer",
|
7 |
-
"kwargs": ["task"]
|
8 |
}
|
9 |
]
|
|
|
4 |
"name": "transformer",
|
5 |
"path": "",
|
6 |
"type": "custom_st.Transformer",
|
7 |
+
"kwargs": ["task", "truncate_dim"]
|
8 |
}
|
9 |
]
|