Troubleshooting RuntimeError in Sentence Transformers with Multi-Process Pool

#85
by SamMaggioli - opened

Troubleshooting RuntimeError in Sentence Transformers with Multi-Process Pool

Background

When using sentence transformers with multi_process_pool, an error has been encountered:
RuntimeError: Serialization of parametrized modules is only supported through state_dict()

Current Situation

This error is preventing the successful execution of our code that utilizes sentence transformers in a multi-process environment.

Error Details

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 11
      4 model = SentenceTransformer("jinaai/jina-embeddings-v3", 
      5     trust_remote_code=True, 
      6     model_kwargs={'default_task': 'text-matching'}, 
      7     device='cuda',
      8     truncate_dim=128)
     10 # Start the multi-process pool on all available CUDA devices
---> 11 pool = model.start_multi_process_pool()
     13 # Compute the embeddings using the multi-process pool
     14 emb = model.encode_multi_process(sentences, pool, show_progress_bar=True, batch_size=32)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:857, in SentenceTransformer.start_multi_process_pool(self, target_devices)
    851 for device_id in target_devices:
    852     p = ctx.Process(
    853         target=SentenceTransformer._encode_multi_process_worker,
    854         args=(device_id, self, input_queue, output_queue),
    855         daemon=True,
    856     )
--> 857     p.start()
    858     processes.append(p)
    860 return {"input": input_queue, "output": output_queue, "processes": processes}

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/process.py:121, in BaseProcess.start(self)
    118 assert not _current_process._config.get('daemon'), \
    119        'daemonic processes are not allowed to have children'
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel
    123 # Avoid a refcycle if the target function holds an indirect
    124 # reference to the process object (see bpo-30775)

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/context.py:288, in SpawnProcess._Popen(process_obj)
    285 @staticmethod
    286 def _Popen(process_obj):
    287     from .popen_spawn_posix import Popen
--> 288     return Popen(process_obj)

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/popen_spawn_posix.py:32, in Popen.__init__(self, process_obj)
     30 def __init__(self, process_obj):
     31     self._fds = []
---> 32     super().__init__(process_obj)

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
     17 self.returncode = None
     18 self.finalizer = None
---> 19 self._launch(process_obj)

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj)
     45 try:
     46     reduction.dump(prep_data, fp)
---> 47     reduction.dump(process_obj, fp)
     48 finally:
     49     set_spawning_popen(None)

File /opt/conda/envs/pytorch/lib/python3.10/multiprocessing/reduction.py:60, in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/utils/parametrize.py:340, in _inject_new_class.<locals>.getstate(self)
    339 def getstate(self):
--> 340     raise RuntimeError(
    341         "Serialization of parametrized modules is only "
    342         "supported through state_dict(). See:\n"
    343         "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
    344         "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
    345     )

RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
Jina AI org

Hi @SamMaggioli , this error happens because of the custom LoRA layers that cannot be serialized. There are several ways to address this issue. It was already discussed here, and I've suggested a possible solution in that thread. Hope this helps!

Sign up or log in to comment