Is there a way to train Jina-embeddings-v3 using Transformers Trainer API? (Encountering NameError in rotary.py)

#129
by saadaltohamy - opened

Hi everyone,

I'm attempting to fine-tune the Jina-embeddings-v3 for a sequence classification task. My goal is to use the standard Hugging Face transformers.Trainer API for this process.

To achieve this, I've been following these steps:

  1. Loading the base model using AutoModel.from_pretrained("Jina-embeddings-v3", trust_remote_code=True, config=modified_config), where modified_config might have lora_main_params_trainable adjusted.
  2. Creating a custom PyTorch nn.Module wrapper around this base Jina model. This wrapper:
    • Adds a classification head (nn.Linear).
    • Implements mean pooling over the base model's last_hidden_state (as the model doesn't use a CLS token for sequence representation).
    • Handles passing an adapter_mask to the base model's forward pass if a specific LoRA adaptation (like "classification") is targeted.
    • Computes CrossEntropyLoss if labels are provided.
  3. Setting requires_grad appropriately to fine-tune only the LoRA adapter parameters (identified by "parametrizations" in their names) and the new classification head, keeping other base model weights frozen (if lora_main_params_trainable=False).

The forward pass of my custom wrapped model works correctly when tested with tokenized inputs. However, when I try to train it using trainer.train(), I consistently encounter a NameError during the backward pass.

Here's the relevant part of the traceback:

Traceback (most recent call last):
  File "/path/to/your/script.py", line X, in <your_training_initiation_function_or_cell> # Adjust path
    trainer.train()
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2245, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2560, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3782, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/accelerate/accelerator.py", line 2359, in backward
    loss.backward(**kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/2b6bc3f30750b3a9648fe9b63448c09920efe9be/rotary.py", line 257, in backward
    apply_rotary( # <--- THIS IS THE LINE CAUSING THE NameError
    ^^^^^^^^^^^^
NameError: name 'apply_rotary' is not defined

The error NameError: name 'apply_rotary' is not defined seems to originate from the model's custom rotary.py file (specifically the one cached from jinaai/xlm-roberta-flash-implementation). This suggests an issue within the model's own implementation of the backward pass for its rotary position embeddings, where the apply_rotary function is not accessible within the scope of the custom torch.autograd.Function's backward method.

My questions are:

  1. Has anyone else encountered this NameError when trying to fine-tune this Jina model (or related Jina embedding models that might use the same rotary.py module)?
  2. Is there a known fix or workaround for this issue that would allow training with the transformers.Trainer?
  3. Given this internal error, what is the recommended approach by Jina AI for fine-tuning these models for a downstream task like sequence classification if the standard Trainer approach hits this roadblock? (The model card mentions SentenceTransformerTrainer, but my aim was to use the standard transformers.Trainer with a classification head.)

Any insights or suggestions would be greatly appreciated!

Thanks!

Sign up or log in to comment