DPO cleanup (#1126)
Browse files* cleanup dpo to be a little more extensible, add zephyr/nectar strategy
* fix eos slash
* support for eval split
* fix kwargs
* handle empty evals
* don't load peft model for dpo
* ensure dpo traning args gets bf16 for peft if applicable
* fix duplicate kwargs for bf16
* make sure to respect the configured lr scheduler
* supprt trainer callback to push config to wandb
* set dataloader preload args
* ensure that we are loading the lora when merging
* Update src/axolotl/utils/data.py
Co-authored-by: Agus <[email protected]>
* support local datasets for dpo
Co-authored-by: Agus <[email protected]>
* chore: lint
* dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names
* add split to dpo tests
* fix rebase/merging error
* handle edge case w logging
* use accelerator for dpo datasets so it doesn't break the logger
* missing args
* validate checkpoint is an adapter for now
* log warning when dataset strategy is not loadable
---------
Co-authored-by: Agus <[email protected]>
- src/axolotl/cli/__init__.py +2 -77
- src/axolotl/core/trainer_builder.py +86 -23
- src/axolotl/prompt_strategies/dpo/__init__.py +21 -0
- src/axolotl/prompt_strategies/dpo/chatml.py +85 -0
- src/axolotl/prompt_strategies/dpo/zephyr.py +21 -0
- src/axolotl/train.py +6 -1
- src/axolotl/utils/data.py +49 -1
- src/axolotl/utils/models.py +11 -3
- src/axolotl/utils/trainer.py +2 -1
- tests/e2e/test_dpo.py +157 -0
| @@ -17,7 +17,6 @@ import yaml | |
| 17 | 
             
            # add src to the pythonpath so we don't need to pip install this
         | 
| 18 | 
             
            from accelerate.commands.config import config_args
         | 
| 19 | 
             
            from art import text2art
         | 
| 20 | 
            -
            from datasets import concatenate_datasets, load_dataset
         | 
| 21 | 
             
            from huggingface_hub import HfApi
         | 
| 22 | 
             
            from huggingface_hub.utils import LocalTokenNotFoundError
         | 
| 23 | 
             
            from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
         | 
| @@ -30,7 +29,7 @@ from axolotl.utils.config import ( | |
| 30 | 
             
                normalize_config,
         | 
| 31 | 
             
                validate_config,
         | 
| 32 | 
             
            )
         | 
| 33 | 
            -
            from axolotl.utils.data import prepare_dataset
         | 
| 34 | 
             
            from axolotl.utils.dict import DictDefault
         | 
| 35 | 
             
            from axolotl.utils.distributed import is_main_process
         | 
| 36 | 
             
            from axolotl.utils.mlflow_ import setup_mlflow_env_vars
         | 
| @@ -343,81 +342,7 @@ def load_rl_datasets( | |
| 343 | 
             
                cfg: DictDefault,
         | 
| 344 | 
             
                cli_args: TrainerCliArgs,  # pylint: disable=unused-argument
         | 
| 345 | 
             
            ) -> TrainDatasetMeta:
         | 
| 346 | 
            -
                 | 
| 347 | 
            -
                for i, ds_cfg in enumerate(cfg.datasets):
         | 
| 348 | 
            -
                    train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
         | 
| 349 | 
            -
                # eval_dataset = load_dataset(
         | 
| 350 | 
            -
                #     cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
         | 
| 351 | 
            -
                # )
         | 
| 352 | 
            -
                eval_dataset = None
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                def argilla_apply_chatml(sample):  # pylint: disable=possibly-unused-variable
         | 
| 355 | 
            -
                    if "system" in sample and sample["system"]:
         | 
| 356 | 
            -
                        sample["prompt"] = (
         | 
| 357 | 
            -
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 358 | 
            -
                            f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 359 | 
            -
                        )
         | 
| 360 | 
            -
                    else:
         | 
| 361 | 
            -
                        sample[
         | 
| 362 | 
            -
                            "prompt"
         | 
| 363 | 
            -
                        ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 364 | 
            -
                    sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
         | 
| 365 | 
            -
                    sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
         | 
| 366 | 
            -
                    return sample
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                def intel_apply_chatml(sample):  # pylint: disable=possibly-unused-variable
         | 
| 369 | 
            -
                    if "system" in sample and sample["system"]:
         | 
| 370 | 
            -
                        sample["prompt"] = (
         | 
| 371 | 
            -
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 372 | 
            -
                            f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 373 | 
            -
                        )
         | 
| 374 | 
            -
                    else:
         | 
| 375 | 
            -
                        sample[
         | 
| 376 | 
            -
                            "prompt"
         | 
| 377 | 
            -
                        ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 378 | 
            -
                    sample["chosen"] = f"{sample['chosen']}<|im_end|>"
         | 
| 379 | 
            -
                    sample["rejected"] = f"{sample['rejected']}<|im_end|>"
         | 
| 380 | 
            -
                    return sample
         | 
| 381 | 
            -
             | 
| 382 | 
            -
                def apply_chatml(sample):  # pylint: disable=possibly-unused-variable
         | 
| 383 | 
            -
                    if "system" in sample and sample["system"]:
         | 
| 384 | 
            -
                        sample["prompt"] = (
         | 
| 385 | 
            -
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 386 | 
            -
                            f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 387 | 
            -
                        )
         | 
| 388 | 
            -
                    else:
         | 
| 389 | 
            -
                        sample[
         | 
| 390 | 
            -
                            "prompt"
         | 
| 391 | 
            -
                        ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 392 | 
            -
                    sample["chosen"] = f"{sample['chosen']}<|im_end|>"
         | 
| 393 | 
            -
                    sample["rejected"] = f"{sample['rejected']}<|im_end|>"
         | 
| 394 | 
            -
                    return sample
         | 
| 395 | 
            -
             | 
| 396 | 
            -
                def ultra_apply_chatml(sample):  # pylint: disable=possibly-unused-variable
         | 
| 397 | 
            -
                    if "system" in sample and sample["system"]:
         | 
| 398 | 
            -
                        sample["prompt"] = (
         | 
| 399 | 
            -
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 400 | 
            -
                            f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 401 | 
            -
                        )
         | 
| 402 | 
            -
                    else:
         | 
| 403 | 
            -
                        sample[
         | 
| 404 | 
            -
                            "prompt"
         | 
| 405 | 
            -
                        ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 406 | 
            -
                    sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
         | 
| 407 | 
            -
                    sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
         | 
| 408 | 
            -
                    return sample
         | 
| 409 | 
            -
             | 
| 410 | 
            -
                for i, data_set in enumerate(train_datasets):
         | 
| 411 | 
            -
                    _type = cfg.datasets[i]["type"]
         | 
| 412 | 
            -
                    ds_type_fn = locals()[_type]
         | 
| 413 | 
            -
                    train_datasets[i] = data_set.map(
         | 
| 414 | 
            -
                        ds_type_fn,
         | 
| 415 | 
            -
                        desc="Mapping RL Dataset",
         | 
| 416 | 
            -
                    )
         | 
| 417 | 
            -
                train_dataset = concatenate_datasets(train_datasets)
         | 
| 418 | 
            -
             | 
| 419 | 
            -
                # eval_dataset = eval_dataset.map(intel_apply_chatml)
         | 
| 420 | 
            -
             | 
| 421 | 
             
                total_num_steps = int(
         | 
| 422 | 
             
                    math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
         | 
| 423 | 
             
                )
         | 
|  | |
| 17 | 
             
            # add src to the pythonpath so we don't need to pip install this
         | 
| 18 | 
             
            from accelerate.commands.config import config_args
         | 
| 19 | 
             
            from art import text2art
         | 
|  | |
| 20 | 
             
            from huggingface_hub import HfApi
         | 
| 21 | 
             
            from huggingface_hub.utils import LocalTokenNotFoundError
         | 
| 22 | 
             
            from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
         | 
|  | |
| 29 | 
             
                normalize_config,
         | 
| 30 | 
             
                validate_config,
         | 
| 31 | 
             
            )
         | 
| 32 | 
            +
            from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
         | 
| 33 | 
             
            from axolotl.utils.dict import DictDefault
         | 
| 34 | 
             
            from axolotl.utils.distributed import is_main_process
         | 
| 35 | 
             
            from axolotl.utils.mlflow_ import setup_mlflow_env_vars
         | 
|  | |
| 342 | 
             
                cfg: DictDefault,
         | 
| 343 | 
             
                cli_args: TrainerCliArgs,  # pylint: disable=unused-argument
         | 
| 344 | 
             
            ) -> TrainDatasetMeta:
         | 
| 345 | 
            +
                train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 346 | 
             
                total_num_steps = int(
         | 
| 347 | 
             
                    math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
         | 
| 348 | 
             
                )
         | 
| @@ -12,14 +12,19 @@ from abc import abstractmethod | |
| 12 | 
             
            from dataclasses import dataclass, field
         | 
| 13 | 
             
            from functools import wraps
         | 
| 14 | 
             
            from pathlib import Path
         | 
| 15 | 
            -
            from typing import Optional, Type, Union
         | 
| 16 |  | 
| 17 | 
             
            import torch
         | 
| 18 | 
             
            import transformers
         | 
| 19 | 
             
            from datasets import Dataset
         | 
| 20 | 
             
            from torch.optim.lr_scheduler import OneCycleLR
         | 
| 21 | 
             
            from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
         | 
| 22 | 
            -
            from transformers import  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 23 | 
             
            from transformers.trainer_utils import seed_worker
         | 
| 24 | 
             
            from trl import DPOTrainer
         | 
| 25 |  | 
| @@ -460,6 +465,7 @@ class TrainerBuilderBase(abc.ABC): | |
| 460 | 
             
                _train_dataset = None
         | 
| 461 | 
             
                _eval_dataset = None
         | 
| 462 | 
             
                _model_ref = None
         | 
|  | |
| 463 |  | 
| 464 | 
             
                def __init__(self, cfg, model, tokenizer):
         | 
| 465 | 
             
                    self.cfg = cfg
         | 
| @@ -490,13 +496,26 @@ class TrainerBuilderBase(abc.ABC): | |
| 490 | 
             
                def eval_dataset(self, dataset):
         | 
| 491 | 
             
                    self._eval_dataset = dataset
         | 
| 492 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 493 | 
             
                @abstractmethod
         | 
| 494 | 
             
                def build(self, total_num_steps):
         | 
| 495 | 
             
                    pass
         | 
| 496 |  | 
| 497 | 
            -
                 | 
| 498 | 
            -
             | 
| 499 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 500 |  | 
| 501 | 
             
                @abstractmethod
         | 
| 502 | 
             
                def get_post_trainer_create_callbacks(self, trainer):
         | 
| @@ -504,12 +523,6 @@ class TrainerBuilderBase(abc.ABC): | |
| 504 | 
             
                    Callbacks added after the trainer is created, usually b/c these need access to the trainer
         | 
| 505 | 
             
                    """
         | 
| 506 |  | 
| 507 | 
            -
             | 
| 508 | 
            -
            class HFCausalTrainerBuilder(TrainerBuilderBase):
         | 
| 509 | 
            -
                """
         | 
| 510 | 
            -
                Build the HuggingFace training args/trainer for Causal models
         | 
| 511 | 
            -
                """
         | 
| 512 | 
            -
             | 
| 513 | 
             
                def hook_pre_create_training_args(self, training_arguments_kwargs):
         | 
| 514 | 
             
                    # TODO
         | 
| 515 | 
             
                    return training_arguments_kwargs
         | 
| @@ -526,10 +539,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): | |
| 526 | 
             
                    # TODO
         | 
| 527 | 
             
                    return trainer
         | 
| 528 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 529 | 
             
                def get_callbacks(self):
         | 
| 530 | 
            -
                    callbacks =  | 
| 531 | 
             
                    callbacks.append(GPUStatsCallback(self.cfg))
         | 
| 532 | 
            -
                    callbacks.append(EvalFirstStepCallback)
         | 
| 533 |  | 
| 534 | 
             
                    if self.cfg.relora_steps:
         | 
| 535 | 
             
                        callbacks.append(ReLoRACallback(self.cfg))
         | 
| @@ -538,7 +557,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): | |
| 538 | 
             
                        hasattr(self.model, "use_bettertransformer")
         | 
| 539 | 
             
                        and self.model.use_bettertransformer is True
         | 
| 540 | 
             
                    ):
         | 
| 541 | 
            -
                        callbacks.append(SaveBetterTransformerModelCallback)
         | 
| 542 |  | 
| 543 | 
             
                    if self.cfg.use_wandb:
         | 
| 544 | 
             
                        callbacks.append(
         | 
| @@ -931,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): | |
| 931 | 
             
                """
         | 
| 932 |  | 
| 933 | 
             
                def get_callbacks(self):
         | 
| 934 | 
            -
                    callbacks =  | 
| 935 | 
             
                    return callbacks
         | 
| 936 |  | 
| 937 | 
             
                def get_post_trainer_create_callbacks(self, trainer):
         | 
| @@ -949,21 +968,60 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): | |
| 949 | 
             
                    ]:
         | 
| 950 | 
             
                        if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
         | 
| 951 | 
             
                            training_args_kwargs[arg] = getattr(self.cfg, arg)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 952 | 
             
                    training_args = TrainingArguments(
         | 
| 953 | 
             
                        per_device_train_batch_size=self.cfg.micro_batch_size,
         | 
| 954 | 
            -
                        max_steps=total_num_steps,
         | 
| 955 | 
             
                        remove_unused_columns=False,
         | 
| 956 | 
             
                        gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
         | 
| 957 | 
             
                        learning_rate=self.cfg.learning_rate,
         | 
| 958 | 
            -
                        evaluation_strategy="no",
         | 
| 959 | 
            -
                        # eval_steps=self.cfg.eval_steps,
         | 
| 960 | 
             
                        save_strategy="steps",
         | 
| 961 | 
             
                        save_steps=self.cfg.save_steps,
         | 
| 962 | 
             
                        output_dir=self.cfg.output_dir,
         | 
| 963 | 
             
                        warmup_steps=self.cfg.warmup_steps,
         | 
| 964 | 
            -
                        bf16=True,
         | 
| 965 | 
             
                        gradient_checkpointing=self.cfg.gradient_checkpointing,
         | 
| 966 | 
            -
                        gradient_checkpointing_kwargs= | 
|  | |
| 967 | 
             
                        logging_first_step=True,
         | 
| 968 | 
             
                        logging_steps=1,
         | 
| 969 | 
             
                        optim=self.cfg.optimizer,
         | 
| @@ -982,22 +1040,27 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): | |
| 982 | 
             
                            dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
         | 
| 983 | 
             
                    elif self.cfg.rl == "kto_pair":
         | 
| 984 | 
             
                        dpo_trainer_kwargs["loss_type"] = "kto_pair"
         | 
| 985 | 
            -
             | 
|  | |
|  | |
|  | |
| 986 | 
             
                    dpo_trainer = DPOTrainer(
         | 
| 987 | 
             
                        self.model,
         | 
| 988 | 
             
                        self.model_ref,
         | 
| 989 | 
             
                        args=training_args,
         | 
| 990 | 
             
                        beta=self.cfg.dpo_beta or 0.1,
         | 
| 991 | 
             
                        train_dataset=self.train_dataset,
         | 
| 992 | 
            -
                        # eval_dataset=self.eval_dataset,
         | 
| 993 | 
            -
                        eval_dataset=None,
         | 
| 994 | 
             
                        tokenizer=self.tokenizer,
         | 
| 995 | 
             
                        max_length=self.cfg.sequence_len,
         | 
| 996 | 
             
                        max_target_length=None,
         | 
| 997 | 
             
                        max_prompt_length=self.cfg.sequence_len,
         | 
| 998 | 
             
                        generate_during_eval=True,
         | 
|  | |
| 999 | 
             
                        **dpo_trainer_kwargs,
         | 
| 1000 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 1001 |  | 
| 1002 | 
             
                    return dpo_trainer
         | 
| 1003 |  | 
|  | |
| 12 | 
             
            from dataclasses import dataclass, field
         | 
| 13 | 
             
            from functools import wraps
         | 
| 14 | 
             
            from pathlib import Path
         | 
| 15 | 
            +
            from typing import List, Optional, Type, Union
         | 
| 16 |  | 
| 17 | 
             
            import torch
         | 
| 18 | 
             
            import transformers
         | 
| 19 | 
             
            from datasets import Dataset
         | 
| 20 | 
             
            from torch.optim.lr_scheduler import OneCycleLR
         | 
| 21 | 
             
            from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
         | 
| 22 | 
            +
            from transformers import (
         | 
| 23 | 
            +
                EarlyStoppingCallback,
         | 
| 24 | 
            +
                Trainer,
         | 
| 25 | 
            +
                TrainerCallback,
         | 
| 26 | 
            +
                TrainingArguments,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
             
            from transformers.trainer_utils import seed_worker
         | 
| 29 | 
             
            from trl import DPOTrainer
         | 
| 30 |  | 
|  | |
| 465 | 
             
                _train_dataset = None
         | 
| 466 | 
             
                _eval_dataset = None
         | 
| 467 | 
             
                _model_ref = None
         | 
| 468 | 
            +
                _peft_config = None
         | 
| 469 |  | 
| 470 | 
             
                def __init__(self, cfg, model, tokenizer):
         | 
| 471 | 
             
                    self.cfg = cfg
         | 
|  | |
| 496 | 
             
                def eval_dataset(self, dataset):
         | 
| 497 | 
             
                    self._eval_dataset = dataset
         | 
| 498 |  | 
| 499 | 
            +
                @property
         | 
| 500 | 
            +
                def peft_config(self):
         | 
| 501 | 
            +
                    return self._peft_config
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                @peft_config.setter
         | 
| 504 | 
            +
                def peft_config(self, peft_config):
         | 
| 505 | 
            +
                    self._peft_config = peft_config
         | 
| 506 | 
            +
             | 
| 507 | 
             
                @abstractmethod
         | 
| 508 | 
             
                def build(self, total_num_steps):
         | 
| 509 | 
             
                    pass
         | 
| 510 |  | 
| 511 | 
            +
                def get_callbacks(self) -> List[TrainerCallback]:
         | 
| 512 | 
            +
                    callbacks = []
         | 
| 513 | 
            +
                    if self.cfg.use_wandb:
         | 
| 514 | 
            +
                        callbacks.append(
         | 
| 515 | 
            +
                            SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
         | 
| 516 | 
            +
                        )
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    return callbacks
         | 
| 519 |  | 
| 520 | 
             
                @abstractmethod
         | 
| 521 | 
             
                def get_post_trainer_create_callbacks(self, trainer):
         | 
|  | |
| 523 | 
             
                    Callbacks added after the trainer is created, usually b/c these need access to the trainer
         | 
| 524 | 
             
                    """
         | 
| 525 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 526 | 
             
                def hook_pre_create_training_args(self, training_arguments_kwargs):
         | 
| 527 | 
             
                    # TODO
         | 
| 528 | 
             
                    return training_arguments_kwargs
         | 
|  | |
| 539 | 
             
                    # TODO
         | 
| 540 | 
             
                    return trainer
         | 
| 541 |  | 
| 542 | 
            +
             | 
| 543 | 
            +
            class HFCausalTrainerBuilder(TrainerBuilderBase):
         | 
| 544 | 
            +
                """
         | 
| 545 | 
            +
                Build the HuggingFace training args/trainer for Causal models
         | 
| 546 | 
            +
                """
         | 
| 547 | 
            +
             | 
| 548 | 
             
                def get_callbacks(self):
         | 
| 549 | 
            +
                    callbacks = super().get_callbacks()
         | 
| 550 | 
             
                    callbacks.append(GPUStatsCallback(self.cfg))
         | 
| 551 | 
            +
                    callbacks.append(EvalFirstStepCallback())
         | 
| 552 |  | 
| 553 | 
             
                    if self.cfg.relora_steps:
         | 
| 554 | 
             
                        callbacks.append(ReLoRACallback(self.cfg))
         | 
|  | |
| 557 | 
             
                        hasattr(self.model, "use_bettertransformer")
         | 
| 558 | 
             
                        and self.model.use_bettertransformer is True
         | 
| 559 | 
             
                    ):
         | 
| 560 | 
            +
                        callbacks.append(SaveBetterTransformerModelCallback())
         | 
| 561 |  | 
| 562 | 
             
                    if self.cfg.use_wandb:
         | 
| 563 | 
             
                        callbacks.append(
         | 
|  | |
| 950 | 
             
                """
         | 
| 951 |  | 
| 952 | 
             
                def get_callbacks(self):
         | 
| 953 | 
            +
                    callbacks = super().get_callbacks()
         | 
| 954 | 
             
                    return callbacks
         | 
| 955 |  | 
| 956 | 
             
                def get_post_trainer_create_callbacks(self, trainer):
         | 
|  | |
| 968 | 
             
                    ]:
         | 
| 969 | 
             
                        if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
         | 
| 970 | 
             
                            training_args_kwargs[arg] = getattr(self.cfg, arg)
         | 
| 971 | 
            +
             | 
| 972 | 
            +
                    if self.cfg.hub_model_id:
         | 
| 973 | 
            +
                        training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
         | 
| 974 | 
            +
                        training_args_kwargs["push_to_hub"] = True
         | 
| 975 | 
            +
                        training_args_kwargs["hub_private_repo"] = True
         | 
| 976 | 
            +
                        training_args_kwargs["hub_always_push"] = True
         | 
| 977 | 
            +
             | 
| 978 | 
            +
                        if self.cfg.hub_strategy:
         | 
| 979 | 
            +
                            training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
         | 
| 980 | 
            +
             | 
| 981 | 
            +
                    if self.cfg.save_safetensors is not None:
         | 
| 982 | 
            +
                        training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                    if self.eval_dataset:
         | 
| 985 | 
            +
                        training_args_kwargs["evaluation_strategy"] = "steps"
         | 
| 986 | 
            +
                        training_args_kwargs["eval_steps"] = self.cfg.eval_steps
         | 
| 987 | 
            +
                    else:
         | 
| 988 | 
            +
                        training_args_kwargs["evaluation_strategy"] = "no"
         | 
| 989 | 
            +
                    if self.cfg.bf16 or self.cfg.bfloat16:
         | 
| 990 | 
            +
                        training_args_kwargs["bf16"] = True
         | 
| 991 | 
            +
             | 
| 992 | 
            +
                    training_args_kwargs["lr_scheduler_type"] = (
         | 
| 993 | 
            +
                        self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
         | 
| 994 | 
            +
                    )
         | 
| 995 | 
            +
                    training_args_kwargs["lr_scheduler_kwargs"] = (
         | 
| 996 | 
            +
                        self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
         | 
| 997 | 
            +
                    )
         | 
| 998 | 
            +
             | 
| 999 | 
            +
                    if self.cfg.dataloader_pin_memory is not None:
         | 
| 1000 | 
            +
                        training_args_kwargs[
         | 
| 1001 | 
            +
                            "dataloader_pin_memory"
         | 
| 1002 | 
            +
                        ] = self.cfg.dataloader_pin_memory
         | 
| 1003 | 
            +
                    if self.cfg.dataloader_num_workers is not None:
         | 
| 1004 | 
            +
                        training_args_kwargs[
         | 
| 1005 | 
            +
                            "dataloader_num_workers"
         | 
| 1006 | 
            +
                        ] = self.cfg.dataloader_num_workers
         | 
| 1007 | 
            +
                    if self.cfg.dataloader_prefetch_factor is not None:
         | 
| 1008 | 
            +
                        training_args_kwargs[
         | 
| 1009 | 
            +
                            "dataloader_prefetch_factor"
         | 
| 1010 | 
            +
                        ] = self.cfg.dataloader_prefetch_factor
         | 
| 1011 | 
            +
             | 
| 1012 | 
             
                    training_args = TrainingArguments(
         | 
| 1013 | 
             
                        per_device_train_batch_size=self.cfg.micro_batch_size,
         | 
| 1014 | 
            +
                        max_steps=self.cfg.max_steps or total_num_steps,
         | 
| 1015 | 
             
                        remove_unused_columns=False,
         | 
| 1016 | 
             
                        gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
         | 
| 1017 | 
             
                        learning_rate=self.cfg.learning_rate,
         | 
|  | |
|  | |
| 1018 | 
             
                        save_strategy="steps",
         | 
| 1019 | 
             
                        save_steps=self.cfg.save_steps,
         | 
| 1020 | 
             
                        output_dir=self.cfg.output_dir,
         | 
| 1021 | 
             
                        warmup_steps=self.cfg.warmup_steps,
         | 
|  | |
| 1022 | 
             
                        gradient_checkpointing=self.cfg.gradient_checkpointing,
         | 
| 1023 | 
            +
                        gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
         | 
| 1024 | 
            +
                        or {"use_reentrant": False},
         | 
| 1025 | 
             
                        logging_first_step=True,
         | 
| 1026 | 
             
                        logging_steps=1,
         | 
| 1027 | 
             
                        optim=self.cfg.optimizer,
         | 
|  | |
| 1040 | 
             
                            dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
         | 
| 1041 | 
             
                    elif self.cfg.rl == "kto_pair":
         | 
| 1042 | 
             
                        dpo_trainer_kwargs["loss_type"] = "kto_pair"
         | 
| 1043 | 
            +
                    if self.eval_dataset:
         | 
| 1044 | 
            +
                        dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
         | 
| 1045 | 
            +
                    if self.cfg.adapter and self.peft_config:
         | 
| 1046 | 
            +
                        dpo_trainer_kwargs["peft_config"] = self.peft_config
         | 
| 1047 | 
             
                    dpo_trainer = DPOTrainer(
         | 
| 1048 | 
             
                        self.model,
         | 
| 1049 | 
             
                        self.model_ref,
         | 
| 1050 | 
             
                        args=training_args,
         | 
| 1051 | 
             
                        beta=self.cfg.dpo_beta or 0.1,
         | 
| 1052 | 
             
                        train_dataset=self.train_dataset,
         | 
|  | |
|  | |
| 1053 | 
             
                        tokenizer=self.tokenizer,
         | 
| 1054 | 
             
                        max_length=self.cfg.sequence_len,
         | 
| 1055 | 
             
                        max_target_length=None,
         | 
| 1056 | 
             
                        max_prompt_length=self.cfg.sequence_len,
         | 
| 1057 | 
             
                        generate_during_eval=True,
         | 
| 1058 | 
            +
                        callbacks=self.get_callbacks(),
         | 
| 1059 | 
             
                        **dpo_trainer_kwargs,
         | 
| 1060 | 
             
                    )
         | 
| 1061 | 
            +
                    dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
         | 
| 1062 | 
            +
                    for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
         | 
| 1063 | 
            +
                        dpo_trainer.add_callback(callback)
         | 
| 1064 |  | 
| 1065 | 
             
                    return dpo_trainer
         | 
| 1066 |  | 
| @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            module for DPO style dataset transform strategies
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import importlib
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            LOG = logging.getLogger("axolotl")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def load(strategy, cfg):
         | 
| 12 | 
            +
                try:
         | 
| 13 | 
            +
                    load_fn = strategy.split(".")[-1]
         | 
| 14 | 
            +
                    strategy = ".".join(strategy.split(".")[:-1])
         | 
| 15 | 
            +
                    mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
         | 
| 16 | 
            +
                    func = getattr(mod, load_fn)
         | 
| 17 | 
            +
                    load_kwargs = {}
         | 
| 18 | 
            +
                    return func(cfg, **load_kwargs)
         | 
| 19 | 
            +
                except Exception:  # pylint: disable=broad-exception-caught
         | 
| 20 | 
            +
                    LOG.warning(f"unable to load strategy {strategy}")
         | 
| 21 | 
            +
                    return None
         | 
| @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            DPO strategies for chatml
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def argilla(
         | 
| 7 | 
            +
                cfg,
         | 
| 8 | 
            +
            ):  # pylint: disable=possibly-unused-variable,unused-argument
         | 
| 9 | 
            +
                def transform_fn(sample):
         | 
| 10 | 
            +
                    if "system" in sample and sample["system"]:
         | 
| 11 | 
            +
                        sample["prompt"] = (
         | 
| 12 | 
            +
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 13 | 
            +
                            f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 14 | 
            +
                        )
         | 
| 15 | 
            +
                    else:
         | 
| 16 | 
            +
                        sample[
         | 
| 17 | 
            +
                            "prompt"
         | 
| 18 | 
            +
                        ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 19 | 
            +
                    sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
         | 
| 20 | 
            +
                    sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
         | 
| 21 | 
            +
                    return sample
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                return transform_fn
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def intel(cfg):  # pylint: disable=possibly-unused-variable,unused-argument
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                For Intel Orca DPO Pairs
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def transform_fn(sample):
         | 
| 32 | 
            +
                    if "system" in sample and sample["system"]:
         | 
| 33 | 
            +
                        sample["prompt"] = (
         | 
| 34 | 
            +
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 35 | 
            +
                            f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 36 | 
            +
                        )
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        sample[
         | 
| 39 | 
            +
                            "prompt"
         | 
| 40 | 
            +
                        ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 41 | 
            +
                    sample["chosen"] = f"{sample['chosen']}<|im_end|>"
         | 
| 42 | 
            +
                    sample["rejected"] = f"{sample['rejected']}<|im_end|>"
         | 
| 43 | 
            +
                    return sample
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                return transform_fn
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def prompt_pairs(cfg):  # pylint: disable=possibly-unused-variable,unused-argument
         | 
| 49 | 
            +
                def transform_fn(sample):
         | 
| 50 | 
            +
                    if "system" in sample and sample["system"]:
         | 
| 51 | 
            +
                        sample["prompt"] = (
         | 
| 52 | 
            +
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 53 | 
            +
                            f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        sample[
         | 
| 57 | 
            +
                            "prompt"
         | 
| 58 | 
            +
                        ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 59 | 
            +
                    sample["chosen"] = f"{sample['chosen']}<|im_end|>"
         | 
| 60 | 
            +
                    sample["rejected"] = f"{sample['rejected']}<|im_end|>"
         | 
| 61 | 
            +
                    return sample
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                return transform_fn
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def ultra(cfg):  # pylint: disable=possibly-unused-variable,unused-argument
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                for ultrafeedback binarized conversations
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def transform_fn(sample):
         | 
| 72 | 
            +
                    if "system" in sample and sample["system"]:
         | 
| 73 | 
            +
                        sample["prompt"] = (
         | 
| 74 | 
            +
                            f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
         | 
| 75 | 
            +
                            f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 76 | 
            +
                        )
         | 
| 77 | 
            +
                    else:
         | 
| 78 | 
            +
                        sample[
         | 
| 79 | 
            +
                            "prompt"
         | 
| 80 | 
            +
                        ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
         | 
| 81 | 
            +
                    sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
         | 
| 82 | 
            +
                    sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
         | 
| 83 | 
            +
                    return sample
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                return transform_fn
         | 
| @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            DPO strategies for zephyr
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def nectar(cfg):  # pylint: disable=possibly-unused-variable,unused-argument
         | 
| 7 | 
            +
                def transform_fn(sample):
         | 
| 8 | 
            +
                    data = {}
         | 
| 9 | 
            +
                    data["prompt"] = (
         | 
| 10 | 
            +
                        "<|system|>\n</s>\n"
         | 
| 11 | 
            +
                        "<|user|>\n"
         | 
| 12 | 
            +
                        f"{sample['prompt']}</s>\n"
         | 
| 13 | 
            +
                        "<|assistant|>\n"
         | 
| 14 | 
            +
                    )
         | 
| 15 | 
            +
                    answers = sorted(sample["answers"], key=lambda x: x["rank"])
         | 
| 16 | 
            +
                    data["chosen"] = answers[-1]["answer"]
         | 
| 17 | 
            +
                    data["rejected"] = answers[-2]["answer"]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    return data
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                return transform_fn
         | 
| @@ -96,7 +96,12 @@ def train( | |
| 96 | 
             
                    freeze_parameters_except(model, cfg.unfrozen_parameters)
         | 
| 97 |  | 
| 98 | 
             
                trainer = setup_trainer(
         | 
| 99 | 
            -
                    cfg, | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 100 | 
             
                )
         | 
| 101 |  | 
| 102 | 
             
                if hasattr(model, "config"):
         | 
|  | |
| 96 | 
             
                    freeze_parameters_except(model, cfg.unfrozen_parameters)
         | 
| 97 |  | 
| 98 | 
             
                trainer = setup_trainer(
         | 
| 99 | 
            +
                    cfg,
         | 
| 100 | 
            +
                    train_dataset,
         | 
| 101 | 
            +
                    eval_dataset,
         | 
| 102 | 
            +
                    (model, model_ref, peft_config),
         | 
| 103 | 
            +
                    tokenizer,
         | 
| 104 | 
            +
                    total_num_steps,
         | 
| 105 | 
             
                )
         | 
| 106 |  | 
| 107 | 
             
                if hasattr(model, "config"):
         | 
| @@ -4,7 +4,7 @@ import hashlib | |
| 4 | 
             
            import logging
         | 
| 5 | 
             
            from collections import defaultdict
         | 
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
            -
            from typing import Dict, List, Optional, Tuple, Union
         | 
| 8 |  | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            from datasets import (
         | 
| @@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizerBase | |
| 21 | 
             
            from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
         | 
| 22 | 
             
            from axolotl.datasets import TokenizedPromptDataset
         | 
| 23 | 
             
            from axolotl.prompt_strategies import load
         | 
|  | |
| 24 | 
             
            from axolotl.prompt_tokenizers import (
         | 
| 25 | 
             
                AlpacaMultipleChoicePromptTokenizingStrategy,
         | 
| 26 | 
             
                AlpacaPromptTokenizingStrategy,
         | 
| @@ -850,3 +851,50 @@ def encode_packed_pretraining( | |
| 850 | 
             
                        chunked_data[feature].append(collated_features[feature].squeeze(0))
         | 
| 851 |  | 
| 852 | 
             
                return chunked_data
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 4 | 
             
            import logging
         | 
| 5 | 
             
            from collections import defaultdict
         | 
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 8 |  | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            from datasets import (
         | 
|  | |
| 21 | 
             
            from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
         | 
| 22 | 
             
            from axolotl.datasets import TokenizedPromptDataset
         | 
| 23 | 
             
            from axolotl.prompt_strategies import load
         | 
| 24 | 
            +
            from axolotl.prompt_strategies.dpo import load as load_dpo
         | 
| 25 | 
             
            from axolotl.prompt_tokenizers import (
         | 
| 26 | 
             
                AlpacaMultipleChoicePromptTokenizingStrategy,
         | 
| 27 | 
             
                AlpacaPromptTokenizingStrategy,
         | 
|  | |
| 851 | 
             
                        chunked_data[feature].append(collated_features[feature].squeeze(0))
         | 
| 852 |  | 
| 853 | 
             
                return chunked_data
         | 
| 854 | 
            +
             | 
| 855 | 
            +
             | 
| 856 | 
            +
            def load_prepare_dpo_datasets(cfg):
         | 
| 857 | 
            +
                def load_split(dataset_cfgs, _cfg):
         | 
| 858 | 
            +
                    split_datasets: List[Any] = []
         | 
| 859 | 
            +
                    for i, ds_cfg in enumerate(dataset_cfgs):
         | 
| 860 | 
            +
                        if ds_cfg["ds_type"] == "json":
         | 
| 861 | 
            +
                            for data_file in ds_cfg["data_files"]:
         | 
| 862 | 
            +
                                data_files = {ds_cfg["split"]: data_file}
         | 
| 863 | 
            +
                                ds = load_dataset(  # pylint: disable=invalid-name
         | 
| 864 | 
            +
                                    "json",
         | 
| 865 | 
            +
                                    data_files=data_files,
         | 
| 866 | 
            +
                                    split=ds_cfg["split"],
         | 
| 867 | 
            +
                                )
         | 
| 868 | 
            +
                                split_datasets.insert(i, ds)
         | 
| 869 | 
            +
                        else:
         | 
| 870 | 
            +
                            ds = load_dataset(  # pylint: disable=invalid-name
         | 
| 871 | 
            +
                                ds_cfg["path"],
         | 
| 872 | 
            +
                                split=ds_cfg["split"],
         | 
| 873 | 
            +
                            )
         | 
| 874 | 
            +
                            split_datasets.insert(i, ds)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    for i, data_set in enumerate(split_datasets):
         | 
| 877 | 
            +
                        _type = dataset_cfgs[i]["type"]
         | 
| 878 | 
            +
                        if _type:
         | 
| 879 | 
            +
                            ds_transform_fn = load_dpo(_type, _cfg)
         | 
| 880 | 
            +
                            split_datasets[i] = data_set.map(
         | 
| 881 | 
            +
                                ds_transform_fn,
         | 
| 882 | 
            +
                                desc="Mapping RL Dataset",
         | 
| 883 | 
            +
                            )
         | 
| 884 | 
            +
                        else:
         | 
| 885 | 
            +
                            # If no `type` is provided, assume the dataset is already in the expected format with
         | 
| 886 | 
            +
                            # "prompt", "chosen" and "rejected" already preprocessed
         | 
| 887 | 
            +
                            split_datasets[i] = data_set
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                    return concatenate_datasets(split_datasets)
         | 
| 890 | 
            +
             | 
| 891 | 
            +
                with zero_first(is_main_process()):
         | 
| 892 | 
            +
                    train_dataset = load_split(cfg.datasets, cfg)
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                    eval_dataset = None
         | 
| 895 | 
            +
                    if cfg.test_datasets:
         | 
| 896 | 
            +
                        eval_dataset = load_split(cfg.test_datasets, cfg)
         | 
| 897 | 
            +
                    if not eval_dataset:
         | 
| 898 | 
            +
                        eval_dataset = None
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                return train_dataset, eval_dataset
         | 
| @@ -682,7 +682,12 @@ def load_model( | |
| 682 |  | 
| 683 | 
             
                lora_config = None
         | 
| 684 | 
             
                if not reference_model or cfg.lora_model_dir:
         | 
| 685 | 
            -
                    model,  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 686 |  | 
| 687 | 
             
                if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
         | 
| 688 | 
             
                    model.to(f"cuda:{cfg.local_rank}")
         | 
| @@ -770,8 +775,8 @@ def find_all_linear_names(model): | |
| 770 | 
             
                return list(lora_module_names)
         | 
| 771 |  | 
| 772 |  | 
| 773 | 
            -
            def load_lora(model, cfg, inference=False):
         | 
| 774 | 
            -
                # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
         | 
| 775 |  | 
| 776 | 
             
                from peft import LoraConfig, PeftModel, get_peft_model
         | 
| 777 |  | 
| @@ -794,6 +799,9 @@ def load_lora(model, cfg, inference=False): | |
| 794 | 
             
                    task_type="CAUSAL_LM",
         | 
| 795 | 
             
                )
         | 
| 796 |  | 
|  | |
|  | |
|  | |
| 797 | 
             
                if cfg.lora_model_dir:
         | 
| 798 | 
             
                    LOG.debug("Loading pretained PEFT - LoRA")
         | 
| 799 | 
             
                    model_kwargs: Any = {}
         | 
|  | |
| 682 |  | 
| 683 | 
             
                lora_config = None
         | 
| 684 | 
             
                if not reference_model or cfg.lora_model_dir:
         | 
| 685 | 
            +
                    # if we're not loading the reference model, then we're loading the model for training
         | 
| 686 | 
            +
                    # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
         | 
| 687 | 
            +
                    if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
         | 
| 688 | 
            +
                        _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
         | 
| 689 | 
            +
                    else:
         | 
| 690 | 
            +
                        model, lora_config = load_adapter(model, cfg, cfg.adapter)
         | 
| 691 |  | 
| 692 | 
             
                if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
         | 
| 693 | 
             
                    model.to(f"cuda:{cfg.local_rank}")
         | 
|  | |
| 775 | 
             
                return list(lora_module_names)
         | 
| 776 |  | 
| 777 |  | 
| 778 | 
            +
            def load_lora(model, cfg, inference=False, config_only=False):
         | 
| 779 | 
            +
                # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
         | 
| 780 |  | 
| 781 | 
             
                from peft import LoraConfig, PeftModel, get_peft_model
         | 
| 782 |  | 
|  | |
| 799 | 
             
                    task_type="CAUSAL_LM",
         | 
| 800 | 
             
                )
         | 
| 801 |  | 
| 802 | 
            +
                if config_only:
         | 
| 803 | 
            +
                    return None, lora_config
         | 
| 804 | 
            +
             | 
| 805 | 
             
                if cfg.lora_model_dir:
         | 
| 806 | 
             
                    LOG.debug("Loading pretained PEFT - LoRA")
         | 
| 807 | 
             
                    model_kwargs: Any = {}
         | 
| @@ -316,9 +316,10 @@ def prepare_optim_env(cfg): | |
| 316 |  | 
| 317 |  | 
| 318 | 
             
            def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
         | 
| 319 | 
            -
                if cfg.rl:
         | 
| 320 | 
             
                    trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
         | 
| 321 | 
             
                    trainer_builder.model_ref = model[1]
         | 
|  | |
| 322 | 
             
                else:
         | 
| 323 | 
             
                    trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
         | 
| 324 |  | 
|  | |
| 316 |  | 
| 317 |  | 
| 318 | 
             
            def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
         | 
| 319 | 
            +
                if cfg.rl in ["dpo", "ipo", "kto_pair"]:
         | 
| 320 | 
             
                    trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
         | 
| 321 | 
             
                    trainer_builder.model_ref = model[1]
         | 
| 322 | 
            +
                    trainer_builder.peft_config = model[2]
         | 
| 323 | 
             
                else:
         | 
| 324 | 
             
                    trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
         | 
| 325 |  | 
| @@ -0,0 +1,157 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            E2E tests for lora llama
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import unittest
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from axolotl.cli import load_rl_datasets
         | 
| 11 | 
            +
            from axolotl.common.cli import TrainerCliArgs
         | 
| 12 | 
            +
            from axolotl.train import train
         | 
| 13 | 
            +
            from axolotl.utils.config import normalize_config
         | 
| 14 | 
            +
            from axolotl.utils.dict import DictDefault
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .utils import with_temp_dir
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            LOG = logging.getLogger("axolotl.tests.e2e")
         | 
| 19 | 
            +
            os.environ["WANDB_DISABLED"] = "true"
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class TestDPOLlamaLora(unittest.TestCase):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Test case for DPO Llama models using LoRA
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                @with_temp_dir
         | 
| 28 | 
            +
                def test_dpo_lora(self, temp_dir):
         | 
| 29 | 
            +
                    # pylint: disable=duplicate-code
         | 
| 30 | 
            +
                    cfg = DictDefault(
         | 
| 31 | 
            +
                        {
         | 
| 32 | 
            +
                            "base_model": "JackFram/llama-68m",
         | 
| 33 | 
            +
                            "tokenizer_type": "LlamaTokenizer",
         | 
| 34 | 
            +
                            "sequence_len": 1024,
         | 
| 35 | 
            +
                            "load_in_8bit": True,
         | 
| 36 | 
            +
                            "adapter": "lora",
         | 
| 37 | 
            +
                            "lora_r": 64,
         | 
| 38 | 
            +
                            "lora_alpha": 32,
         | 
| 39 | 
            +
                            "lora_dropout": 0.1,
         | 
| 40 | 
            +
                            "lora_target_linear": True,
         | 
| 41 | 
            +
                            "special_tokens": {},
         | 
| 42 | 
            +
                            "rl": "dpo",
         | 
| 43 | 
            +
                            "datasets": [
         | 
| 44 | 
            +
                                {
         | 
| 45 | 
            +
                                    "path": "Intel/orca_dpo_pairs",
         | 
| 46 | 
            +
                                    "type": "chatml.intel",
         | 
| 47 | 
            +
                                    "split": "train",
         | 
| 48 | 
            +
                                },
         | 
| 49 | 
            +
                            ],
         | 
| 50 | 
            +
                            "num_epochs": 1,
         | 
| 51 | 
            +
                            "micro_batch_size": 4,
         | 
| 52 | 
            +
                            "gradient_accumulation_steps": 1,
         | 
| 53 | 
            +
                            "output_dir": temp_dir,
         | 
| 54 | 
            +
                            "learning_rate": 0.00001,
         | 
| 55 | 
            +
                            "optimizer": "paged_adamw_8bit",
         | 
| 56 | 
            +
                            "lr_scheduler": "cosine",
         | 
| 57 | 
            +
                            "max_steps": 20,
         | 
| 58 | 
            +
                            "save_steps": 10,
         | 
| 59 | 
            +
                            "warmup_steps": 5,
         | 
| 60 | 
            +
                            "gradient_checkpointing": True,
         | 
| 61 | 
            +
                            "gradient_checkpointing_kwargs": {"use_reentrant": True},
         | 
| 62 | 
            +
                        }
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
                    normalize_config(cfg)
         | 
| 65 | 
            +
                    cli_args = TrainerCliArgs()
         | 
| 66 | 
            +
                    dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         | 
| 69 | 
            +
                    assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                @with_temp_dir
         | 
| 72 | 
            +
                def test_kto_pair_lora(self, temp_dir):
         | 
| 73 | 
            +
                    # pylint: disable=duplicate-code
         | 
| 74 | 
            +
                    cfg = DictDefault(
         | 
| 75 | 
            +
                        {
         | 
| 76 | 
            +
                            "base_model": "JackFram/llama-68m",
         | 
| 77 | 
            +
                            "tokenizer_type": "LlamaTokenizer",
         | 
| 78 | 
            +
                            "sequence_len": 1024,
         | 
| 79 | 
            +
                            "load_in_8bit": True,
         | 
| 80 | 
            +
                            "adapter": "lora",
         | 
| 81 | 
            +
                            "lora_r": 64,
         | 
| 82 | 
            +
                            "lora_alpha": 32,
         | 
| 83 | 
            +
                            "lora_dropout": 0.1,
         | 
| 84 | 
            +
                            "lora_target_linear": True,
         | 
| 85 | 
            +
                            "special_tokens": {},
         | 
| 86 | 
            +
                            "rl": "kto_pair",
         | 
| 87 | 
            +
                            "datasets": [
         | 
| 88 | 
            +
                                {
         | 
| 89 | 
            +
                                    "path": "Intel/orca_dpo_pairs",
         | 
| 90 | 
            +
                                    "type": "chatml.intel",
         | 
| 91 | 
            +
                                    "split": "train",
         | 
| 92 | 
            +
                                },
         | 
| 93 | 
            +
                            ],
         | 
| 94 | 
            +
                            "num_epochs": 1,
         | 
| 95 | 
            +
                            "micro_batch_size": 4,
         | 
| 96 | 
            +
                            "gradient_accumulation_steps": 1,
         | 
| 97 | 
            +
                            "output_dir": temp_dir,
         | 
| 98 | 
            +
                            "learning_rate": 0.00001,
         | 
| 99 | 
            +
                            "optimizer": "paged_adamw_8bit",
         | 
| 100 | 
            +
                            "lr_scheduler": "cosine",
         | 
| 101 | 
            +
                            "max_steps": 20,
         | 
| 102 | 
            +
                            "save_steps": 10,
         | 
| 103 | 
            +
                            "warmup_steps": 5,
         | 
| 104 | 
            +
                            "gradient_checkpointing": True,
         | 
| 105 | 
            +
                            "gradient_checkpointing_kwargs": {"use_reentrant": True},
         | 
| 106 | 
            +
                        }
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
                    normalize_config(cfg)
         | 
| 109 | 
            +
                    cli_args = TrainerCliArgs()
         | 
| 110 | 
            +
                    dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         | 
| 113 | 
            +
                    assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @with_temp_dir
         | 
| 116 | 
            +
                def test_ipo_lora(self, temp_dir):
         | 
| 117 | 
            +
                    # pylint: disable=duplicate-code
         | 
| 118 | 
            +
                    cfg = DictDefault(
         | 
| 119 | 
            +
                        {
         | 
| 120 | 
            +
                            "base_model": "JackFram/llama-68m",
         | 
| 121 | 
            +
                            "tokenizer_type": "LlamaTokenizer",
         | 
| 122 | 
            +
                            "sequence_len": 1024,
         | 
| 123 | 
            +
                            "load_in_8bit": True,
         | 
| 124 | 
            +
                            "adapter": "lora",
         | 
| 125 | 
            +
                            "lora_r": 64,
         | 
| 126 | 
            +
                            "lora_alpha": 32,
         | 
| 127 | 
            +
                            "lora_dropout": 0.1,
         | 
| 128 | 
            +
                            "lora_target_linear": True,
         | 
| 129 | 
            +
                            "special_tokens": {},
         | 
| 130 | 
            +
                            "rl": "ipo",
         | 
| 131 | 
            +
                            "datasets": [
         | 
| 132 | 
            +
                                {
         | 
| 133 | 
            +
                                    "path": "Intel/orca_dpo_pairs",
         | 
| 134 | 
            +
                                    "type": "chatml.intel",
         | 
| 135 | 
            +
                                    "split": "train",
         | 
| 136 | 
            +
                                },
         | 
| 137 | 
            +
                            ],
         | 
| 138 | 
            +
                            "num_epochs": 1,
         | 
| 139 | 
            +
                            "micro_batch_size": 4,
         | 
| 140 | 
            +
                            "gradient_accumulation_steps": 1,
         | 
| 141 | 
            +
                            "output_dir": temp_dir,
         | 
| 142 | 
            +
                            "learning_rate": 0.00001,
         | 
| 143 | 
            +
                            "optimizer": "paged_adamw_8bit",
         | 
| 144 | 
            +
                            "lr_scheduler": "cosine",
         | 
| 145 | 
            +
                            "max_steps": 20,
         | 
| 146 | 
            +
                            "save_steps": 10,
         | 
| 147 | 
            +
                            "warmup_steps": 5,
         | 
| 148 | 
            +
                            "gradient_checkpointing": True,
         | 
| 149 | 
            +
                            "gradient_checkpointing_kwargs": {"use_reentrant": True},
         | 
| 150 | 
            +
                        }
         | 
| 151 | 
            +
                    )
         | 
| 152 | 
            +
                    normalize_config(cfg)
         | 
| 153 | 
            +
                    cli_args = TrainerCliArgs()
         | 
| 154 | 
            +
                    dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         | 
| 157 | 
            +
                    assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
         | 
 
		