Apple silicon / Non-CUDA support (flash_attn)
When you try to load the model on apple silicon you get the error: ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
This is due to the fact that this model requires the use of flash_attn, which so far only supports machines with GPUs.
Is there a way to disable the Flash attention dependency to try this model on Mac OS?
@jupyterjazz
,
Roughly 4 days ago I was testing with Apple silicon, loading the config via config = transformers.AutoConfig(..., torch_dtype=torch.float16)
, modifying it with config.text_config._attn_implementation = "eager"
to resolve the Flash Attention 2 error, and created model = transformers.AutoModel.from_pretrained(..., config=config)
. Similarly, testing sentence_transformers.SentenceTransformer(..., device="mps", config_kwargs=config.to_dict(), model_kwargs={"torch_dtype": torch.float16})
. As you mentioned, your fix for Flash Attention 2 has removed the need to load and edit the config manually.
After downloading the fix however, I can no longer embed images via transformers.AutoModel(...).encode_image
nor sentence_transformers.SentenceTransformer(...).encode
. These were previously working but now I get an error with both. Below is the log for sentence_transformers
although it and transformers
both fail at modeling_jina_embeddings_v4.py:290
.
Note: I've tried to manually load the old config file but hit the same error. The first time I loaded the model today I remember seeing the notice about new python files and I see a few have changed ~4 days ago. I plan to track the situation as I have time. Let me know if there are any tests you'd like me to run.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[9], line 2
1 # Encode image/document
----> 2 image_embeddings = sentence_transformers_model.encode(
3 sentences=["https://i.ibb.co/nQNGqL0/beach1.jpg"],
4 task="retrieval",
5 )
7 print(f"image_embeddings.shape = {image_embeddings.shape}")
File .venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File .venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:1052, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, truncate_dim, pool, chunk_size, **kwargs)
1049 features.update(extra_features)
1051 with torch.no_grad():
-> 1052 out_features = self.forward(features, **kwargs)
1053 if self.device.type == "hpu":
1054 out_features = copy.deepcopy(out_features)
File .venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:1133, in SentenceTransformer.forward(self, input, **kwargs)
1127 module_kwarg_keys = self.module_kwargs.get(module_name, [])
1128 module_kwargs = {
1129 key: value
1130 for key, value in kwargs.items()
1131 if key in module_kwarg_keys or (hasattr(module, "forward_kwargs") and key in module.forward_kwargs)
1132 }
-> 1133 input = module(input, **module_kwargs)
1134 return input
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File custom_st.py:162, in Transformer.forward(self, features, task, truncate_dim)
159 image_indices = features.get("image_indices", [])
161 with torch.autocast(device_type=device, dtype=torch.bfloat16):
--> 162 img_embeddings = self.model(
163 **image_batch, task_label=task
164 ).single_vec_emb
165 if truncate_dim:
166 img_embeddings = img_embeddings[:, :truncate_dim]
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File .venv/lib/python3.10/site-packages/peft/peft_model.py:2759, in PeftModelForFeatureExtraction.forward(self, input_ids, attention_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
2757 with self._enable_peft_forward_hooks(**kwargs):
2758 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 2759 return self.base_model(
2760 input_ids=input_ids,
2761 attention_mask=attention_mask,
2762 inputs_embeds=inputs_embeds,
2763 output_attentions=output_attentions,
2764 output_hidden_states=output_hidden_states,
2765 return_dict=return_dict,
2766 **kwargs,
2767 )
2769 batch_size = _get_batch_size(input_ids, inputs_embeds)
2770 if attention_mask is not None:
2771 # concat prompt attention mask
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File .venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:193, in BaseTuner.forward(self, *args, **kwargs)
192 def forward(self, *args: Any, **kwargs: Any):
--> 193 return self.model.forward(*args, **kwargs)
File modeling_jina_embeddings_v4.py:290, in JinaEmbeddingsV4Model.forward(self, task_label, input_ids, attention_mask, output_vlm_last_hidden_states, **kwargs)
278 """
279 Forward pass through the model. Returns both single-vector and multi-vector embeddings.
280 Args:
(...)
287 multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
288 """
289 # Forward pass through the VLM
--> 290 hidden_states = self.get_last_hidden_states(
291 input_ids=input_ids,
292 attention_mask=attention_mask,
293 task_label=task_label,
294 **kwargs,
295 ) # (batch_size, seq_length, hidden_size)
296 # Compute the embeddings
297 single_vec_emb = self.get_single_vector_embeddings(
298 hidden_states=hidden_states,
299 attention_mask=attention_mask,
300 input_ids=input_ids,
301 )
File modeling_jina_embeddings_v4.py:189, in JinaEmbeddingsV4Model.get_last_hidden_states(self, task_label, input_ids, attention_mask, **kwargs)
182 position_ids, rope_deltas = self.model.get_rope_index(
183 input_ids=input_ids,
184 image_grid_thw=kwargs.get("image_grid_thw", None),
185 attention_mask=attention_mask,
186 )
188 kwargs["output_hidden_states"] = True
--> 189 outputs = super().forward(
190 task_label=task_label,
191 input_ids=input_ids,
192 attention_mask=attention_mask,
193 **kwargs,
194 position_ids=position_ids,
195 rope_deltas=rope_deltas,
196 use_cache=False,
197 )
199 hidden_states = outputs.hidden_states
200 if not hidden_states:
File .venv/lib/python3.10/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
940 set_attribute_for_modules(self, "_is_top_level_module", False)
942 try:
--> 943 output = func(self, *args, **kwargs)
944 if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
945 output = output.to_tuple()
File qwen2_5_vl.py:2235, in Qwen2_5_VLForConditionalGeneration.forward(self, task_label, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts)
2230 output_hidden_states = (
2231 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2232 )
2233 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 2235 outputs = self.model(
2236 task_label=task_label,
2237 input_ids=input_ids,
2238 pixel_values=pixel_values,
2239 pixel_values_videos=pixel_values_videos,
2240 image_grid_thw=image_grid_thw,
2241 video_grid_thw=video_grid_thw,
2242 second_per_grid_ts=second_per_grid_ts,
2243 position_ids=position_ids,
2244 attention_mask=attention_mask,
2245 past_key_values=past_key_values,
2246 inputs_embeds=inputs_embeds,
2247 use_cache=use_cache,
2248 output_attentions=output_attentions,
2249 output_hidden_states=output_hidden_states,
2250 return_dict=return_dict,
2251 cache_position=cache_position,
2252 )
2254 hidden_states = outputs[0]
2255 logits = self.lm_head(hidden_states)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File qwen2_5_vl.py:1996, in Qwen2_5_VLModel.forward(self, task_label, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts)
1993 image_mask = mask_expanded.to(inputs_embeds.device)
1995 image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
-> 1996 inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1998 if pixel_values_videos is not None:
1999 video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
RuntimeError: Unexpected floating ScalarType in at::autocast::prioritize
Hi @zboyles , thanks for reporting this. I was unable to reproduce the issue in my env. Can you help me reproduce it? I’d like to know your env details, especially the torch version, and also the code you’re running if it's different from the examples in the readme