ORPO Trainer replacement (#1551)
Browse files* WIP use trl ORPOTrainer
* fixes to make orpo work with trl
* fix the chat template laoding
* make sure to handle the special tokens and add_generation for assistant turn too
- requirements.txt +1 -1
- src/axolotl/cli/preprocess.py +1 -1
- src/axolotl/cli/train.py +1 -1
- src/axolotl/core/trainer_builder.py +36 -11
- src/axolotl/prompt_strategies/orpo/__init__.py +1 -1
- src/axolotl/prompt_strategies/orpo/chat_template.py +84 -0
- src/axolotl/utils/data/__init__.py +1 -1
- src/axolotl/utils/data/{dpo.py → rl.py} +20 -4
- src/axolotl/utils/trainer.py +3 -3
- tests/core/test_trainer_builder.py +3 -3
requirements.txt
CHANGED
|
@@ -39,6 +39,6 @@ s3fs
|
|
| 39 |
gcsfs
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
-
trl
|
| 43 |
zstandard==0.22.0
|
| 44 |
fastcore
|
|
|
|
| 39 |
gcsfs
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
+
trl==0.8.5
|
| 43 |
zstandard==0.22.0
|
| 44 |
fastcore
|
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
| 54 |
LOG.warning(msg)
|
| 55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 56 |
|
| 57 |
-
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
| 58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 59 |
else:
|
| 60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
|
| 54 |
LOG.warning(msg)
|
| 55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 56 |
|
| 57 |
+
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
| 58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 59 |
else:
|
| 60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
src/axolotl/cli/train.py
CHANGED
|
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
| 47 |
else:
|
| 48 |
register_chatml_template()
|
| 49 |
|
| 50 |
-
if cfg.rl and cfg.rl != "orpo":
|
| 51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 52 |
else:
|
| 53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
| 47 |
else:
|
| 48 |
register_chatml_template()
|
| 49 |
|
| 50 |
+
if cfg.rl: # and cfg.rl != "orpo":
|
| 51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 52 |
else:
|
| 53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -30,7 +30,7 @@ from transformers import (
|
|
| 30 |
)
|
| 31 |
from transformers.trainer_utils import seed_worker
|
| 32 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 33 |
-
from trl import DPOTrainer
|
| 34 |
from trl.trainer.utils import pad_to_length
|
| 35 |
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
@@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|
| 810 |
return res
|
| 811 |
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
class TrainerBuilderBase(abc.ABC):
|
| 814 |
"""
|
| 815 |
Base class for trainer builder
|
|
@@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 1404 |
)
|
| 1405 |
|
| 1406 |
|
| 1407 |
-
class
|
| 1408 |
"""
|
| 1409 |
Trainer factory class for DPO Trainer
|
| 1410 |
"""
|
|
@@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
| 1497 |
# default to saving each epoch if not defined
|
| 1498 |
training_args_kwargs["save_strategy"] = "epoch"
|
| 1499 |
|
| 1500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1502 |
max_steps=self.cfg.max_steps or total_num_steps,
|
| 1503 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
|
@@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
| 1530 |
dpo_trainer_kwargs[
|
| 1531 |
"precompute_ref_log_probs"
|
| 1532 |
] = self.cfg.precompute_ref_log_probs
|
| 1533 |
-
|
| 1534 |
-
|
| 1535 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1536 |
args=training_args,
|
| 1537 |
-
beta=self.cfg.dpo_beta or 0.1,
|
| 1538 |
train_dataset=self.train_dataset,
|
| 1539 |
tokenizer=self.tokenizer,
|
| 1540 |
-
max_length=self.cfg.sequence_len,
|
| 1541 |
-
max_target_length=None,
|
| 1542 |
-
max_prompt_length=self.cfg.sequence_len,
|
| 1543 |
-
generate_during_eval=True,
|
| 1544 |
callbacks=self.get_callbacks(),
|
| 1545 |
**dpo_trainer_kwargs,
|
| 1546 |
)
|
|
|
|
| 30 |
)
|
| 31 |
from transformers.trainer_utils import seed_worker
|
| 32 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 33 |
+
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
| 34 |
from trl.trainer.utils import pad_to_length
|
| 35 |
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
|
|
| 810 |
return res
|
| 811 |
|
| 812 |
|
| 813 |
+
class AxolotlORPOTrainer(ORPOTrainer):
|
| 814 |
+
"""
|
| 815 |
+
Extend the base ORPOTrainer for axolotl helpers
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
tag_names = ["axolotl", "orpo"]
|
| 819 |
+
|
| 820 |
+
|
| 821 |
class TrainerBuilderBase(abc.ABC):
|
| 822 |
"""
|
| 823 |
Base class for trainer builder
|
|
|
|
| 1412 |
)
|
| 1413 |
|
| 1414 |
|
| 1415 |
+
class HFRLTrainerBuilder(TrainerBuilderBase):
|
| 1416 |
"""
|
| 1417 |
Trainer factory class for DPO Trainer
|
| 1418 |
"""
|
|
|
|
| 1505 |
# default to saving each epoch if not defined
|
| 1506 |
training_args_kwargs["save_strategy"] = "epoch"
|
| 1507 |
|
| 1508 |
+
if self.cfg.orpo_alpha:
|
| 1509 |
+
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
| 1510 |
+
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
| 1511 |
+
|
| 1512 |
+
training_args_cls = TrainingArguments
|
| 1513 |
+
if self.cfg.rl == "orpo":
|
| 1514 |
+
training_args_cls = ORPOConfig
|
| 1515 |
+
|
| 1516 |
+
training_args = training_args_cls(
|
| 1517 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1518 |
max_steps=self.cfg.max_steps or total_num_steps,
|
| 1519 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
|
|
|
| 1546 |
dpo_trainer_kwargs[
|
| 1547 |
"precompute_ref_log_probs"
|
| 1548 |
] = self.cfg.precompute_ref_log_probs
|
| 1549 |
+
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
| 1550 |
+
trainer_cls = AxolotlDPOTrainer
|
| 1551 |
+
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
| 1552 |
+
trainer_cls_args = [self.model, self.model_ref]
|
| 1553 |
+
|
| 1554 |
+
# these aren't used for the ORPO trainer
|
| 1555 |
+
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
| 1556 |
+
dpo_trainer_kwargs["max_target_length"] = None
|
| 1557 |
+
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
| 1558 |
+
dpo_trainer_kwargs["generate_during_eval"] = True
|
| 1559 |
+
elif self.cfg.rl == "orpo":
|
| 1560 |
+
trainer_cls = AxolotlORPOTrainer
|
| 1561 |
+
trainer_cls_args = [self.model]
|
| 1562 |
+
else:
|
| 1563 |
+
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
| 1564 |
+
dpo_trainer = trainer_cls(
|
| 1565 |
+
*trainer_cls_args,
|
| 1566 |
args=training_args,
|
|
|
|
| 1567 |
train_dataset=self.train_dataset,
|
| 1568 |
tokenizer=self.tokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1569 |
callbacks=self.get_callbacks(),
|
| 1570 |
**dpo_trainer_kwargs,
|
| 1571 |
)
|
src/axolotl/prompt_strategies/orpo/__init__.py
CHANGED
|
@@ -6,4 +6,4 @@ from functools import partial
|
|
| 6 |
|
| 7 |
from ..base import load as load_base
|
| 8 |
|
| 9 |
-
load = partial(load_base,
|
|
|
|
| 6 |
|
| 7 |
from ..base import load as load_base
|
| 8 |
|
| 9 |
+
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
src/axolotl/prompt_strategies/orpo/chat_template.py
CHANGED
|
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
|
|
| 78 |
)
|
| 79 |
return MessageList(messages=messages)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
| 83 |
"""
|
|
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
|
|
| 186 |
chat_template=self.chat_template,
|
| 187 |
tokenize=False,
|
| 188 |
), True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
)
|
| 79 |
return MessageList(messages=messages)
|
| 80 |
|
| 81 |
+
def get_prompt(self, prompt) -> MessageList:
|
| 82 |
+
"""Map the data to extract everything up to the last turn"""
|
| 83 |
+
total_msg_len = len(prompt["chosen"])
|
| 84 |
+
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
| 85 |
+
assert remainder == 0, "invalid number of turns"
|
| 86 |
+
|
| 87 |
+
messages: List[Message] = []
|
| 88 |
+
if system := prompt.get("system", None):
|
| 89 |
+
messages.append(Message(role="system", content=system, label=False))
|
| 90 |
+
for i in range(total_msg_turns):
|
| 91 |
+
if "prompt" in prompt:
|
| 92 |
+
messages.append(
|
| 93 |
+
Message(role="user", content=prompt["prompt"], label=False)
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
messages.append(
|
| 97 |
+
Message(
|
| 98 |
+
role="user",
|
| 99 |
+
content=prompt["chosen"][i * 2]["content"],
|
| 100 |
+
label=False,
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
if i < total_msg_turns - 1:
|
| 104 |
+
messages.append(
|
| 105 |
+
Message(
|
| 106 |
+
role="assistant",
|
| 107 |
+
content=prompt["chosen"][i * 2 + 1]["content"],
|
| 108 |
+
label=False,
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return MessageList(messages=messages)
|
| 113 |
+
|
| 114 |
+
def get_chosen(self, prompt) -> MessageList:
|
| 115 |
+
res = self.get_prompt(prompt)
|
| 116 |
+
res.messages.append(
|
| 117 |
+
Message(
|
| 118 |
+
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
return res
|
| 122 |
+
|
| 123 |
+
def get_rejected(self, prompt) -> MessageList:
|
| 124 |
+
res = self.get_prompt(prompt)
|
| 125 |
+
res.messages.append(
|
| 126 |
+
Message(
|
| 127 |
+
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
return res
|
| 131 |
+
|
| 132 |
|
| 133 |
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
| 134 |
"""
|
|
|
|
| 237 |
chat_template=self.chat_template,
|
| 238 |
tokenize=False,
|
| 239 |
), True
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
| 243 |
+
dataset_parser = ORPODatasetParsingStrategy()
|
| 244 |
+
|
| 245 |
+
chat_template_str = chat_templates(cfg.chat_template)
|
| 246 |
+
|
| 247 |
+
def transform_fn(sample, tokenizer=None):
|
| 248 |
+
res = {}
|
| 249 |
+
|
| 250 |
+
res["prompt"] = tokenizer.apply_chat_template(
|
| 251 |
+
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
| 252 |
+
add_generation_prompt=True,
|
| 253 |
+
chat_template=chat_template_str,
|
| 254 |
+
tokenize=False,
|
| 255 |
+
)
|
| 256 |
+
prompt_str_len = len(res["prompt"])
|
| 257 |
+
res["chosen"] = tokenizer.apply_chat_template(
|
| 258 |
+
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
| 259 |
+
add_generation_prompt=False,
|
| 260 |
+
chat_template=chat_template_str,
|
| 261 |
+
tokenize=False,
|
| 262 |
+
)[prompt_str_len:]
|
| 263 |
+
res["rejected"] = tokenizer.apply_chat_template(
|
| 264 |
+
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
| 265 |
+
add_generation_prompt=False,
|
| 266 |
+
chat_template=chat_template_str,
|
| 267 |
+
tokenize=False,
|
| 268 |
+
)[prompt_str_len:]
|
| 269 |
+
|
| 270 |
+
return res
|
| 271 |
+
|
| 272 |
+
return transform_fn
|
src/axolotl/utils/data/__init__.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
Data processing modules
|
| 3 |
"""
|
| 4 |
-
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
| 5 |
from axolotl.utils.data.pretraining import ( # noqa: F401
|
| 6 |
encode_pretraining,
|
| 7 |
wrap_pretraining_dataset,
|
| 8 |
)
|
|
|
|
| 9 |
from axolotl.utils.data.sft import ( # noqa: F401
|
| 10 |
get_dataset_wrapper,
|
| 11 |
load_prepare_datasets,
|
|
|
|
| 1 |
"""
|
| 2 |
Data processing modules
|
| 3 |
"""
|
|
|
|
| 4 |
from axolotl.utils.data.pretraining import ( # noqa: F401
|
| 5 |
encode_pretraining,
|
| 6 |
wrap_pretraining_dataset,
|
| 7 |
)
|
| 8 |
+
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
| 9 |
from axolotl.utils.data.sft import ( # noqa: F401
|
| 10 |
get_dataset_wrapper,
|
| 11 |
load_prepare_datasets,
|
src/axolotl/utils/data/{dpo.py → rl.py}
RENAMED
|
@@ -1,17 +1,20 @@
|
|
| 1 |
"""data handling specific to DPO"""
|
| 2 |
-
|
| 3 |
import logging
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, List
|
| 6 |
|
| 7 |
import yaml
|
| 8 |
-
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
| 9 |
|
| 10 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 11 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
|
|
| 12 |
from axolotl.utils.data.utils import md5
|
| 13 |
from axolotl.utils.dict import DictDefault
|
| 14 |
from axolotl.utils.distributed import is_main_process, zero_first
|
|
|
|
| 15 |
|
| 16 |
LOG = logging.getLogger("axolotl")
|
| 17 |
|
|
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
|
|
| 72 |
)
|
| 73 |
split_datasets.insert(i, ds)
|
| 74 |
|
|
|
|
| 75 |
for i, data_set in enumerate(split_datasets):
|
| 76 |
_type = dataset_cfgs[i]["type"]
|
| 77 |
if _type:
|
| 78 |
if isinstance(_type, DictDefault):
|
| 79 |
_type = "user_defined.default"
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
ds_transform_fn,
|
| 83 |
desc="Mapping RL Dataset",
|
| 84 |
)
|
|
|
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
| 87 |
# "prompt", "chosen" and "rejected" already preprocessed
|
|
|
|
| 1 |
"""data handling specific to DPO"""
|
| 2 |
+
import inspect
|
| 3 |
import logging
|
| 4 |
+
from functools import partial
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, List
|
| 7 |
|
| 8 |
import yaml
|
| 9 |
+
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
| 10 |
|
| 11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
| 13 |
+
from axolotl.prompt_strategies.orpo import load as load_orpo
|
| 14 |
from axolotl.utils.data.utils import md5
|
| 15 |
from axolotl.utils.dict import DictDefault
|
| 16 |
from axolotl.utils.distributed import is_main_process, zero_first
|
| 17 |
+
from axolotl.utils.models import load_tokenizer
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl")
|
| 20 |
|
|
|
|
| 75 |
)
|
| 76 |
split_datasets.insert(i, ds)
|
| 77 |
|
| 78 |
+
tokenizer = None
|
| 79 |
for i, data_set in enumerate(split_datasets):
|
| 80 |
_type = dataset_cfgs[i]["type"]
|
| 81 |
if _type:
|
| 82 |
if isinstance(_type, DictDefault):
|
| 83 |
_type = "user_defined.default"
|
| 84 |
+
if _cfg.rl == "orpo":
|
| 85 |
+
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
| 86 |
+
else:
|
| 87 |
+
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
| 88 |
+
sig = inspect.signature(ds_transform_fn)
|
| 89 |
+
if "tokenizer" in sig.parameters:
|
| 90 |
+
if not tokenizer:
|
| 91 |
+
tokenizer = load_tokenizer(_cfg)
|
| 92 |
+
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
| 93 |
+
|
| 94 |
+
data_set = data_set.map(
|
| 95 |
ds_transform_fn,
|
| 96 |
desc="Mapping RL Dataset",
|
| 97 |
)
|
| 98 |
+
if isinstance(data_set, DatasetDict):
|
| 99 |
+
data_set = data_set["train"]
|
| 100 |
+
split_datasets[i] = data_set
|
| 101 |
else:
|
| 102 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
| 103 |
# "prompt", "chosen" and "rejected" already preprocessed
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
|
|
| 13 |
from torch.utils.data import DataLoader, RandomSampler
|
| 14 |
from transformers.utils import is_torch_bf16_gpu_available
|
| 15 |
|
| 16 |
-
from axolotl.core.trainer_builder import HFCausalTrainerBuilder,
|
| 17 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
| 18 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 19 |
|
|
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
|
|
| 340 |
|
| 341 |
|
| 342 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 343 |
-
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
| 344 |
-
trainer_builder =
|
| 345 |
trainer_builder.model_ref = model[1]
|
| 346 |
trainer_builder.peft_config = model[2]
|
| 347 |
else:
|
|
|
|
| 13 |
from torch.utils.data import DataLoader, RandomSampler
|
| 14 |
from transformers.utils import is_torch_bf16_gpu_available
|
| 15 |
|
| 16 |
+
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
| 17 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
| 18 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 19 |
|
|
|
|
| 340 |
|
| 341 |
|
| 342 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 343 |
+
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
| 344 |
+
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
| 345 |
trainer_builder.model_ref = model[1]
|
| 346 |
trainer_builder.peft_config = model[2]
|
| 347 |
else:
|
tests/core/test_trainer_builder.py
CHANGED
|
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
-
from axolotl.core.trainer_builder import
|
| 8 |
from axolotl.utils.config import normalize_config
|
| 9 |
from axolotl.utils.dict import DictDefault
|
| 10 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
|
| 51 |
return load_model(cfg, tokenizer)
|
| 52 |
|
| 53 |
|
| 54 |
-
class
|
| 55 |
"""
|
| 56 |
TestCase class for DPO trainer builder
|
| 57 |
"""
|
| 58 |
|
| 59 |
def test_build_training_arguments(self, cfg, model, tokenizer):
|
| 60 |
-
builder =
|
| 61 |
training_arguments = builder.build_training_arguments(100)
|
| 62 |
assert training_arguments.adam_beta1 == 0.998
|
| 63 |
assert training_arguments.adam_beta2 == 0.9
|
|
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
+
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
| 8 |
from axolotl.utils.config import normalize_config
|
| 9 |
from axolotl.utils.dict import DictDefault
|
| 10 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
| 51 |
return load_model(cfg, tokenizer)
|
| 52 |
|
| 53 |
|
| 54 |
+
class TestHFRLTrainerBuilder:
|
| 55 |
"""
|
| 56 |
TestCase class for DPO trainer builder
|
| 57 |
"""
|
| 58 |
|
| 59 |
def test_build_training_arguments(self, cfg, model, tokenizer):
|
| 60 |
+
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
| 61 |
training_arguments = builder.build_training_arguments(100)
|
| 62 |
assert training_arguments.adam_beta1 == 0.998
|
| 63 |
assert training_arguments.adam_beta2 == 0.9
|