nan commited on
Commit
a5838bd
·
verified ·
1 Parent(s): 4c2a7cb

feat-remove-paddings-0623 (#22)

Browse files

- fix: remove the padding tokens when a list of multivectors are returned (ef1876f5b9dbe290d7a58ff16fc37367217d32c5)
- fix: fix the bug when return_numpy is false (77d5a29ef8ef2396f56106d8ed882e322b7dc9be)
- fix: fix the bug when return_numpy is false (205b18f42bb9bd3ee57bab31cdd1b3116b6d762b)
- fix: fix the bug when return_numpy is false (6bb8cf2b7575b11b19c8f9780efda6f0b1a61708)
- fix: fix the bug (3ad717f7eaca1d26701063a16cfbe5f40ebaf551)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +29 -6
modeling_jina_embeddings_v4.py CHANGED
@@ -127,11 +127,13 @@ class JinaEmbeddingsV4ModelOutput:
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
 
130
  """
131
 
132
  vlm_last_hidden_states: Optional[torch.Tensor] = None
133
  single_vec_emb: Optional[torch.Tensor] = None
134
  multi_vec_emb: Optional[torch.Tensor] = None
 
135
 
136
 
137
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
@@ -312,6 +314,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
312
  ),
313
  single_vec_emb=single_vec_emb,
314
  multi_vec_emb=multi_vec_emb,
 
315
  )
316
 
317
  def _process_batches(
@@ -331,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
331
  shuffle=False,
332
  collate_fn=processor_fn,
333
  )
 
 
334
  results = []
335
  self.eval()
336
  for batch in tqdm(dataloader, desc=desc):
@@ -340,17 +345,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
 
343
  if not return_multivector:
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
347
  else:
348
  embeddings = embeddings.multi_vec_emb
349
- results.append(
350
- embeddings.cpu()
351
- if return_numpy
352
- else list(torch.unbind(embeddings))
353
- )
 
 
 
 
 
354
  if return_numpy:
355
  return np.concatenate([result.numpy() for result in results], axis=0)
356
  return [item for sublist in results for item in sublist]
@@ -436,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
436
  )
437
 
438
  return_list = isinstance(texts, list)
 
 
 
 
 
 
439
 
440
  if isinstance(texts, str):
441
  texts = [texts]
@@ -484,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
484
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
485
  batch_size: Number of images to process at once
486
  return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
487
- return_numpy: Whether to return numpy arrays instead of torch tensors
488
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
489
  max_pixels: Maximum number of pixels to process per image
490
 
@@ -501,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
501
 
502
  return_list = isinstance(images, list)
503
 
 
 
 
 
 
 
504
  # Convert single image to list
505
  if isinstance(images, (str, Image.Image)):
506
  images = [images]
 
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
130
+ attention_mask (torch.Tensor, optional): Attention mask.
131
  """
132
 
133
  vlm_last_hidden_states: Optional[torch.Tensor] = None
134
  single_vec_emb: Optional[torch.Tensor] = None
135
  multi_vec_emb: Optional[torch.Tensor] = None
136
+ attention_mask: Optional[torch.Tensor] = None
137
 
138
 
139
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
 
314
  ),
315
  single_vec_emb=single_vec_emb,
316
  multi_vec_emb=multi_vec_emb,
317
+ attention_mask=attention_mask,
318
  )
319
 
320
  def _process_batches(
 
334
  shuffle=False,
335
  collate_fn=processor_fn,
336
  )
337
+ if return_multivector and len(data) > 1:
338
+ assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
339
  results = []
340
  self.eval()
341
  for batch in tqdm(dataloader, desc=desc):
 
345
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
346
  ):
347
  embeddings = self(**batch, task_label=task_label)
348
+ attention_mask = embeddings.attention_mask
349
  if not return_multivector:
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
353
  else:
354
  embeddings = embeddings.multi_vec_emb
355
+ if return_multivector and not return_numpy:
356
+ valid_tokens = attention_mask.bool()
357
+ embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
358
+ results.append(embeddings)
359
+ else:
360
+ results.append(
361
+ embeddings.cpu()
362
+ if return_numpy
363
+ else list(torch.unbind(embeddings))
364
+ )
365
  if return_numpy:
366
  return np.concatenate([result.numpy() for result in results], axis=0)
367
  return [item for sublist in results for item in sublist]
 
447
  )
448
 
449
  return_list = isinstance(texts, list)
450
+
451
+ # If return_multivector is True and encoding multiple texts, ignore return_numpy
452
+ if return_multivector and return_list and len(texts) > 1:
453
+ if return_numpy:
454
+ print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
455
+ return_numpy = False
456
 
457
  if isinstance(texts, str):
458
  texts = [texts]
 
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
  return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
504
+ return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image
507
 
 
518
 
519
  return_list = isinstance(images, list)
520
 
521
+ # If return_multivector is True and encoding multiple images, ignore return_numpy
522
+ if return_multivector and return_list and len(images) > 1:
523
+ if return_numpy:
524
+ print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
525
+ return_numpy = False
526
+
527
  # Convert single image to list
528
  if isinstance(images, (str, Image.Image)):
529
  images = [images]