jupyterjazz commited on
Commit
7c77aab
·
verified ·
1 Parent(s): bfadc62

fix-matryoshka-normalization (#24)

Browse files

- fix: matryoshka normalization (a4de150f6b126c6bb858fd0b999cf4862b75d327)

Files changed (2) hide show
  1. custom_st.py +4 -5
  2. modeling_jina_embeddings_v4.py +1 -0
custom_st.py CHANGED
@@ -45,7 +45,6 @@ class Transformer(nn.Module):
45
  self.model = AutoModel.from_pretrained(
46
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
47
  )
48
-
49
  self.processor = AutoProcessor.from_pretrained(
50
  model_name_or_path,
51
  cache_dir=cache_dir,
@@ -133,14 +132,13 @@ class Transformer(nn.Module):
133
  if k.startswith("text_") and k != "text_indices"
134
  }
135
  text_indices = features.get("text_indices", [])
136
-
137
- with torch.autocast(device_type=device):
138
  text_embeddings = self.model(
139
  **text_batch, task_label=task
140
  ).single_vec_emb
141
  if self.config.truncate_dim:
142
  text_embeddings = text_embeddings[:, : self.config.truncate_dim]
143
-
144
  for i, embedding in enumerate(text_embeddings):
145
  all_embeddings.append((text_indices[i], embedding))
146
 
@@ -152,12 +150,13 @@ class Transformer(nn.Module):
152
  }
153
  image_indices = features.get("image_indices", [])
154
 
155
- with torch.autocast(device_type=device):
156
  img_embeddings = self.model(
157
  **image_batch, task_label=task
158
  ).single_vec_emb
159
  if self.config.truncate_dim:
160
  img_embeddings = img_embeddings[:, : self.config.truncate_dim]
 
161
 
162
  for i, embedding in enumerate(img_embeddings):
163
  all_embeddings.append((image_indices[i], embedding))
 
45
  self.model = AutoModel.from_pretrained(
46
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
47
  )
 
48
  self.processor = AutoProcessor.from_pretrained(
49
  model_name_or_path,
50
  cache_dir=cache_dir,
 
132
  if k.startswith("text_") and k != "text_indices"
133
  }
134
  text_indices = features.get("text_indices", [])
135
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
 
136
  text_embeddings = self.model(
137
  **text_batch, task_label=task
138
  ).single_vec_emb
139
  if self.config.truncate_dim:
140
  text_embeddings = text_embeddings[:, : self.config.truncate_dim]
141
+ text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
142
  for i, embedding in enumerate(text_embeddings):
143
  all_embeddings.append((text_indices[i], embedding))
144
 
 
150
  }
151
  image_indices = features.get("image_indices", [])
152
 
153
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
154
  img_embeddings = self.model(
155
  **image_batch, task_label=task
156
  ).single_vec_emb
157
  if self.config.truncate_dim:
158
  img_embeddings = img_embeddings[:, : self.config.truncate_dim]
159
+ img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
160
 
161
  for i, embedding in enumerate(img_embeddings):
162
  all_embeddings.append((image_indices[i], embedding))
modeling_jina_embeddings_v4.py CHANGED
@@ -350,6 +350,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
 
353
  else:
354
  embeddings = embeddings.multi_vec_emb
355
  if return_multivector and not return_numpy:
 
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
353
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
354
  else:
355
  embeddings = embeddings.multi_vec_emb
356
  if return_multivector and not return_numpy: