Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
WAN 2.2 14B training broken
#14
by
wouterverweirder
- opened
I'm not able to run WAN 2.2 14B training. WAN 2.2 5B works. All 14B training runs exit with this error:
Error running job: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
========================================
Result:
- 0 completed jobs
- 1 failure
========================================
wan_2_2_i2v_woven_fabric_01: 0%| | 0/3000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/ai-toolkit/run.py", line 120, in <module>
main()
File "/ai-toolkit/run.py", line 108, in main
raise e
File "/ai-toolkit/run.py", line 96, in main
job.run()
File "/ai-toolkit/jobs/ExtensionJob.py", line 22, in run
process.run()
File "/ai-toolkit/jobs/process/BaseSDTrainProcess.py", line 2154, in run
loss_dict = self.hook_train_loop(batch_list)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 2023, in hook_train_loop
loss = self.train_single_accumulation(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 1930, in train_single_accumulation
noise_pred = self.predict_noise(
^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py", line 1175, in predict_noise
return self.sd.predict_noise(
^^^^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/toolkit/models/base_model.py", line 914, in predict_noise
noise_pred = self.get_noise_prediction(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py", line 137, in get_noise_prediction
noise_pred = self.model(
^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py", line 149, in forward
return self.transformer(
^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 675, in forward
hidden_states = self._gradient_checkpointing_func(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/modeling_utils.py", line 305, in _gradient_checkpointing_func
return torch.utils.checkpoint.checkpoint(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint
ret = function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/uv/environments-v2/script-912247c0edd68a55/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 478, in forward
self.scale_shift_table + temb.float()
~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I worked around it by enabling the Low VRAM option.