Apple silicon / Non-CUDA support (flash_attn)

#49
by simonsv - opened

When you try to load the model on apple silicon you get the error: ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

This is due to the fact that this model requires the use of flash_attn, which so far only supports machines with GPUs.

Is there a way to disable the Flash attention dependency to try this model on Mac OS?

Jina AI org

Hi @simonsv , this should be fixed now

@jupyterjazz ,
Roughly 4 days ago I was testing with Apple silicon, loading the config via config = transformers.AutoConfig(..., torch_dtype=torch.float16), modifying it with config.text_config._attn_implementation = "eager" to resolve the Flash Attention 2 error, and created model = transformers.AutoModel.from_pretrained(..., config=config). Similarly, testing sentence_transformers.SentenceTransformer(..., device="mps", config_kwargs=config.to_dict(), model_kwargs={"torch_dtype": torch.float16}). As you mentioned, your fix for Flash Attention 2 has removed the need to load and edit the config manually.

After downloading the fix however, I can no longer embed images via transformers.AutoModel(...).encode_image nor sentence_transformers.SentenceTransformer(...).encode. These were previously working but now I get an error with both. Below is the log for sentence_transformers although it and transformers both fail at modeling_jina_embeddings_v4.py:290.

Note: I've tried to manually load the old config file but hit the same error. The first time I loaded the model today I remember seeing the notice about new python files and I see a few have changed ~4 days ago. I plan to track the situation as I have time. Let me know if there are any tests you'd like me to run.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 2
      1 # Encode image/document
----> 2 image_embeddings = sentence_transformers_model.encode(
      3     sentences=["https://i.ibb.co/nQNGqL0/beach1.jpg"],
      4     task="retrieval",
      5 )
      7 print(f"image_embeddings.shape = {image_embeddings.shape}")

File .venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File .venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:1052, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, truncate_dim, pool, chunk_size, **kwargs)
   1049 features.update(extra_features)
   1051 with torch.no_grad():
-> 1052     out_features = self.forward(features, **kwargs)
   1053     if self.device.type == "hpu":
   1054         out_features = copy.deepcopy(out_features)

File .venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:1133, in SentenceTransformer.forward(self, input, **kwargs)
   1127             module_kwarg_keys = self.module_kwargs.get(module_name, [])
   1128         module_kwargs = {
   1129             key: value
   1130             for key, value in kwargs.items()
   1131             if key in module_kwarg_keys or (hasattr(module, "forward_kwargs") and key in module.forward_kwargs)
   1132         }
-> 1133     input = module(input, **module_kwargs)
   1134 return input

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File custom_st.py:162, in Transformer.forward(self, features, task, truncate_dim)
    159 image_indices = features.get("image_indices", [])
    161 with torch.autocast(device_type=device, dtype=torch.bfloat16):
--> 162     img_embeddings = self.model(
    163         **image_batch, task_label=task
    164     ).single_vec_emb
    165     if truncate_dim:
    166         img_embeddings = img_embeddings[:, :truncate_dim]

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File .venv/lib/python3.10/site-packages/peft/peft_model.py:2759, in PeftModelForFeatureExtraction.forward(self, input_ids, attention_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   2757     with self._enable_peft_forward_hooks(**kwargs):
   2758         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 2759         return self.base_model(
   2760             input_ids=input_ids,
   2761             attention_mask=attention_mask,
   2762             inputs_embeds=inputs_embeds,
   2763             output_attentions=output_attentions,
   2764             output_hidden_states=output_hidden_states,
   2765             return_dict=return_dict,
   2766             **kwargs,
   2767         )
   2769 batch_size = _get_batch_size(input_ids, inputs_embeds)
   2770 if attention_mask is not None:
   2771     # concat prompt attention mask

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File .venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:193, in BaseTuner.forward(self, *args, **kwargs)
    192 def forward(self, *args: Any, **kwargs: Any):
--> 193     return self.model.forward(*args, **kwargs)

File modeling_jina_embeddings_v4.py:290, in JinaEmbeddingsV4Model.forward(self, task_label, input_ids, attention_mask, output_vlm_last_hidden_states, **kwargs)
    278 """
    279 Forward pass through the model. Returns both single-vector and multi-vector embeddings.
    280 Args:
   (...)
    287         multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
    288 """
    289 # Forward pass through the VLM
--> 290 hidden_states = self.get_last_hidden_states(
    291     input_ids=input_ids,
    292     attention_mask=attention_mask,
    293     task_label=task_label,
    294     **kwargs,
    295 )  # (batch_size, seq_length, hidden_size)
    296 # Compute the embeddings
    297 single_vec_emb = self.get_single_vector_embeddings(
    298     hidden_states=hidden_states,
    299     attention_mask=attention_mask,
    300     input_ids=input_ids,
    301 )

File modeling_jina_embeddings_v4.py:189, in JinaEmbeddingsV4Model.get_last_hidden_states(self, task_label, input_ids, attention_mask, **kwargs)
    182 position_ids, rope_deltas = self.model.get_rope_index(
    183     input_ids=input_ids,
    184     image_grid_thw=kwargs.get("image_grid_thw", None),
    185     attention_mask=attention_mask,
    186 )
    188 kwargs["output_hidden_states"] = True
--> 189 outputs = super().forward(
    190     task_label=task_label,
    191     input_ids=input_ids,
    192     attention_mask=attention_mask,
    193     **kwargs,
    194     position_ids=position_ids,
    195     rope_deltas=rope_deltas,
    196     use_cache=False,
    197 )
    199 hidden_states = outputs.hidden_states
    200 if not hidden_states:

File .venv/lib/python3.10/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    940     set_attribute_for_modules(self, "_is_top_level_module", False)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    945         output = output.to_tuple()

File qwen2_5_vl.py:2235, in Qwen2_5_VLForConditionalGeneration.forward(self, task_label, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts)
   2230 output_hidden_states = (
   2231     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
   2232 )
   2233 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 2235 outputs = self.model(
   2236     task_label=task_label,
   2237     input_ids=input_ids,
   2238     pixel_values=pixel_values,
   2239     pixel_values_videos=pixel_values_videos,
   2240     image_grid_thw=image_grid_thw,
   2241     video_grid_thw=video_grid_thw,
   2242     second_per_grid_ts=second_per_grid_ts,
   2243     position_ids=position_ids,
   2244     attention_mask=attention_mask,
   2245     past_key_values=past_key_values,
   2246     inputs_embeds=inputs_embeds,
   2247     use_cache=use_cache,
   2248     output_attentions=output_attentions,
   2249     output_hidden_states=output_hidden_states,
   2250     return_dict=return_dict,
   2251     cache_position=cache_position,
   2252 )
   2254 hidden_states = outputs[0]
   2255 logits = self.lm_head(hidden_states)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File qwen2_5_vl.py:1996, in Qwen2_5_VLModel.forward(self, task_label, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts)
   1993     image_mask = mask_expanded.to(inputs_embeds.device)
   1995     image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
-> 1996     inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
   1998 if pixel_values_videos is not None:
   1999     video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)

RuntimeError: Unexpected floating ScalarType in at::autocast::prioritize
Jina AI org

Hi @zboyles , thanks for reporting this. I was unable to reproduce the issue in my env. Can you help me reproduce it? I’d like to know your env details, especially the torch version, and also the code you’re running if it's different from the examples in the readme

Sign up or log in to comment