support galore once upstreamed into transformers (#1409)
Browse files* support galore once upstreamed into transformers
* update module name for llama in readme and fix typing for all linear
* bump trl for deprecation fixes from newer transformers
* include galore as an extra and install in docker image
* fix optim_args type
* fix optim_args
* update dependencies for galore
* add galore to cicd dockerfile
- README.md +19 -0
- cicd/Dockerfile.jinja +2 -2
- docker/Dockerfile +2 -2
- requirements.txt +2 -2
- setup.py +3 -0
- src/axolotl/core/trainer_builder.py +14 -1
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +9 -0
README.md
CHANGED
|
@@ -907,7 +907,26 @@ lr_div_factor: # Learning rate div factor
|
|
| 907 |
# - paged_adamw_8bit
|
| 908 |
# - paged_lion_32bit
|
| 909 |
# - paged_lion_8bit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
optimizer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
# Specify weight decay
|
| 912 |
weight_decay:
|
| 913 |
# adamw hyperparams
|
|
|
|
| 907 |
# - paged_adamw_8bit
|
| 908 |
# - paged_lion_32bit
|
| 909 |
# - paged_lion_8bit
|
| 910 |
+
# - galore_adamw
|
| 911 |
+
# - galore_adamw_8bit
|
| 912 |
+
# - galore_adafactor
|
| 913 |
+
# - galore_adamw_layerwise
|
| 914 |
+
# - galore_adamw_8bit_layerwise
|
| 915 |
+
# - galore_adafactor_layerwise
|
| 916 |
optimizer:
|
| 917 |
+
# Dictionary of arguments to pass to the optimizer
|
| 918 |
+
optim_args:
|
| 919 |
+
# For Galore Optimizers the following optim_args are available
|
| 920 |
+
# rank: # type: int
|
| 921 |
+
# update_proj_gap # type: int
|
| 922 |
+
# scale # type: float
|
| 923 |
+
# proj_type: # type: str, default = std
|
| 924 |
+
|
| 925 |
+
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
| 926 |
+
optim_target_modules:
|
| 927 |
+
# - self_attn # for llama
|
| 928 |
+
# - mlp
|
| 929 |
+
|
| 930 |
# Specify weight decay
|
| 931 |
weight_decay:
|
| 932 |
# adamw hyperparams
|
cicd/Dockerfile.jinja
CHANGED
|
@@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|
| 23 |
|
| 24 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 25 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 26 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 27 |
else \
|
| 28 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
| 29 |
fi
|
| 30 |
|
| 31 |
# So we can test the Docker image
|
|
|
|
| 23 |
|
| 24 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 25 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 26 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 27 |
else \
|
| 28 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
| 29 |
fi
|
| 30 |
|
| 31 |
# So we can test the Docker image
|
docker/Dockerfile
CHANGED
|
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|
| 21 |
|
| 22 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 24 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 25 |
else \
|
| 26 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
| 27 |
fi
|
| 28 |
|
| 29 |
# So we can test the Docker image
|
|
|
|
| 21 |
|
| 22 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 24 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 25 |
else \
|
| 26 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
| 27 |
fi
|
| 28 |
|
| 29 |
# So we can test the Docker image
|
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft==0.9.0
|
| 4 |
-
transformers
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.43.0
|
| 7 |
accelerate==0.26.1
|
|
@@ -39,5 +39,5 @@ s3fs
|
|
| 39 |
gcsfs
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
-
trl
|
| 43 |
fastcore>=1.5.29
|
|
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft==0.9.0
|
| 4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.43.0
|
| 7 |
accelerate==0.26.1
|
|
|
|
| 39 |
gcsfs
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
+
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
| 43 |
fastcore>=1.5.29
|
setup.py
CHANGED
|
@@ -89,5 +89,8 @@ setup(
|
|
| 89 |
"lion-pytorch": [
|
| 90 |
"lion-pytorch==0.1.2",
|
| 91 |
],
|
|
|
|
|
|
|
|
|
|
| 92 |
},
|
| 93 |
)
|
|
|
|
| 89 |
"lion-pytorch": [
|
| 90 |
"lion-pytorch==0.1.2",
|
| 91 |
],
|
| 92 |
+
"galore": [
|
| 93 |
+
"galore_torch",
|
| 94 |
+
],
|
| 95 |
},
|
| 96 |
)
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -220,7 +220,7 @@ class AxolotlTrainer(Trainer):
|
|
| 220 |
num_epochs=1,
|
| 221 |
bench_data_collator=None,
|
| 222 |
eval_data_collator=None,
|
| 223 |
-
**kwargs
|
| 224 |
):
|
| 225 |
self.num_epochs = num_epochs
|
| 226 |
self.bench_data_collator = bench_data_collator
|
|
@@ -239,6 +239,7 @@ class AxolotlTrainer(Trainer):
|
|
| 239 |
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
| 240 |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
| 241 |
self.args,
|
|
|
|
| 242 |
)
|
| 243 |
|
| 244 |
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
@@ -1150,6 +1151,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 1150 |
training_arguments_kwargs["optim"] = (
|
| 1151 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
| 1152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
| 1154 |
training_arguments_kwargs[
|
| 1155 |
"loraplus_lr_embedding"
|
|
|
|
| 220 |
num_epochs=1,
|
| 221 |
bench_data_collator=None,
|
| 222 |
eval_data_collator=None,
|
| 223 |
+
**kwargs,
|
| 224 |
):
|
| 225 |
self.num_epochs = num_epochs
|
| 226 |
self.bench_data_collator = bench_data_collator
|
|
|
|
| 239 |
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
| 240 |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
| 241 |
self.args,
|
| 242 |
+
opt_model,
|
| 243 |
)
|
| 244 |
|
| 245 |
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
|
|
| 1151 |
training_arguments_kwargs["optim"] = (
|
| 1152 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
| 1153 |
)
|
| 1154 |
+
if self.cfg.optim_args:
|
| 1155 |
+
if isinstance(self.cfg.optim_args, dict):
|
| 1156 |
+
optim_args = ",".join(
|
| 1157 |
+
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
| 1158 |
+
)
|
| 1159 |
+
else:
|
| 1160 |
+
optim_args = self.cfg.optim_args
|
| 1161 |
+
training_arguments_kwargs["optim_args"] = optim_args
|
| 1162 |
+
if self.cfg.optim_target_modules:
|
| 1163 |
+
training_arguments_kwargs[
|
| 1164 |
+
"optim_target_modules"
|
| 1165 |
+
] = self.cfg.optim_target_modules
|
| 1166 |
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
| 1167 |
training_arguments_kwargs[
|
| 1168 |
"loraplus_lr_embedding"
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -313,6 +313,15 @@ class HyperparametersConfig(BaseModel):
|
|
| 313 |
learning_rate: Union[str, float]
|
| 314 |
weight_decay: Optional[float] = None
|
| 315 |
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
torchdistx_path: Optional[str] = None
|
| 317 |
lr_scheduler: Optional[SchedulerType] = None
|
| 318 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
| 313 |
learning_rate: Union[str, float]
|
| 314 |
weight_decay: Optional[float] = None
|
| 315 |
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
| 316 |
+
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
| 317 |
+
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
| 318 |
+
)
|
| 319 |
+
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
| 320 |
+
default=None,
|
| 321 |
+
metadata={
|
| 322 |
+
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
| 323 |
+
},
|
| 324 |
+
)
|
| 325 |
torchdistx_path: Optional[str] = None
|
| 326 |
lr_scheduler: Optional[SchedulerType] = None
|
| 327 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|