feat-remove-paddings-0623 (#22)
Browse files- fix: remove the padding tokens when a list of multivectors are returned (ef1876f5b9dbe290d7a58ff16fc37367217d32c5)
- fix: fix the bug when return_numpy is false (77d5a29ef8ef2396f56106d8ed882e322b7dc9be)
- fix: fix the bug when return_numpy is false (205b18f42bb9bd3ee57bab31cdd1b3116b6d762b)
- fix: fix the bug when return_numpy is false (6bb8cf2b7575b11b19c8f9780efda6f0b1a61708)
- fix: fix the bug (3ad717f7eaca1d26701063a16cfbe5f40ebaf551)
modeling_jina_embeddings_v4.py
CHANGED
@@ -127,11 +127,13 @@ class JinaEmbeddingsV4ModelOutput:
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
|
|
130 |
"""
|
131 |
|
132 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
133 |
single_vec_emb: Optional[torch.Tensor] = None
|
134 |
multi_vec_emb: Optional[torch.Tensor] = None
|
|
|
135 |
|
136 |
|
137 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
@@ -312,6 +314,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
312 |
),
|
313 |
single_vec_emb=single_vec_emb,
|
314 |
multi_vec_emb=multi_vec_emb,
|
|
|
315 |
)
|
316 |
|
317 |
def _process_batches(
|
@@ -331,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
331 |
shuffle=False,
|
332 |
collate_fn=processor_fn,
|
333 |
)
|
|
|
|
|
334 |
results = []
|
335 |
self.eval()
|
336 |
for batch in tqdm(dataloader, desc=desc):
|
@@ -340,17 +345,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
341 |
):
|
342 |
embeddings = self(**batch, task_label=task_label)
|
|
|
343 |
if not return_multivector:
|
344 |
embeddings = embeddings.single_vec_emb
|
345 |
if truncate_dim is not None:
|
346 |
embeddings = embeddings[:, :truncate_dim]
|
347 |
else:
|
348 |
embeddings = embeddings.multi_vec_emb
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
354 |
if return_numpy:
|
355 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
356 |
return [item for sublist in results for item in sublist]
|
@@ -436,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
436 |
)
|
437 |
|
438 |
return_list = isinstance(texts, list)
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
if isinstance(texts, str):
|
441 |
texts = [texts]
|
@@ -484,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
484 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
485 |
batch_size: Number of images to process at once
|
486 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
487 |
-
return_numpy: Whether to return numpy arrays instead of torch tensors
|
488 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
489 |
max_pixels: Maximum number of pixels to process per image
|
490 |
|
@@ -501,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
501 |
|
502 |
return_list = isinstance(images, list)
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
# Convert single image to list
|
505 |
if isinstance(images, (str, Image.Image)):
|
506 |
images = [images]
|
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
130 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
131 |
"""
|
132 |
|
133 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
134 |
single_vec_emb: Optional[torch.Tensor] = None
|
135 |
multi_vec_emb: Optional[torch.Tensor] = None
|
136 |
+
attention_mask: Optional[torch.Tensor] = None
|
137 |
|
138 |
|
139 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
|
314 |
),
|
315 |
single_vec_emb=single_vec_emb,
|
316 |
multi_vec_emb=multi_vec_emb,
|
317 |
+
attention_mask=attention_mask,
|
318 |
)
|
319 |
|
320 |
def _process_batches(
|
|
|
334 |
shuffle=False,
|
335 |
collate_fn=processor_fn,
|
336 |
)
|
337 |
+
if return_multivector and len(data) > 1:
|
338 |
+
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
339 |
results = []
|
340 |
self.eval()
|
341 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
345 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
346 |
):
|
347 |
embeddings = self(**batch, task_label=task_label)
|
348 |
+
attention_mask = embeddings.attention_mask
|
349 |
if not return_multivector:
|
350 |
embeddings = embeddings.single_vec_emb
|
351 |
if truncate_dim is not None:
|
352 |
embeddings = embeddings[:, :truncate_dim]
|
353 |
else:
|
354 |
embeddings = embeddings.multi_vec_emb
|
355 |
+
if return_multivector and not return_numpy:
|
356 |
+
valid_tokens = attention_mask.bool()
|
357 |
+
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
358 |
+
results.append(embeddings)
|
359 |
+
else:
|
360 |
+
results.append(
|
361 |
+
embeddings.cpu()
|
362 |
+
if return_numpy
|
363 |
+
else list(torch.unbind(embeddings))
|
364 |
+
)
|
365 |
if return_numpy:
|
366 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
367 |
return [item for sublist in results for item in sublist]
|
|
|
447 |
)
|
448 |
|
449 |
return_list = isinstance(texts, list)
|
450 |
+
|
451 |
+
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
452 |
+
if return_multivector and return_list and len(texts) > 1:
|
453 |
+
if return_numpy:
|
454 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
455 |
+
return_numpy = False
|
456 |
|
457 |
if isinstance(texts, str):
|
458 |
texts = [texts]
|
|
|
501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
502 |
batch_size: Number of images to process at once
|
503 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
504 |
+
return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
|
505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
506 |
max_pixels: Maximum number of pixels to process per image
|
507 |
|
|
|
518 |
|
519 |
return_list = isinstance(images, list)
|
520 |
|
521 |
+
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
522 |
+
if return_multivector and return_list and len(images) > 1:
|
523 |
+
if return_numpy:
|
524 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
|
525 |
+
return_numpy = False
|
526 |
+
|
527 |
# Convert single image to list
|
528 |
if isinstance(images, (str, Image.Image)):
|
529 |
images = [images]
|