make sure everything stays in the same dtype when using dpo + FSDP (#1559)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -54,6 +54,7 @@ from axolotl.utils.collators import (
|
|
| 54 |
MambaDataCollator,
|
| 55 |
V2BatchSamplerDataCollatorForSeq2Seq,
|
| 56 |
)
|
|
|
|
| 57 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 58 |
from axolotl.utils.schedulers import (
|
| 59 |
get_cosine_schedule_with_min_lr,
|
|
@@ -1569,6 +1570,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1569 |
callbacks=self.get_callbacks(),
|
| 1570 |
**dpo_trainer_kwargs,
|
| 1571 |
)
|
|
|
|
|
|
|
|
|
|
| 1572 |
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
| 1573 |
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
| 1574 |
dpo_trainer.add_callback(callback)
|
|
|
|
| 54 |
MambaDataCollator,
|
| 55 |
V2BatchSamplerDataCollatorForSeq2Seq,
|
| 56 |
)
|
| 57 |
+
from axolotl.utils.models import ensure_dtype
|
| 58 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 59 |
from axolotl.utils.schedulers import (
|
| 60 |
get_cosine_schedule_with_min_lr,
|
|
|
|
| 1570 |
callbacks=self.get_callbacks(),
|
| 1571 |
**dpo_trainer_kwargs,
|
| 1572 |
)
|
| 1573 |
+
if self.cfg.fsdp:
|
| 1574 |
+
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
| 1575 |
+
|
| 1576 |
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
| 1577 |
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
| 1578 |
dpo_trainer.add_callback(callback)
|
src/axolotl/utils/models.py
CHANGED
|
@@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|
| 993 |
setup_quantized_peft_meta_for_training(model)
|
| 994 |
|
| 995 |
return model, lora_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
setup_quantized_peft_meta_for_training(model)
|
| 994 |
|
| 995 |
return model, lora_config
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
def ensure_dtype(model, dtype=torch.bfloat16):
|
| 999 |
+
for name, module in model.named_modules():
|
| 1000 |
+
try:
|
| 1001 |
+
if module.weight.dtype != dtype:
|
| 1002 |
+
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
| 1003 |
+
module.to(dtype)
|
| 1004 |
+
except AttributeError:
|
| 1005 |
+
pass
|