set env var for FSDP layer to wrap (#453)
Browse files
src/axolotl/utils/trainer.py
CHANGED
|
@@ -377,6 +377,10 @@ def setup_fsdp_envs(cfg):
|
|
| 377 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
| 378 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
| 379 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
|
| 382 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
|
|
| 377 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
| 378 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
| 379 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
| 380 |
+
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
| 381 |
+
os.environ[
|
| 382 |
+
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
| 383 |
+
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
| 384 |
|
| 385 |
|
| 386 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|