WAN 2.2 14B training broken

#14
by wouterverweirder - opened

I'm not able to run WAN 2.2 14B training. WAN 2.2 5B works. All 14B training runs exit with this error:

Error running job: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!


========================================

Result:

 - 0 completed jobs

 - 1 failure

========================================


wan_2_2_i2v_woven_fabric_01:   0%|          | 0/3000 [00:00<?, ?it/s]Traceback (most recent call last):

  File "/ai-toolkit/run.py", line 120, in <module>

    main()

  File "/ai-toolkit/run.py", line 108, in main

    raise e

  File "/ai-toolkit/run.py", line 96, in main

    job.run()

  File "/ai-toolkit/jobs/ExtensionJob.py", line 22, in run

    process.run()

  File "/ai-toolkit/jobs/process/BaseSDTrainProcess.py", line 2154, in run

    loss_dict = self.hook_train_loop(batch_list)

                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 2023, in hook_train_loop

    loss = self.train_single_accumulation(batch)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 1930, in train_single_accumulation

    noise_pred = self.predict_noise(

                 ^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 1175, in predict_noise

    return self.sd.predict_noise(

           ^^^^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/toolkit/models/base_model.py", line 914, in predict_noise

    noise_pred = self.get_noise_prediction(

                 ^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py", line 137, in get_noise_prediction

    noise_pred = self.model(

                 ^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl

    return forward_call(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py", line 149, in forward

    return self.transformer(

           ^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl

    return forward_call(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 675, in forward

    hidden_states = self._gradient_checkpointing_func(

                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/modeling_utils.py", line 305, in _gradient_checkpointing_func

    return torch.utils.checkpoint.checkpoint(

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner

    return disable_fn(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn

    return fn(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint

    ret = function(*args, **kwargs)

          ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl

    return forward_call(*args, **kwargs)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 478, in forward

    self.scale_shift_table + temb.float()

    ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I worked around it by enabling the Low VRAM option.

Sign up or log in to comment