Uploaded lora adapters after finetuning on primevul
Browse files- compilefcach/UnslothCPOTrainer.py +1404 -0
- compilefcach/UnslothDDPOTrainer.py +744 -0
- compilefcach/UnslothKTOTrainer.py +1629 -0
- compilefcach/UnslothORPOTrainer.py +1413 -0
- compilefcach/UnslothPPOTrainer.py +1566 -0
- compilefcach/UnslothRewardTrainer.py +722 -0
- compilefcach/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
- compilefcach/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
- compilefcach/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
- compilefcach/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
- compilefcach/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
- compilefcach/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
compilefcach/UnslothCPOTrainer.py
ADDED
@@ -0,0 +1,1404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, List, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainerCallback, Tuple, Union, defaultdict, disable_dropout_in_model, inspect, is_peft_available, is_torch_fx_proxy, is_wandb_available, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothCPOConfig(CPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
CPOConfig collects all training arguments related to the [`CPOTrainer`] class.
|
47 |
+
|
48 |
+
Using [`HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int`, defaults to `None`):
|
54 |
+
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
55 |
+
max_prompt_length (`int`, defaults to `None`):
|
56 |
+
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
57 |
+
max_target_length (`int`, defaults to `None`):
|
58 |
+
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
59 |
+
beta (`float`, defaults to 0.1):
|
60 |
+
The beta factor in CPO loss.
|
61 |
+
label_smoothing (`float`, defaults to 0):
|
62 |
+
The label smoothing factor. This argument is required if you want to use the default data collator.
|
63 |
+
loss_type (`str`, defaults to `sigmoid`):
|
64 |
+
The type of loss to use. This argument is required if you want to use the default data collator.
|
65 |
+
label_pad_token_id (`int`, defaults to `-100`):
|
66 |
+
The label pad token id. This argument is required if you want to use the default data collator.
|
67 |
+
padding_value (`int`, defaults to `None`):
|
68 |
+
The padding value if it is different to the tokenizer's pad_token_id.
|
69 |
+
truncation_mode (`str`, defaults to `keep_end`):
|
70 |
+
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
71 |
+
generate_during_eval (`bool`, defaults to `False`):
|
72 |
+
Whether to sample and log generations during evaluation step.
|
73 |
+
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
74 |
+
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
75 |
+
disable_dropout (`bool`, defaults to `True`):
|
76 |
+
Whether or not to disable dropouts in `model`.
|
77 |
+
model_init_kwargs (`Optional[Dict]`, *optional*):
|
78 |
+
Dict of Optional kwargs to pass when instantiating the model from a string
|
79 |
+
dataset_num_proc (`Optional[int]`, *optional*):
|
80 |
+
The number of workers to use to tokenize the data. Defaults to None.
|
81 |
+
|
82 |
+
"""
|
83 |
+
vllm_sampling_params: Optional[Any] = field(
|
84 |
+
default = None,
|
85 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
86 |
+
)
|
87 |
+
unsloth_num_chunks : Optional[int] = field(
|
88 |
+
default = -1,
|
89 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
90 |
+
)
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
output_dir = None,
|
94 |
+
overwrite_output_dir = None,
|
95 |
+
do_train = False,
|
96 |
+
do_eval = False,
|
97 |
+
do_predict = False,
|
98 |
+
eval_strategy = 'no',
|
99 |
+
prediction_loss_only = False,
|
100 |
+
per_device_train_batch_size = 4,
|
101 |
+
per_device_eval_batch_size = 4,
|
102 |
+
per_gpu_train_batch_size = None,
|
103 |
+
per_gpu_eval_batch_size = None,
|
104 |
+
gradient_accumulation_steps = 2,
|
105 |
+
eval_accumulation_steps = 2,
|
106 |
+
eval_delay = 0,
|
107 |
+
torch_empty_cache_steps = 250,
|
108 |
+
learning_rate = 5e-05,
|
109 |
+
weight_decay = 0.01,
|
110 |
+
adam_beta1 = 0.9,
|
111 |
+
adam_beta2 = 0.999,
|
112 |
+
adam_epsilon = 1e-08,
|
113 |
+
max_grad_norm = 1.0,
|
114 |
+
num_train_epochs = 3.0,
|
115 |
+
max_steps = -1,
|
116 |
+
lr_scheduler_type = 'linear',
|
117 |
+
warmup_ratio = 0.1,
|
118 |
+
warmup_steps = 0,
|
119 |
+
log_level = 'passive',
|
120 |
+
log_level_replica = 'warning',
|
121 |
+
log_on_each_node = True,
|
122 |
+
logging_dir = None,
|
123 |
+
logging_strategy = 'steps',
|
124 |
+
logging_first_step = False,
|
125 |
+
logging_steps = 1,
|
126 |
+
logging_nan_inf_filter = False,
|
127 |
+
save_strategy = 'steps',
|
128 |
+
save_steps = 500,
|
129 |
+
save_total_limit = None,
|
130 |
+
save_safetensors = True,
|
131 |
+
save_on_each_node = False,
|
132 |
+
save_only_model = False,
|
133 |
+
restore_callback_states_from_checkpoint = False,
|
134 |
+
no_cuda = False,
|
135 |
+
use_cpu = False,
|
136 |
+
use_mps_device = False,
|
137 |
+
seed = 3407,
|
138 |
+
data_seed = 3407,
|
139 |
+
jit_mode_eval = False,
|
140 |
+
use_ipex = False,
|
141 |
+
bf16 = False,
|
142 |
+
fp16 = False,
|
143 |
+
fp16_opt_level = 'O1',
|
144 |
+
half_precision_backend = 'auto',
|
145 |
+
bf16_full_eval = False,
|
146 |
+
fp16_full_eval = False,
|
147 |
+
tf32 = None,
|
148 |
+
local_rank = -1,
|
149 |
+
ddp_backend = None,
|
150 |
+
tpu_num_cores = None,
|
151 |
+
tpu_metrics_debug = False,
|
152 |
+
debug = '',
|
153 |
+
dataloader_drop_last = False,
|
154 |
+
eval_steps = None,
|
155 |
+
dataloader_num_workers = 0,
|
156 |
+
dataloader_prefetch_factor = None,
|
157 |
+
past_index = -1,
|
158 |
+
run_name = None,
|
159 |
+
disable_tqdm = None,
|
160 |
+
remove_unused_columns = True,
|
161 |
+
label_names = None,
|
162 |
+
load_best_model_at_end = False,
|
163 |
+
metric_for_best_model = None,
|
164 |
+
greater_is_better = None,
|
165 |
+
ignore_data_skip = False,
|
166 |
+
fsdp = '',
|
167 |
+
fsdp_min_num_params = 0,
|
168 |
+
fsdp_config = None,
|
169 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
170 |
+
accelerator_config = None,
|
171 |
+
deepspeed = None,
|
172 |
+
label_smoothing_factor = 0.0,
|
173 |
+
optim = 'adamw_8bit',
|
174 |
+
optim_args = None,
|
175 |
+
adafactor = False,
|
176 |
+
group_by_length = False,
|
177 |
+
length_column_name = 'length',
|
178 |
+
report_to = None,
|
179 |
+
ddp_find_unused_parameters = None,
|
180 |
+
ddp_bucket_cap_mb = None,
|
181 |
+
ddp_broadcast_buffers = None,
|
182 |
+
dataloader_pin_memory = True,
|
183 |
+
dataloader_persistent_workers = False,
|
184 |
+
skip_memory_metrics = True,
|
185 |
+
use_legacy_prediction_loop = False,
|
186 |
+
push_to_hub = False,
|
187 |
+
resume_from_checkpoint = None,
|
188 |
+
hub_model_id = None,
|
189 |
+
hub_strategy = 'every_save',
|
190 |
+
hub_token = None,
|
191 |
+
hub_private_repo = None,
|
192 |
+
hub_always_push = False,
|
193 |
+
hub_revision = None,
|
194 |
+
gradient_checkpointing = False,
|
195 |
+
gradient_checkpointing_kwargs = None,
|
196 |
+
include_inputs_for_metrics = False,
|
197 |
+
eval_do_concat_batches = True,
|
198 |
+
fp16_backend = 'auto',
|
199 |
+
push_to_hub_model_id = None,
|
200 |
+
push_to_hub_organization = None,
|
201 |
+
push_to_hub_token = None,
|
202 |
+
mp_parameters = '',
|
203 |
+
auto_find_batch_size = False,
|
204 |
+
full_determinism = False,
|
205 |
+
torchdynamo = None,
|
206 |
+
ray_scope = 'last',
|
207 |
+
ddp_timeout = 1800,
|
208 |
+
torch_compile = False,
|
209 |
+
torch_compile_backend = None,
|
210 |
+
torch_compile_mode = None,
|
211 |
+
include_tokens_per_second = False,
|
212 |
+
include_num_input_tokens_seen = False,
|
213 |
+
neftune_noise_alpha = None,
|
214 |
+
optim_target_modules = None,
|
215 |
+
batch_eval_metrics = False,
|
216 |
+
eval_on_start = False,
|
217 |
+
use_liger_kernel = False,
|
218 |
+
liger_kernel_config = None,
|
219 |
+
eval_use_gather_object = False,
|
220 |
+
average_tokens_across_devices = False,
|
221 |
+
max_length = None,
|
222 |
+
max_prompt_length = None,
|
223 |
+
max_completion_length = None,
|
224 |
+
max_target_length = None,
|
225 |
+
beta = 0.1,
|
226 |
+
label_smoothing = 0,
|
227 |
+
loss_type = 'sigmoid',
|
228 |
+
disable_dropout = True,
|
229 |
+
label_pad_token_id = -100,
|
230 |
+
padding_value = None,
|
231 |
+
truncation_mode = 'keep_end',
|
232 |
+
generate_during_eval = False,
|
233 |
+
is_encoder_decoder = None,
|
234 |
+
model_init_kwargs = None,
|
235 |
+
dataset_num_proc = None,
|
236 |
+
vllm_sampling_params = None,
|
237 |
+
unsloth_num_chunks = -1,
|
238 |
+
**kwargs,
|
239 |
+
):
|
240 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
241 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
242 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
243 |
+
output_dir = 'unsloth_training_checkpoints'
|
244 |
+
save_strategy = 'no'
|
245 |
+
if dataset_num_proc is None:
|
246 |
+
from multiprocessing import cpu_count
|
247 |
+
dataset_num_proc = cpu_count()
|
248 |
+
|
249 |
+
super().__init__(
|
250 |
+
output_dir = output_dir,
|
251 |
+
overwrite_output_dir = overwrite_output_dir,
|
252 |
+
do_train = do_train,
|
253 |
+
do_eval = do_eval,
|
254 |
+
do_predict = do_predict,
|
255 |
+
eval_strategy = eval_strategy,
|
256 |
+
prediction_loss_only = prediction_loss_only,
|
257 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
258 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
259 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
260 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
261 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
262 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
263 |
+
eval_delay = eval_delay,
|
264 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
265 |
+
learning_rate = learning_rate,
|
266 |
+
weight_decay = weight_decay,
|
267 |
+
adam_beta1 = adam_beta1,
|
268 |
+
adam_beta2 = adam_beta2,
|
269 |
+
adam_epsilon = adam_epsilon,
|
270 |
+
max_grad_norm = max_grad_norm,
|
271 |
+
num_train_epochs = num_train_epochs,
|
272 |
+
max_steps = max_steps,
|
273 |
+
lr_scheduler_type = lr_scheduler_type,
|
274 |
+
warmup_ratio = warmup_ratio,
|
275 |
+
warmup_steps = warmup_steps,
|
276 |
+
log_level = log_level,
|
277 |
+
log_level_replica = log_level_replica,
|
278 |
+
log_on_each_node = log_on_each_node,
|
279 |
+
logging_dir = logging_dir,
|
280 |
+
logging_strategy = logging_strategy,
|
281 |
+
logging_first_step = logging_first_step,
|
282 |
+
logging_steps = logging_steps,
|
283 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
284 |
+
save_strategy = save_strategy,
|
285 |
+
save_steps = save_steps,
|
286 |
+
save_total_limit = save_total_limit,
|
287 |
+
save_safetensors = save_safetensors,
|
288 |
+
save_on_each_node = save_on_each_node,
|
289 |
+
save_only_model = save_only_model,
|
290 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
291 |
+
no_cuda = no_cuda,
|
292 |
+
use_cpu = use_cpu,
|
293 |
+
use_mps_device = use_mps_device,
|
294 |
+
seed = seed,
|
295 |
+
data_seed = data_seed,
|
296 |
+
jit_mode_eval = jit_mode_eval,
|
297 |
+
use_ipex = use_ipex,
|
298 |
+
bf16 = bf16,
|
299 |
+
fp16 = fp16,
|
300 |
+
fp16_opt_level = fp16_opt_level,
|
301 |
+
half_precision_backend = half_precision_backend,
|
302 |
+
bf16_full_eval = bf16_full_eval,
|
303 |
+
fp16_full_eval = fp16_full_eval,
|
304 |
+
tf32 = tf32,
|
305 |
+
local_rank = local_rank,
|
306 |
+
ddp_backend = ddp_backend,
|
307 |
+
tpu_num_cores = tpu_num_cores,
|
308 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
309 |
+
debug = debug,
|
310 |
+
dataloader_drop_last = dataloader_drop_last,
|
311 |
+
eval_steps = eval_steps,
|
312 |
+
dataloader_num_workers = dataloader_num_workers,
|
313 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
314 |
+
past_index = past_index,
|
315 |
+
run_name = run_name,
|
316 |
+
disable_tqdm = disable_tqdm,
|
317 |
+
remove_unused_columns = remove_unused_columns,
|
318 |
+
label_names = label_names,
|
319 |
+
load_best_model_at_end = load_best_model_at_end,
|
320 |
+
metric_for_best_model = metric_for_best_model,
|
321 |
+
greater_is_better = greater_is_better,
|
322 |
+
ignore_data_skip = ignore_data_skip,
|
323 |
+
fsdp = fsdp,
|
324 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
325 |
+
fsdp_config = fsdp_config,
|
326 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
327 |
+
accelerator_config = accelerator_config,
|
328 |
+
deepspeed = deepspeed,
|
329 |
+
label_smoothing_factor = label_smoothing_factor,
|
330 |
+
optim = optim,
|
331 |
+
optim_args = optim_args,
|
332 |
+
adafactor = adafactor,
|
333 |
+
group_by_length = group_by_length,
|
334 |
+
length_column_name = length_column_name,
|
335 |
+
report_to = report_to,
|
336 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
337 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
338 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
339 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
340 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
341 |
+
skip_memory_metrics = skip_memory_metrics,
|
342 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
343 |
+
push_to_hub = push_to_hub,
|
344 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
345 |
+
hub_model_id = hub_model_id,
|
346 |
+
hub_strategy = hub_strategy,
|
347 |
+
hub_token = hub_token,
|
348 |
+
hub_private_repo = hub_private_repo,
|
349 |
+
hub_always_push = hub_always_push,
|
350 |
+
hub_revision = hub_revision,
|
351 |
+
gradient_checkpointing = gradient_checkpointing,
|
352 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
353 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
354 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
355 |
+
fp16_backend = fp16_backend,
|
356 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
357 |
+
push_to_hub_organization = push_to_hub_organization,
|
358 |
+
push_to_hub_token = push_to_hub_token,
|
359 |
+
mp_parameters = mp_parameters,
|
360 |
+
auto_find_batch_size = auto_find_batch_size,
|
361 |
+
full_determinism = full_determinism,
|
362 |
+
torchdynamo = torchdynamo,
|
363 |
+
ray_scope = ray_scope,
|
364 |
+
ddp_timeout = ddp_timeout,
|
365 |
+
torch_compile = torch_compile,
|
366 |
+
torch_compile_backend = torch_compile_backend,
|
367 |
+
torch_compile_mode = torch_compile_mode,
|
368 |
+
include_tokens_per_second = include_tokens_per_second,
|
369 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
370 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
371 |
+
optim_target_modules = optim_target_modules,
|
372 |
+
batch_eval_metrics = batch_eval_metrics,
|
373 |
+
eval_on_start = eval_on_start,
|
374 |
+
use_liger_kernel = use_liger_kernel,
|
375 |
+
liger_kernel_config = liger_kernel_config,
|
376 |
+
eval_use_gather_object = eval_use_gather_object,
|
377 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
378 |
+
max_length = max_length,
|
379 |
+
max_prompt_length = max_prompt_length,
|
380 |
+
max_completion_length = max_completion_length,
|
381 |
+
max_target_length = max_target_length,
|
382 |
+
beta = beta,
|
383 |
+
label_smoothing = label_smoothing,
|
384 |
+
loss_type = loss_type,
|
385 |
+
disable_dropout = disable_dropout,
|
386 |
+
label_pad_token_id = label_pad_token_id,
|
387 |
+
padding_value = padding_value,
|
388 |
+
truncation_mode = truncation_mode,
|
389 |
+
generate_during_eval = generate_during_eval,
|
390 |
+
is_encoder_decoder = is_encoder_decoder,
|
391 |
+
model_init_kwargs = model_init_kwargs,
|
392 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
393 |
+
self.vllm_sampling_params = vllm_sampling_params
|
394 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
395 |
+
pass
|
396 |
+
|
397 |
+
class _UnslothCPOTrainer(Trainer):
|
398 |
+
r""""""
|
399 |
+
|
400 |
+
_tag_names = ["trl", "cpo"]
|
401 |
+
|
402 |
+
def __init__(
|
403 |
+
self,
|
404 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
405 |
+
args: Optional[CPOConfig] = None,
|
406 |
+
data_collator: Optional[DataCollator] = None,
|
407 |
+
train_dataset: Optional[Dataset] = None,
|
408 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
409 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
410 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
411 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
412 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
413 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
414 |
+
peft_config: Optional[Dict] = None,
|
415 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
416 |
+
):
|
417 |
+
if args.model_init_kwargs is None:
|
418 |
+
model_init_kwargs = {}
|
419 |
+
elif not isinstance(model, str):
|
420 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
421 |
+
else:
|
422 |
+
model_init_kwargs = args.model_init_kwargs
|
423 |
+
model_init_kwargs["torch_dtype"] = (
|
424 |
+
model_init_kwargs["torch_dtype"]
|
425 |
+
if model_init_kwargs["torch_dtype"] in ["auto", None]
|
426 |
+
else getattr(torch, model_init_kwargs["torch_dtype"])
|
427 |
+
)
|
428 |
+
|
429 |
+
if isinstance(model, str):
|
430 |
+
warnings.warn(
|
431 |
+
"You passed a model_id to the CPOTrainer. This will automatically create an "
|
432 |
+
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
|
433 |
+
)
|
434 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
435 |
+
|
436 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
437 |
+
# has been called in order to properly call autocast if needed.
|
438 |
+
self._peft_has_been_casted_to_bf16 = False
|
439 |
+
|
440 |
+
if not is_peft_available() and peft_config is not None:
|
441 |
+
raise ValueError(
|
442 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
443 |
+
)
|
444 |
+
elif is_peft_available() and peft_config is not None:
|
445 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
446 |
+
if isinstance(model, PeftModel):
|
447 |
+
model = model.merge_and_unload()
|
448 |
+
|
449 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
450 |
+
_support_gc_kwargs = hasattr(
|
451 |
+
args, "gradient_checkpointing_kwargs"
|
452 |
+
) and "gradient_checkpointing_kwargs" in list(
|
453 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
454 |
+
)
|
455 |
+
|
456 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
457 |
+
|
458 |
+
if _support_gc_kwargs:
|
459 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
460 |
+
|
461 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
462 |
+
elif getattr(args, "gradient_checkpointing", False):
|
463 |
+
# For backward compatibility with older versions of transformers
|
464 |
+
if hasattr(model, "enable_input_require_grads"):
|
465 |
+
model.enable_input_require_grads()
|
466 |
+
else:
|
467 |
+
|
468 |
+
def make_inputs_require_grad(module, input, output):
|
469 |
+
output.requires_grad_(True)
|
470 |
+
|
471 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
472 |
+
|
473 |
+
# get peft model with the given config
|
474 |
+
model = model
|
475 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
476 |
+
peft_module_casting_to_bf16(model)
|
477 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
478 |
+
self._peft_has_been_casted_to_bf16 = True
|
479 |
+
|
480 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
481 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
482 |
+
# fail or completely fail.
|
483 |
+
elif getattr(args, "gradient_checkpointing", False):
|
484 |
+
# For backward compatibility with older versions of transformers
|
485 |
+
if hasattr(model, "enable_input_require_grads"):
|
486 |
+
model.enable_input_require_grads()
|
487 |
+
else:
|
488 |
+
|
489 |
+
def make_inputs_require_grad(module, input, output):
|
490 |
+
output.requires_grad_(True)
|
491 |
+
|
492 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
493 |
+
|
494 |
+
if args.generate_during_eval and not is_wandb_available():
|
495 |
+
raise ValueError(
|
496 |
+
"`generate_during_eval=True` requires Weights and Biases to be installed."
|
497 |
+
" Please install `wandb` to resolve."
|
498 |
+
)
|
499 |
+
|
500 |
+
if model is not None:
|
501 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
502 |
+
elif args.is_encoder_decoder is None:
|
503 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
504 |
+
else:
|
505 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
506 |
+
|
507 |
+
if self.is_encoder_decoder:
|
508 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
509 |
+
self.pad_token_id = model.config.pad_token_id
|
510 |
+
|
511 |
+
if tokenizer is None:
|
512 |
+
raise ValueError("tokenizer must be specified to tokenize a CPO dataset.")
|
513 |
+
if args.max_length is None:
|
514 |
+
warnings.warn(
|
515 |
+
"`max_length` is not set in the CPOConfig's init"
|
516 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
517 |
+
UserWarning,
|
518 |
+
)
|
519 |
+
max_length = 512
|
520 |
+
else:
|
521 |
+
max_length = args.max_length
|
522 |
+
if args.max_prompt_length is None:
|
523 |
+
warnings.warn(
|
524 |
+
"`max_prompt_length` is not set in the CPOConfig's init"
|
525 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
526 |
+
UserWarning,
|
527 |
+
)
|
528 |
+
max_prompt_length = 128
|
529 |
+
else:
|
530 |
+
max_prompt_length = args.max_prompt_length
|
531 |
+
|
532 |
+
if args.max_target_length is None and self.is_encoder_decoder:
|
533 |
+
warnings.warn(
|
534 |
+
"When using an encoder decoder architecture, you should set `max_target_length` in the CPOConfig's init"
|
535 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
536 |
+
UserWarning,
|
537 |
+
)
|
538 |
+
max_target_length = 128
|
539 |
+
else:
|
540 |
+
max_target_length = args.max_target_length
|
541 |
+
|
542 |
+
if data_collator is None:
|
543 |
+
data_collator = DPODataCollatorWithPadding(
|
544 |
+
pad_token_id=tokenizer.pad_token_id,
|
545 |
+
label_pad_token_id=args.label_pad_token_id,
|
546 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
547 |
+
)
|
548 |
+
|
549 |
+
if args.remove_unused_columns:
|
550 |
+
args.remove_unused_columns = False
|
551 |
+
# warn users
|
552 |
+
warnings.warn(
|
553 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
554 |
+
" we have set it for you, but you should do it yourself in the future.",
|
555 |
+
UserWarning,
|
556 |
+
)
|
557 |
+
|
558 |
+
self.use_dpo_data_collator = True
|
559 |
+
else:
|
560 |
+
self.use_dpo_data_collator = False
|
561 |
+
|
562 |
+
if args.disable_dropout:
|
563 |
+
disable_dropout_in_model(model)
|
564 |
+
|
565 |
+
self.max_length = max_length
|
566 |
+
self.generate_during_eval = args.generate_during_eval
|
567 |
+
self.label_pad_token_id = args.label_pad_token_id
|
568 |
+
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
|
569 |
+
self.max_prompt_length = max_prompt_length
|
570 |
+
self.truncation_mode = args.truncation_mode
|
571 |
+
self.max_target_length = max_target_length
|
572 |
+
self.tokenizer = tokenizer
|
573 |
+
|
574 |
+
if args.loss_type in ["hinge", "ipo", "kto_pair"] and args.label_smoothing > 0:
|
575 |
+
warnings.warn(
|
576 |
+
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
|
577 |
+
)
|
578 |
+
|
579 |
+
self.beta = args.beta
|
580 |
+
self.label_smoothing = args.label_smoothing
|
581 |
+
self.loss_type = args.loss_type
|
582 |
+
|
583 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
584 |
+
|
585 |
+
# Compute that only on the main process for faster data processing.
|
586 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
587 |
+
with PartialState().local_main_process_first():
|
588 |
+
# tokenize the dataset
|
589 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
590 |
+
if eval_dataset is not None:
|
591 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
592 |
+
|
593 |
+
super().__init__(
|
594 |
+
model=model,
|
595 |
+
args=args,
|
596 |
+
data_collator=data_collator,
|
597 |
+
train_dataset=train_dataset,
|
598 |
+
eval_dataset=eval_dataset,
|
599 |
+
tokenizer=tokenizer,
|
600 |
+
model_init=model_init,
|
601 |
+
compute_metrics=compute_metrics,
|
602 |
+
callbacks=callbacks,
|
603 |
+
optimizers=optimizers,
|
604 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
605 |
+
)
|
606 |
+
|
607 |
+
# Add tags for models that have been loaded with the correct transformers version
|
608 |
+
if hasattr(self.model, "add_model_tags"):
|
609 |
+
self.model.add_model_tags(self._tag_names)
|
610 |
+
|
611 |
+
if not hasattr(self, "accelerator"):
|
612 |
+
raise AttributeError(
|
613 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
614 |
+
)
|
615 |
+
|
616 |
+
def build_tokenized_answer(self, prompt, answer):
|
617 |
+
"""
|
618 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
619 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
620 |
+
Reference:
|
621 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
622 |
+
"""
|
623 |
+
|
624 |
+
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
|
625 |
+
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
626 |
+
|
627 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
628 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
629 |
+
|
630 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
631 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
632 |
+
|
633 |
+
# Prepare input tokens for token by token comparison
|
634 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
635 |
+
|
636 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
637 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
638 |
+
|
639 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
640 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
641 |
+
# on the last token from the prompt being different when tokenized on its own
|
642 |
+
# vs when done as prompt+answer.
|
643 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
644 |
+
|
645 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
646 |
+
# last token has changed due to merging.
|
647 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
648 |
+
response_token_ids_start_idx -= 1
|
649 |
+
|
650 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
651 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
652 |
+
|
653 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
654 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
655 |
+
|
656 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
657 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
658 |
+
|
659 |
+
return dict(
|
660 |
+
prompt_input_ids=prompt_input_ids,
|
661 |
+
prompt_attention_mask=prompt_attention_mask,
|
662 |
+
input_ids=answer_input_ids,
|
663 |
+
attention_mask=answer_attention_mask,
|
664 |
+
)
|
665 |
+
|
666 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
|
667 |
+
"""Tokenize a single row from a CPO specific dataset.
|
668 |
+
|
669 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
670 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
671 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
672 |
+
|
673 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
674 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
675 |
+
label_pad_token_id for the prompt tokens.
|
676 |
+
"""
|
677 |
+
batch = {}
|
678 |
+
prompt = feature["prompt"]
|
679 |
+
chosen = feature["chosen"]
|
680 |
+
rejected = feature["rejected"]
|
681 |
+
|
682 |
+
if not self.is_encoder_decoder:
|
683 |
+
# Check issues below for more details
|
684 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
685 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
686 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
687 |
+
|
688 |
+
if not isinstance(prompt, str):
|
689 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
690 |
+
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
|
691 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
692 |
+
|
693 |
+
if not isinstance(chosen, str):
|
694 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
695 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
696 |
+
|
697 |
+
if not isinstance(rejected, str):
|
698 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
699 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
700 |
+
|
701 |
+
# Last prompt token might get merged by tokenizer and
|
702 |
+
# it should not be included for generation if that happens
|
703 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
704 |
+
|
705 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
706 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
707 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
708 |
+
|
709 |
+
for k, v in prompt_tokens.items():
|
710 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
711 |
+
|
712 |
+
# Make sure prompts only have one different token at most an
|
713 |
+
# and length only differs by 1 at most
|
714 |
+
num_diff_tokens = sum(
|
715 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
716 |
+
)
|
717 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
718 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
719 |
+
raise ValueError(
|
720 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
721 |
+
"last token due to tokenizer merge ops."
|
722 |
+
)
|
723 |
+
|
724 |
+
# add BOS token to head of prompt
|
725 |
+
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
|
726 |
+
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
|
727 |
+
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
|
728 |
+
|
729 |
+
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
|
730 |
+
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
|
731 |
+
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
|
732 |
+
|
733 |
+
# add EOS token to end of answer
|
734 |
+
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
735 |
+
chosen_tokens["attention_mask"].append(1)
|
736 |
+
|
737 |
+
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
738 |
+
rejected_tokens["attention_mask"].append(1)
|
739 |
+
|
740 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
741 |
+
|
742 |
+
# if combined sequence is too long, truncate the prompt
|
743 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
744 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
745 |
+
if self.truncation_mode == "keep_start":
|
746 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
747 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
748 |
+
elif self.truncation_mode == "keep_end":
|
749 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
750 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
751 |
+
else:
|
752 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
753 |
+
|
754 |
+
# if that's still too long, truncate the response
|
755 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
756 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
757 |
+
for k in ["input_ids", "attention_mask"]:
|
758 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
759 |
+
|
760 |
+
# Create labels
|
761 |
+
chosen_sequence_tokens = {
|
762 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
763 |
+
}
|
764 |
+
rejected_sequence_tokens = {
|
765 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
766 |
+
}
|
767 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
768 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
769 |
+
self.label_pad_token_id
|
770 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
771 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
772 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
773 |
+
self.label_pad_token_id
|
774 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
775 |
+
|
776 |
+
for k, toks in {
|
777 |
+
"chosen_": chosen_sequence_tokens,
|
778 |
+
"rejected_": rejected_sequence_tokens,
|
779 |
+
"": prompt_tokens,
|
780 |
+
}.items():
|
781 |
+
for type_key, tokens in toks.items():
|
782 |
+
if type_key == "token_type_ids":
|
783 |
+
continue
|
784 |
+
batch[f"{k}{type_key}"] = tokens
|
785 |
+
|
786 |
+
else:
|
787 |
+
chosen_tokens = self.tokenizer(
|
788 |
+
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
|
789 |
+
)
|
790 |
+
rejected_tokens = self.tokenizer(
|
791 |
+
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
|
792 |
+
)
|
793 |
+
prompt_tokens = self.tokenizer(
|
794 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
795 |
+
)
|
796 |
+
|
797 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
798 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
799 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
800 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
801 |
+
|
802 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
803 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
804 |
+
labels=torch.tensor(batch["rejected_labels"])
|
805 |
+
)
|
806 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
807 |
+
labels=torch.tensor(batch["chosen_labels"])
|
808 |
+
)
|
809 |
+
|
810 |
+
return batch
|
811 |
+
|
812 |
+
@staticmethod
|
813 |
+
def concatenated_inputs(
|
814 |
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
815 |
+
is_encoder_decoder: bool = False,
|
816 |
+
label_pad_token_id: int = -100,
|
817 |
+
padding_value: int = 0,
|
818 |
+
device: Optional[torch.device] = None,
|
819 |
+
) -> Dict[str, torch.LongTensor]:
|
820 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
821 |
+
|
822 |
+
Args:
|
823 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
824 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
825 |
+
label_pad_token_id: The label pad token id.
|
826 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
827 |
+
device: The device for the concatenated inputs.
|
828 |
+
|
829 |
+
Returns:
|
830 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
831 |
+
"""
|
832 |
+
concatenated_batch = {}
|
833 |
+
|
834 |
+
if is_encoder_decoder:
|
835 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
836 |
+
else:
|
837 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
838 |
+
|
839 |
+
for k in batch:
|
840 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
841 |
+
if "labels" in k or is_encoder_decoder:
|
842 |
+
pad_value = label_pad_token_id
|
843 |
+
elif k.endswith("_input_ids"):
|
844 |
+
pad_value = padding_value
|
845 |
+
elif k.endswith("_attention_mask"):
|
846 |
+
pad_value = 0
|
847 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
848 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
849 |
+
for k in batch:
|
850 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
851 |
+
if "labels" in k or is_encoder_decoder:
|
852 |
+
pad_value = label_pad_token_id
|
853 |
+
elif k.endswith("_input_ids"):
|
854 |
+
pad_value = padding_value
|
855 |
+
elif k.endswith("_attention_mask"):
|
856 |
+
pad_value = 0
|
857 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
858 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
859 |
+
(
|
860 |
+
concatenated_batch[concatenated_key],
|
861 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
862 |
+
),
|
863 |
+
dim=0,
|
864 |
+
).to(device=device)
|
865 |
+
|
866 |
+
if is_encoder_decoder:
|
867 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
868 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
869 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
870 |
+
)
|
871 |
+
|
872 |
+
return concatenated_batch
|
873 |
+
|
874 |
+
def cpo_loss(
|
875 |
+
self,
|
876 |
+
policy_chosen_logps: torch.FloatTensor,
|
877 |
+
policy_rejected_logps: torch.FloatTensor,
|
878 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
879 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
880 |
+
|
881 |
+
Args:
|
882 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
883 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
884 |
+
|
885 |
+
Returns:
|
886 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
887 |
+
The losses tensor contains the CPO loss for each example in the batch.
|
888 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
889 |
+
"""
|
890 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
891 |
+
|
892 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
893 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
894 |
+
# calculates a conservative CPO loss.
|
895 |
+
if self.loss_type == "sigmoid":
|
896 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
897 |
+
losses = (
|
898 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
899 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
900 |
+
)
|
901 |
+
elif self.loss_type == "hinge":
|
902 |
+
losses = torch.relu(1 - self.beta * logits)
|
903 |
+
elif self.loss_type == "ipo":
|
904 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
905 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
906 |
+
else:
|
907 |
+
raise ValueError(
|
908 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
|
909 |
+
)
|
910 |
+
|
911 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
912 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
913 |
+
|
914 |
+
return losses, chosen_rewards, rejected_rewards
|
915 |
+
|
916 |
+
@staticmethod
|
917 |
+
def get_batch_logps(
|
918 |
+
logits: torch.FloatTensor,
|
919 |
+
labels: torch.LongTensor,
|
920 |
+
average_log_prob: bool = False,
|
921 |
+
label_pad_token_id: int = -100,
|
922 |
+
is_encoder_decoder: bool = False,
|
923 |
+
) -> torch.FloatTensor:
|
924 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
925 |
+
|
926 |
+
Args:
|
927 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
928 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
929 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
930 |
+
label_pad_token_id: The label pad token id.
|
931 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
932 |
+
|
933 |
+
Returns:
|
934 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
935 |
+
"""
|
936 |
+
if logits.shape[:-1] != labels.shape:
|
937 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
938 |
+
|
939 |
+
if not is_encoder_decoder:
|
940 |
+
labels = labels[:, 1:].clone()
|
941 |
+
logits = logits[:, :-1, :]
|
942 |
+
loss_mask = labels != label_pad_token_id
|
943 |
+
|
944 |
+
# dummy token; we'll ignore the losses on these tokens later
|
945 |
+
labels[labels == label_pad_token_id] = 0
|
946 |
+
|
947 |
+
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
948 |
+
|
949 |
+
if average_log_prob:
|
950 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
951 |
+
else:
|
952 |
+
return (per_token_logps * loss_mask).sum(-1)
|
953 |
+
|
954 |
+
def concatenated_forward(
|
955 |
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
956 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
957 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
958 |
+
|
959 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
960 |
+
"""
|
961 |
+
concatenated_batch = self.concatenated_inputs(
|
962 |
+
batch,
|
963 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
964 |
+
label_pad_token_id=self.label_pad_token_id,
|
965 |
+
padding_value=self.padding_value,
|
966 |
+
device=self.accelerator.device,
|
967 |
+
)
|
968 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
969 |
+
|
970 |
+
model_kwargs = (
|
971 |
+
{
|
972 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
973 |
+
}
|
974 |
+
if self.is_encoder_decoder
|
975 |
+
else {}
|
976 |
+
)
|
977 |
+
|
978 |
+
outputs = model(
|
979 |
+
concatenated_batch["concatenated_input_ids"],
|
980 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
981 |
+
use_cache=False,
|
982 |
+
**model_kwargs,
|
983 |
+
)
|
984 |
+
all_logits = outputs.logits
|
985 |
+
|
986 |
+
def cross_entropy_loss(logits, labels):
|
987 |
+
if not self.is_encoder_decoder:
|
988 |
+
# Shift so that tokens < n predict n
|
989 |
+
logits = logits[..., :-1, :].contiguous()
|
990 |
+
labels = labels[..., 1:].contiguous()
|
991 |
+
# Flatten the tokens
|
992 |
+
loss_fct = nn.CrossEntropyLoss()
|
993 |
+
logits = logits.view(-1, logits.shape[-1])
|
994 |
+
labels = labels.view(-1)
|
995 |
+
# Enable model parallelism
|
996 |
+
labels = labels.to(logits.device)
|
997 |
+
loss = loss_fct(logits, labels)
|
998 |
+
return loss
|
999 |
+
|
1000 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1001 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1002 |
+
|
1003 |
+
all_logps = self.get_batch_logps(
|
1004 |
+
all_logits,
|
1005 |
+
concatenated_batch["concatenated_labels"],
|
1006 |
+
average_log_prob=self.loss_type == "ipo",
|
1007 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1008 |
+
label_pad_token_id=self.label_pad_token_id,
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
chosen_logps = all_logps[:len_chosen]
|
1012 |
+
rejected_logps = all_logps[len_chosen:]
|
1013 |
+
|
1014 |
+
chosen_logits = all_logits[:len_chosen]
|
1015 |
+
rejected_logits = all_logits[len_chosen:]
|
1016 |
+
|
1017 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
1018 |
+
|
1019 |
+
def get_batch_loss_metrics(
|
1020 |
+
self,
|
1021 |
+
model,
|
1022 |
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
1023 |
+
train_eval: Literal["train", "eval"] = "train",
|
1024 |
+
):
|
1025 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
1026 |
+
metrics = {}
|
1027 |
+
|
1028 |
+
(
|
1029 |
+
policy_chosen_logps,
|
1030 |
+
policy_rejected_logps,
|
1031 |
+
policy_chosen_logits,
|
1032 |
+
policy_rejected_logits,
|
1033 |
+
policy_nll_loss,
|
1034 |
+
) = self.concatenated_forward(model, batch)
|
1035 |
+
|
1036 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
1037 |
+
policy_chosen_logps,
|
1038 |
+
policy_rejected_logps,
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
loss = losses.mean() + policy_nll_loss
|
1042 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1043 |
+
|
1044 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1045 |
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
1046 |
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
1047 |
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
1048 |
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
1049 |
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
1050 |
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
1051 |
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
1052 |
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
1053 |
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
1054 |
+
|
1055 |
+
return loss, metrics
|
1056 |
+
|
1057 |
+
def compute_loss(
|
1058 |
+
self,
|
1059 |
+
model: Union[PreTrainedModel, nn.Module],
|
1060 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1061 |
+
return_outputs=False,
|
1062 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
1063 |
+
if not self.use_dpo_data_collator:
|
1064 |
+
warnings.warn(
|
1065 |
+
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1066 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1070 |
+
|
1071 |
+
with compute_loss_context_manager():
|
1072 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1073 |
+
|
1074 |
+
# force log the metrics
|
1075 |
+
self.store_metrics(metrics, train_eval="train")
|
1076 |
+
|
1077 |
+
if return_outputs:
|
1078 |
+
return (loss, metrics)
|
1079 |
+
return loss
|
1080 |
+
|
1081 |
+
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
1082 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1083 |
+
|
1084 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1085 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1086 |
+
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
1087 |
+
|
1088 |
+
with generate_context_manager():
|
1089 |
+
policy_output = model.generate(
|
1090 |
+
input_ids=batch["prompt_input_ids"],
|
1091 |
+
attention_mask=batch["prompt_attention_mask"],
|
1092 |
+
max_length=self.max_length,
|
1093 |
+
do_sample=True,
|
1094 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
|
1098 |
+
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
1099 |
+
|
1100 |
+
return policy_output_decoded
|
1101 |
+
|
1102 |
+
def prediction_step(
|
1103 |
+
self,
|
1104 |
+
model: Union[PreTrainedModel, nn.Module],
|
1105 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1106 |
+
prediction_loss_only: bool,
|
1107 |
+
ignore_keys: Optional[List[str]] = None,
|
1108 |
+
):
|
1109 |
+
if not self.use_dpo_data_collator:
|
1110 |
+
warnings.warn(
|
1111 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1112 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1113 |
+
)
|
1114 |
+
if ignore_keys is None:
|
1115 |
+
if hasattr(model, "config"):
|
1116 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1117 |
+
else:
|
1118 |
+
ignore_keys = []
|
1119 |
+
|
1120 |
+
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1121 |
+
|
1122 |
+
with torch.no_grad(), prediction_context_manager():
|
1123 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1124 |
+
|
1125 |
+
# force log the metrics
|
1126 |
+
self.store_metrics(metrics, train_eval="eval")
|
1127 |
+
|
1128 |
+
if prediction_loss_only:
|
1129 |
+
return (loss.detach(), None, None)
|
1130 |
+
|
1131 |
+
# logits for the chosen and rejected samples from model
|
1132 |
+
logits_dict = {
|
1133 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1134 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1135 |
+
}
|
1136 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1137 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1138 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1139 |
+
|
1140 |
+
return (loss.detach(), logits, labels)
|
1141 |
+
|
1142 |
+
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1143 |
+
for key, value in metrics.items():
|
1144 |
+
self._stored_metrics[train_eval][key].append(value)
|
1145 |
+
|
1146 |
+
def evaluation_loop(
|
1147 |
+
self,
|
1148 |
+
dataloader: DataLoader,
|
1149 |
+
description: str,
|
1150 |
+
prediction_loss_only: Optional[bool] = None,
|
1151 |
+
ignore_keys: Optional[List[str]] = None,
|
1152 |
+
metric_key_prefix: str = "eval",
|
1153 |
+
) -> EvalLoopOutput:
|
1154 |
+
"""
|
1155 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1156 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1157 |
+
|
1158 |
+
Works both with or without labels.
|
1159 |
+
"""
|
1160 |
+
|
1161 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1162 |
+
if self.generate_during_eval:
|
1163 |
+
# Generate random indices within the range of the total number of samples
|
1164 |
+
num_samples = len(dataloader.dataset)
|
1165 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1166 |
+
|
1167 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1168 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1169 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1170 |
+
random_batch = self._prepare_inputs(random_batch)
|
1171 |
+
|
1172 |
+
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
|
1173 |
+
|
1174 |
+
self.log(
|
1175 |
+
{
|
1176 |
+
"game_log": wandb.Table(
|
1177 |
+
columns=["Prompt", "Policy"],
|
1178 |
+
rows=[
|
1179 |
+
[prompt, pol[len(prompt) :]]
|
1180 |
+
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1181 |
+
],
|
1182 |
+
)
|
1183 |
+
}
|
1184 |
+
)
|
1185 |
+
self.state.log_history.pop()
|
1186 |
+
|
1187 |
+
# Base evaluation
|
1188 |
+
initial_output = super().evaluation_loop(
|
1189 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
return initial_output
|
1193 |
+
|
1194 |
+
def log(self, logs: Dict[str, float]) -> None:
|
1195 |
+
"""
|
1196 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1197 |
+
|
1198 |
+
Args:
|
1199 |
+
logs (`Dict[str, float]`):
|
1200 |
+
The values to log.
|
1201 |
+
"""
|
1202 |
+
# logs either has 'loss' or 'eval_loss'
|
1203 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1204 |
+
# Add averaged stored metrics to logs
|
1205 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1206 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1207 |
+
del self._stored_metrics[train_eval]
|
1208 |
+
return super().log(logs)
|
1209 |
+
|
1210 |
+
def _shift_right(self, input_ids):
|
1211 |
+
if self.decoder_start_token_id is None:
|
1212 |
+
raise ValueError(
|
1213 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
# shift inputs to the right
|
1217 |
+
if is_torch_fx_proxy(input_ids):
|
1218 |
+
# Item assignment is not supported natively for proxies.
|
1219 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1220 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1221 |
+
else:
|
1222 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1223 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1224 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1225 |
+
|
1226 |
+
if self.pad_token_id is None:
|
1227 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1228 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1229 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1230 |
+
|
1231 |
+
return shifted_input_ids
|
1232 |
+
|
1233 |
+
@wraps(Trainer.push_to_hub)
|
1234 |
+
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
1235 |
+
"""
|
1236 |
+
Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the
|
1237 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
1238 |
+
"""
|
1239 |
+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
1240 |
+
|
1241 |
+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
1242 |
+
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
1243 |
+
"""
|
1244 |
+
|
1245 |
+
Initialize CPOTrainer.
|
1246 |
+
|
1247 |
+
Args:
|
1248 |
+
model (`transformers.PreTrainedModel`):
|
1249 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1250 |
+
args (`CPOConfig`):
|
1251 |
+
The CPO config arguments to use for training.
|
1252 |
+
data_collator (`transformers.DataCollator`):
|
1253 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1254 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1255 |
+
train_dataset (`datasets.Dataset`):
|
1256 |
+
The dataset to use for training.
|
1257 |
+
eval_dataset (`datasets.Dataset`):
|
1258 |
+
The dataset to use for evaluation.
|
1259 |
+
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
1260 |
+
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
1261 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1262 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1263 |
+
callbacks (`List[transformers.TrainerCallback]`):
|
1264 |
+
The callbacks to use for training.
|
1265 |
+
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1266 |
+
The optimizer and scheduler to use for training.
|
1267 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1268 |
+
The function to use to preprocess the logits before computing the metrics.
|
1269 |
+
peft_config (`Dict`, defaults to `None`):
|
1270 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1271 |
+
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
1272 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1273 |
+
a dictionary string to metric values.
|
1274 |
+
|
1275 |
+
"""
|
1276 |
+
def __init__(
|
1277 |
+
self,
|
1278 |
+
model = None,
|
1279 |
+
args = None,
|
1280 |
+
data_collator = None,
|
1281 |
+
train_dataset = None,
|
1282 |
+
eval_dataset = None,
|
1283 |
+
tokenizer = None,
|
1284 |
+
model_init = None,
|
1285 |
+
callbacks = None,
|
1286 |
+
preprocess_logits_for_metrics = None,
|
1287 |
+
peft_config = None,
|
1288 |
+
compute_metrics = None,
|
1289 |
+
**kwargs
|
1290 |
+
):
|
1291 |
+
if args is None: args = UnslothCPOConfig()
|
1292 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1293 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
1294 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1295 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
1296 |
+
force_float32 = False
|
1297 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1298 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1299 |
+
force_float32 = True
|
1300 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1301 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1302 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1303 |
+
from unsloth_zoo.utils import _get_dtype
|
1304 |
+
dtype = _get_dtype(dtype)
|
1305 |
+
float16 = dtype == torch.float16
|
1306 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1307 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1308 |
+
if force_float32:
|
1309 |
+
args.fp16 = False
|
1310 |
+
args.bf16 = False
|
1311 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1312 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1313 |
+
args.fp16 = float16
|
1314 |
+
args.bf16 = not float16
|
1315 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1316 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1317 |
+
args.eval_strategy = 'steps'
|
1318 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1319 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1320 |
+
if ga_steps is not None and ga_steps > 1:
|
1321 |
+
from transformers import __version__ as transformers_version
|
1322 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1323 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1324 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1325 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1326 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1327 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1328 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1329 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1330 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
1331 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1332 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
1333 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1334 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1335 |
+
if force_float32:
|
1336 |
+
args.bf16_full_eval = False
|
1337 |
+
args.fp16_full_eval = False
|
1338 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1339 |
+
args.bf16_full_eval = True
|
1340 |
+
args.fp16_full_eval = False
|
1341 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1342 |
+
args.bf16_full_eval = args.bf16
|
1343 |
+
args.fp16_full_eval = args.fp16
|
1344 |
+
_output_logits = False
|
1345 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1346 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1347 |
+
if _output_logits:
|
1348 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1349 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1350 |
+
pass
|
1351 |
+
else:
|
1352 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1353 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1354 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1355 |
+
max_seq_length = model.max_seq_length
|
1356 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1357 |
+
if model is not None and hasattr(model, 'for_training'):
|
1358 |
+
model.for_training()
|
1359 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1360 |
+
if 'processing_class' in locals():
|
1361 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1362 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1363 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1364 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1365 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1366 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1367 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
1368 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1369 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1370 |
+
else:
|
1371 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1372 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1373 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1374 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1375 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1376 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1377 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1378 |
+
else:
|
1379 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
1380 |
+
other_metrics = []
|
1381 |
+
|
1382 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1383 |
+
PatchRLStatistics('cpo_trainer', other_metrics)
|
1384 |
+
|
1385 |
+
super().__init__(
|
1386 |
+
model = model,
|
1387 |
+
args = args,
|
1388 |
+
data_collator = data_collator,
|
1389 |
+
train_dataset = train_dataset,
|
1390 |
+
eval_dataset = eval_dataset,
|
1391 |
+
tokenizer = tokenizer,
|
1392 |
+
model_init = model_init,
|
1393 |
+
callbacks = callbacks,
|
1394 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1395 |
+
peft_config = peft_config,
|
1396 |
+
compute_metrics = compute_metrics,**kwargs)
|
1397 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1398 |
+
self.neftune_hook_handle.remove()
|
1399 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1400 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1401 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1402 |
+
pass
|
1403 |
+
|
1404 |
+
pass
|
compilefcach/UnslothDDPOTrainer.py
ADDED
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ddpo_trainer import (Accelerator, Any, BaseTrainer, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, MODEL_CARD_TEMPLATE, Optional, PerPromptStatTracker, ProjectConfiguration, Tuple, defaultdict, futures, logger, os, set_seed, torch, warn, warnings, whoami)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothDDPOConfig(DDPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for DDPOTrainer
|
47 |
+
|
48 |
+
"""
|
49 |
+
vllm_sampling_params: Optional[Any] = field(
|
50 |
+
default = None,
|
51 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
52 |
+
)
|
53 |
+
unsloth_num_chunks : Optional[int] = field(
|
54 |
+
default = -1,
|
55 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
56 |
+
)
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
exp_name = 'colab_kernel_launcher',
|
60 |
+
run_name = '',
|
61 |
+
seed = 3407,
|
62 |
+
log_with = None,
|
63 |
+
tracker_project_name = 'trl',
|
64 |
+
logdir = 'logs',
|
65 |
+
num_epochs = 100,
|
66 |
+
save_freq = 1,
|
67 |
+
num_checkpoint_limit = 5,
|
68 |
+
mixed_precision = 'fp16',
|
69 |
+
allow_tf32 = True,
|
70 |
+
resume_from = '',
|
71 |
+
sample_num_steps = 50,
|
72 |
+
sample_eta = 1.0,
|
73 |
+
sample_guidance_scale = 5.0,
|
74 |
+
sample_batch_size = 1,
|
75 |
+
sample_num_batches_per_epoch = 2,
|
76 |
+
train_batch_size = 1,
|
77 |
+
train_use_8bit_adam = False,
|
78 |
+
train_learning_rate = 5e-05,
|
79 |
+
train_adam_beta1 = 0.9,
|
80 |
+
train_adam_beta2 = 0.999,
|
81 |
+
train_adam_weight_decay = 0.01,
|
82 |
+
train_adam_epsilon = 1e-08,
|
83 |
+
train_gradient_accumulation_steps = 2,
|
84 |
+
train_max_grad_norm = 1.0,
|
85 |
+
train_num_inner_epochs = 1,
|
86 |
+
train_cfg = True,
|
87 |
+
train_adv_clip_max = 5,
|
88 |
+
train_clip_range = 0.0001,
|
89 |
+
train_timestep_fraction = 1.0,
|
90 |
+
per_prompt_stat_tracking = False,
|
91 |
+
per_prompt_stat_tracking_buffer_size = 16,
|
92 |
+
per_prompt_stat_tracking_min_count = 16,
|
93 |
+
async_reward_computation = False,
|
94 |
+
max_workers = 2,
|
95 |
+
negative_prompts = '',
|
96 |
+
vllm_sampling_params = None,
|
97 |
+
unsloth_num_chunks = -1,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
|
101 |
+
super().__init__(
|
102 |
+
exp_name = exp_name,
|
103 |
+
run_name = run_name,
|
104 |
+
seed = seed,
|
105 |
+
log_with = log_with,
|
106 |
+
tracker_project_name = tracker_project_name,
|
107 |
+
logdir = logdir,
|
108 |
+
num_epochs = num_epochs,
|
109 |
+
save_freq = save_freq,
|
110 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
111 |
+
mixed_precision = mixed_precision,
|
112 |
+
allow_tf32 = allow_tf32,
|
113 |
+
resume_from = resume_from,
|
114 |
+
sample_num_steps = sample_num_steps,
|
115 |
+
sample_eta = sample_eta,
|
116 |
+
sample_guidance_scale = sample_guidance_scale,
|
117 |
+
sample_batch_size = sample_batch_size,
|
118 |
+
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
119 |
+
train_batch_size = train_batch_size,
|
120 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
121 |
+
train_learning_rate = train_learning_rate,
|
122 |
+
train_adam_beta1 = train_adam_beta1,
|
123 |
+
train_adam_beta2 = train_adam_beta2,
|
124 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
125 |
+
train_adam_epsilon = train_adam_epsilon,
|
126 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
127 |
+
train_max_grad_norm = train_max_grad_norm,
|
128 |
+
train_num_inner_epochs = train_num_inner_epochs,
|
129 |
+
train_cfg = train_cfg,
|
130 |
+
train_adv_clip_max = train_adv_clip_max,
|
131 |
+
train_clip_range = train_clip_range,
|
132 |
+
train_timestep_fraction = train_timestep_fraction,
|
133 |
+
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
134 |
+
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
135 |
+
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
136 |
+
async_reward_computation = async_reward_computation,
|
137 |
+
max_workers = max_workers,
|
138 |
+
negative_prompts = negative_prompts,**kwargs)
|
139 |
+
self.vllm_sampling_params = vllm_sampling_params
|
140 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
141 |
+
pass
|
142 |
+
|
143 |
+
class _UnslothDDPOTrainer(BaseTrainer):
|
144 |
+
""""""
|
145 |
+
|
146 |
+
_tag_names = ["trl", "ddpo"]
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
config: DDPOConfig,
|
151 |
+
reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
|
152 |
+
prompt_function: Callable[[], Tuple[str, Any]],
|
153 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
154 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
155 |
+
):
|
156 |
+
if image_samples_hook is None:
|
157 |
+
warn("No image_samples_hook provided; no images will be logged")
|
158 |
+
|
159 |
+
self.prompt_fn = prompt_function
|
160 |
+
self.reward_fn = reward_function
|
161 |
+
self.config = config
|
162 |
+
self.image_samples_callback = image_samples_hook
|
163 |
+
|
164 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
165 |
+
|
166 |
+
if self.config.resume_from:
|
167 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
168 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
169 |
+
# get the most recent checkpoint in this directory
|
170 |
+
checkpoints = list(
|
171 |
+
filter(
|
172 |
+
lambda x: "checkpoint_" in x,
|
173 |
+
os.listdir(self.config.resume_from),
|
174 |
+
)
|
175 |
+
)
|
176 |
+
if len(checkpoints) == 0:
|
177 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
178 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
179 |
+
self.config.resume_from = os.path.join(
|
180 |
+
self.config.resume_from,
|
181 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
182 |
+
)
|
183 |
+
|
184 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
185 |
+
|
186 |
+
# number of timesteps within each trajectory to train on
|
187 |
+
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
188 |
+
|
189 |
+
self.accelerator = Accelerator(
|
190 |
+
log_with=self.config.log_with,
|
191 |
+
mixed_precision=self.config.mixed_precision,
|
192 |
+
project_config=accelerator_project_config,
|
193 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
194 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
195 |
+
# the total number of optimizer steps to accumulate across.
|
196 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
197 |
+
**self.config.accelerator_kwargs,
|
198 |
+
)
|
199 |
+
|
200 |
+
is_okay, message = self._config_check()
|
201 |
+
if not is_okay:
|
202 |
+
raise ValueError(message)
|
203 |
+
|
204 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
205 |
+
|
206 |
+
if self.accelerator.is_main_process:
|
207 |
+
self.accelerator.init_trackers(
|
208 |
+
self.config.tracker_project_name,
|
209 |
+
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
210 |
+
init_kwargs=self.config.tracker_kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
logger.info(f"\n{config}")
|
214 |
+
|
215 |
+
set_seed(self.config.seed, device_specific=True)
|
216 |
+
|
217 |
+
self.sd_pipeline = sd_pipeline
|
218 |
+
|
219 |
+
self.sd_pipeline.set_progress_bar_config(
|
220 |
+
position=1,
|
221 |
+
disable=not self.accelerator.is_local_main_process,
|
222 |
+
leave=False,
|
223 |
+
desc="Timestep",
|
224 |
+
dynamic_ncols=True,
|
225 |
+
)
|
226 |
+
|
227 |
+
# For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
|
228 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
229 |
+
if self.accelerator.mixed_precision == "fp16":
|
230 |
+
inference_dtype = torch.float16
|
231 |
+
elif self.accelerator.mixed_precision == "bf16":
|
232 |
+
inference_dtype = torch.bfloat16
|
233 |
+
else:
|
234 |
+
inference_dtype = torch.float32
|
235 |
+
|
236 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
237 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
238 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
239 |
+
|
240 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
241 |
+
|
242 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
243 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
244 |
+
|
245 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
246 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
247 |
+
if self.config.allow_tf32:
|
248 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
249 |
+
|
250 |
+
self.optimizer = self._setup_optimizer(
|
251 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
252 |
+
)
|
253 |
+
|
254 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
255 |
+
self.sd_pipeline.tokenizer(
|
256 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
257 |
+
return_tensors="pt",
|
258 |
+
padding="max_length",
|
259 |
+
truncation=True,
|
260 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
261 |
+
).input_ids.to(self.accelerator.device)
|
262 |
+
)[0]
|
263 |
+
|
264 |
+
if config.per_prompt_stat_tracking:
|
265 |
+
self.stat_tracker = PerPromptStatTracker(
|
266 |
+
config.per_prompt_stat_tracking_buffer_size,
|
267 |
+
config.per_prompt_stat_tracking_min_count,
|
268 |
+
)
|
269 |
+
|
270 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
271 |
+
# more memory
|
272 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
273 |
+
|
274 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
275 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
276 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
277 |
+
else:
|
278 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
279 |
+
|
280 |
+
if self.config.async_reward_computation:
|
281 |
+
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
282 |
+
|
283 |
+
if config.resume_from:
|
284 |
+
logger.info(f"Resuming from {config.resume_from}")
|
285 |
+
self.accelerator.load_state(config.resume_from)
|
286 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
287 |
+
else:
|
288 |
+
self.first_epoch = 0
|
289 |
+
|
290 |
+
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
291 |
+
if not is_async:
|
292 |
+
rewards = []
|
293 |
+
for images, prompts, prompt_metadata in prompt_image_pairs:
|
294 |
+
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
295 |
+
rewards.append(
|
296 |
+
(
|
297 |
+
torch.as_tensor(reward, device=self.accelerator.device),
|
298 |
+
reward_metadata,
|
299 |
+
)
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
303 |
+
rewards = [
|
304 |
+
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
305 |
+
for reward, reward_metadata in rewards
|
306 |
+
]
|
307 |
+
|
308 |
+
return zip(*rewards)
|
309 |
+
|
310 |
+
def step(self, epoch: int, global_step: int):
|
311 |
+
"""
|
312 |
+
Perform a single step of training.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
epoch (int): The current epoch.
|
316 |
+
global_step (int): The current global step.
|
317 |
+
|
318 |
+
Side Effects:
|
319 |
+
- Model weights are updated
|
320 |
+
- Logs the statistics to the accelerator trackers.
|
321 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
global_step (int): The updated global step.
|
325 |
+
|
326 |
+
"""
|
327 |
+
samples, prompt_image_data = self._generate_samples(
|
328 |
+
iterations=self.config.sample_num_batches_per_epoch,
|
329 |
+
batch_size=self.config.sample_batch_size,
|
330 |
+
)
|
331 |
+
|
332 |
+
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
333 |
+
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
334 |
+
rewards, rewards_metadata = self.compute_rewards(
|
335 |
+
prompt_image_data, is_async=self.config.async_reward_computation
|
336 |
+
)
|
337 |
+
|
338 |
+
for i, image_data in enumerate(prompt_image_data):
|
339 |
+
image_data.extend([rewards[i], rewards_metadata[i]])
|
340 |
+
|
341 |
+
if self.image_samples_callback is not None:
|
342 |
+
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
343 |
+
|
344 |
+
rewards = torch.cat(rewards)
|
345 |
+
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
346 |
+
|
347 |
+
self.accelerator.log(
|
348 |
+
{
|
349 |
+
"reward": rewards,
|
350 |
+
"epoch": epoch,
|
351 |
+
"reward_mean": rewards.mean(),
|
352 |
+
"reward_std": rewards.std(),
|
353 |
+
},
|
354 |
+
step=global_step,
|
355 |
+
)
|
356 |
+
|
357 |
+
if self.config.per_prompt_stat_tracking:
|
358 |
+
# gather the prompts across processes
|
359 |
+
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
360 |
+
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
361 |
+
advantages = self.stat_tracker.update(prompts, rewards)
|
362 |
+
else:
|
363 |
+
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
364 |
+
|
365 |
+
# ungather advantages; keep the entries corresponding to the samples on this process
|
366 |
+
samples["advantages"] = (
|
367 |
+
torch.as_tensor(advantages)
|
368 |
+
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
369 |
+
.to(self.accelerator.device)
|
370 |
+
)
|
371 |
+
|
372 |
+
del samples["prompt_ids"]
|
373 |
+
|
374 |
+
total_batch_size, num_timesteps = samples["timesteps"].shape
|
375 |
+
|
376 |
+
for inner_epoch in range(self.config.train_num_inner_epochs):
|
377 |
+
# shuffle samples along batch dimension
|
378 |
+
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
379 |
+
samples = {k: v[perm] for k, v in samples.items()}
|
380 |
+
|
381 |
+
# shuffle along time dimension independently for each sample
|
382 |
+
# still trying to understand the code below
|
383 |
+
perms = torch.stack(
|
384 |
+
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
385 |
+
)
|
386 |
+
|
387 |
+
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
388 |
+
samples[key] = samples[key][
|
389 |
+
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
390 |
+
perms,
|
391 |
+
]
|
392 |
+
|
393 |
+
original_keys = samples.keys()
|
394 |
+
original_values = samples.values()
|
395 |
+
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
396 |
+
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
397 |
+
|
398 |
+
# Transpose the list of original values
|
399 |
+
transposed_values = zip(*reshaped_values)
|
400 |
+
# Create new dictionaries for each row of transposed values
|
401 |
+
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
402 |
+
|
403 |
+
self.sd_pipeline.unet.train()
|
404 |
+
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
405 |
+
# ensure optimization step at the end of the inner epoch
|
406 |
+
if not self.accelerator.sync_gradients:
|
407 |
+
raise ValueError(
|
408 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
409 |
+
)
|
410 |
+
|
411 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
412 |
+
self.accelerator.save_state()
|
413 |
+
|
414 |
+
return global_step
|
415 |
+
|
416 |
+
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
417 |
+
"""
|
418 |
+
Calculate the loss for a batch of an unpacked sample
|
419 |
+
|
420 |
+
Args:
|
421 |
+
latents (torch.Tensor):
|
422 |
+
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
423 |
+
timesteps (torch.Tensor):
|
424 |
+
The timesteps sampled from the diffusion model, shape: [batch_size]
|
425 |
+
next_latents (torch.Tensor):
|
426 |
+
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
427 |
+
log_probs (torch.Tensor):
|
428 |
+
The log probabilities of the latents, shape: [batch_size]
|
429 |
+
advantages (torch.Tensor):
|
430 |
+
The advantages of the latents, shape: [batch_size]
|
431 |
+
embeds (torch.Tensor):
|
432 |
+
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
433 |
+
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
434 |
+
|
435 |
+
Returns:
|
436 |
+
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
437 |
+
(all of these are of shape (1,))
|
438 |
+
"""
|
439 |
+
with self.autocast():
|
440 |
+
if self.config.train_cfg:
|
441 |
+
noise_pred = self.sd_pipeline.unet(
|
442 |
+
torch.cat([latents] * 2),
|
443 |
+
torch.cat([timesteps] * 2),
|
444 |
+
embeds,
|
445 |
+
).sample
|
446 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
447 |
+
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
448 |
+
noise_pred_text - noise_pred_uncond
|
449 |
+
)
|
450 |
+
else:
|
451 |
+
noise_pred = self.sd_pipeline.unet(
|
452 |
+
latents,
|
453 |
+
timesteps,
|
454 |
+
embeds,
|
455 |
+
).sample
|
456 |
+
# compute the log prob of next_latents given latents under the current model
|
457 |
+
|
458 |
+
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
459 |
+
noise_pred,
|
460 |
+
timesteps,
|
461 |
+
latents,
|
462 |
+
eta=self.config.sample_eta,
|
463 |
+
prev_sample=next_latents,
|
464 |
+
)
|
465 |
+
|
466 |
+
log_prob = scheduler_step_output.log_probs
|
467 |
+
|
468 |
+
advantages = torch.clamp(
|
469 |
+
advantages,
|
470 |
+
-self.config.train_adv_clip_max,
|
471 |
+
self.config.train_adv_clip_max,
|
472 |
+
)
|
473 |
+
|
474 |
+
ratio = torch.exp(log_prob - log_probs)
|
475 |
+
|
476 |
+
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
477 |
+
|
478 |
+
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
479 |
+
|
480 |
+
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
481 |
+
|
482 |
+
return loss, approx_kl, clipfrac
|
483 |
+
|
484 |
+
def loss(
|
485 |
+
self,
|
486 |
+
advantages: torch.Tensor,
|
487 |
+
clip_range: float,
|
488 |
+
ratio: torch.Tensor,
|
489 |
+
):
|
490 |
+
unclipped_loss = -advantages * ratio
|
491 |
+
clipped_loss = -advantages * torch.clamp(
|
492 |
+
ratio,
|
493 |
+
1.0 - clip_range,
|
494 |
+
1.0 + clip_range,
|
495 |
+
)
|
496 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
497 |
+
|
498 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
499 |
+
if self.config.train_use_8bit_adam:
|
500 |
+
import bitsandbytes
|
501 |
+
|
502 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
503 |
+
else:
|
504 |
+
optimizer_cls = torch.optim.AdamW
|
505 |
+
|
506 |
+
return optimizer_cls(
|
507 |
+
trainable_layers_parameters,
|
508 |
+
lr=self.config.train_learning_rate,
|
509 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
510 |
+
weight_decay=self.config.train_adam_weight_decay,
|
511 |
+
eps=self.config.train_adam_epsilon,
|
512 |
+
)
|
513 |
+
|
514 |
+
def _save_model_hook(self, models, weights, output_dir):
|
515 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
516 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
517 |
+
|
518 |
+
def _load_model_hook(self, models, input_dir):
|
519 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
520 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
521 |
+
|
522 |
+
def _generate_samples(self, iterations, batch_size):
|
523 |
+
"""
|
524 |
+
Generate samples from the model
|
525 |
+
|
526 |
+
Args:
|
527 |
+
iterations (int): Number of iterations to generate samples for
|
528 |
+
batch_size (int): Batch size to use for sampling
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
|
532 |
+
"""
|
533 |
+
samples = []
|
534 |
+
prompt_image_pairs = []
|
535 |
+
self.sd_pipeline.unet.eval()
|
536 |
+
|
537 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
538 |
+
|
539 |
+
for _ in range(iterations):
|
540 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
541 |
+
|
542 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
543 |
+
prompts,
|
544 |
+
return_tensors="pt",
|
545 |
+
padding="max_length",
|
546 |
+
truncation=True,
|
547 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
548 |
+
).input_ids.to(self.accelerator.device)
|
549 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
550 |
+
|
551 |
+
with self.autocast():
|
552 |
+
sd_output = self.sd_pipeline(
|
553 |
+
prompt_embeds=prompt_embeds,
|
554 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
555 |
+
num_inference_steps=self.config.sample_num_steps,
|
556 |
+
guidance_scale=self.config.sample_guidance_scale,
|
557 |
+
eta=self.config.sample_eta,
|
558 |
+
output_type="pt",
|
559 |
+
)
|
560 |
+
|
561 |
+
images = sd_output.images
|
562 |
+
latents = sd_output.latents
|
563 |
+
log_probs = sd_output.log_probs
|
564 |
+
|
565 |
+
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
566 |
+
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
567 |
+
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
568 |
+
|
569 |
+
samples.append(
|
570 |
+
{
|
571 |
+
"prompt_ids": prompt_ids,
|
572 |
+
"prompt_embeds": prompt_embeds,
|
573 |
+
"timesteps": timesteps,
|
574 |
+
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
575 |
+
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
576 |
+
"log_probs": log_probs,
|
577 |
+
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
578 |
+
}
|
579 |
+
)
|
580 |
+
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
581 |
+
|
582 |
+
return samples, prompt_image_pairs
|
583 |
+
|
584 |
+
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
585 |
+
"""
|
586 |
+
Train on a batch of samples. Main training segment
|
587 |
+
|
588 |
+
Args:
|
589 |
+
inner_epoch (int): The current inner epoch
|
590 |
+
epoch (int): The current epoch
|
591 |
+
global_step (int): The current global step
|
592 |
+
batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
|
593 |
+
|
594 |
+
Side Effects:
|
595 |
+
- Model weights are updated
|
596 |
+
- Logs the statistics to the accelerator trackers.
|
597 |
+
|
598 |
+
Returns:
|
599 |
+
global_step (int): The updated global step
|
600 |
+
"""
|
601 |
+
info = defaultdict(list)
|
602 |
+
for _i, sample in enumerate(batched_samples):
|
603 |
+
if self.config.train_cfg:
|
604 |
+
# concat negative prompts to sample prompts to avoid two forward passes
|
605 |
+
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
606 |
+
else:
|
607 |
+
embeds = sample["prompt_embeds"]
|
608 |
+
|
609 |
+
for j in range(self.num_train_timesteps):
|
610 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
611 |
+
loss, approx_kl, clipfrac = self.calculate_loss(
|
612 |
+
sample["latents"][:, j],
|
613 |
+
sample["timesteps"][:, j],
|
614 |
+
sample["next_latents"][:, j],
|
615 |
+
sample["log_probs"][:, j],
|
616 |
+
sample["advantages"],
|
617 |
+
embeds,
|
618 |
+
)
|
619 |
+
info["approx_kl"].append(approx_kl)
|
620 |
+
info["clipfrac"].append(clipfrac)
|
621 |
+
info["loss"].append(loss)
|
622 |
+
|
623 |
+
self.accelerator.backward(loss)
|
624 |
+
if self.accelerator.sync_gradients:
|
625 |
+
self.accelerator.clip_grad_norm_(
|
626 |
+
self.trainable_layers.parameters()
|
627 |
+
if not isinstance(self.trainable_layers, list)
|
628 |
+
else self.trainable_layers,
|
629 |
+
self.config.train_max_grad_norm,
|
630 |
+
)
|
631 |
+
self.optimizer.step()
|
632 |
+
self.optimizer.zero_grad()
|
633 |
+
|
634 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
635 |
+
if self.accelerator.sync_gradients:
|
636 |
+
# log training-related stuff
|
637 |
+
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
638 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
639 |
+
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
640 |
+
self.accelerator.log(info, step=global_step)
|
641 |
+
global_step += 1
|
642 |
+
info = defaultdict(list)
|
643 |
+
return global_step
|
644 |
+
|
645 |
+
def _config_check(self) -> Tuple[bool, str]:
|
646 |
+
samples_per_epoch = (
|
647 |
+
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
648 |
+
)
|
649 |
+
total_train_batch_size = (
|
650 |
+
self.config.train_batch_size
|
651 |
+
* self.accelerator.num_processes
|
652 |
+
* self.config.train_gradient_accumulation_steps
|
653 |
+
)
|
654 |
+
|
655 |
+
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
656 |
+
return (
|
657 |
+
False,
|
658 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
659 |
+
)
|
660 |
+
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
661 |
+
return (
|
662 |
+
False,
|
663 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
664 |
+
)
|
665 |
+
if not samples_per_epoch % total_train_batch_size == 0:
|
666 |
+
return (
|
667 |
+
False,
|
668 |
+
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
669 |
+
)
|
670 |
+
return True, ""
|
671 |
+
|
672 |
+
def train(self, epochs: Optional[int] = None):
|
673 |
+
"""
|
674 |
+
Train the model for a given number of epochs
|
675 |
+
"""
|
676 |
+
global_step = 0
|
677 |
+
if epochs is None:
|
678 |
+
epochs = self.config.num_epochs
|
679 |
+
for epoch in range(self.first_epoch, epochs):
|
680 |
+
global_step = self.step(epoch, global_step)
|
681 |
+
|
682 |
+
def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None:
|
683 |
+
"""Creates and saves a model card for a TRL model.
|
684 |
+
|
685 |
+
Args:
|
686 |
+
path (`str`): The path to save the model card to.
|
687 |
+
model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`.
|
688 |
+
"""
|
689 |
+
try:
|
690 |
+
user = whoami()["name"]
|
691 |
+
# handle the offline case
|
692 |
+
except Exception:
|
693 |
+
warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
|
694 |
+
return
|
695 |
+
|
696 |
+
if not os.path.exists(path):
|
697 |
+
os.makedirs(path)
|
698 |
+
|
699 |
+
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
|
700 |
+
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
|
701 |
+
f.write(model_card_content)
|
702 |
+
|
703 |
+
def _save_pretrained(self, save_directory):
|
704 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
705 |
+
self.create_model_card(save_directory)
|
706 |
+
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
707 |
+
"""
|
708 |
+
|
709 |
+
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
710 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
711 |
+
As of now only Stable Diffusion based pipelines are supported
|
712 |
+
|
713 |
+
Attributes:
|
714 |
+
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
715 |
+
details.
|
716 |
+
**reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
|
717 |
+
**prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
|
718 |
+
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
719 |
+
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
720 |
+
|
721 |
+
"""
|
722 |
+
def __init__(
|
723 |
+
self,
|
724 |
+
config,
|
725 |
+
reward_function,
|
726 |
+
prompt_function,
|
727 |
+
sd_pipeline,
|
728 |
+
image_samples_hook = None,
|
729 |
+
**kwargs
|
730 |
+
):
|
731 |
+
if args is None: args = UnslothDDPOConfig()
|
732 |
+
other_metrics = []
|
733 |
+
|
734 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
735 |
+
PatchRLStatistics('ddpo_trainer', other_metrics)
|
736 |
+
|
737 |
+
super().__init__(
|
738 |
+
config = config,
|
739 |
+
reward_function = reward_function,
|
740 |
+
prompt_function = prompt_function,
|
741 |
+
sd_pipeline = sd_pipeline,
|
742 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
743 |
+
|
744 |
+
pass
|
compilefcach/UnslothKTOTrainer.py
ADDED
@@ -0,0 +1,1629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, KTOConfig, KTOTrainer, List, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Tuple, Union, _get_kl_dataset, _process_tokens, _tokenize, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, has_length, inspect, is_peft_available, is_wandb_available, itemgetter, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, tqdm, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothKTOConfig(KTOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
KTOConfig collects all training arguments related to the [`KTOTrainer`] class.
|
47 |
+
|
48 |
+
Using [`HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int`, *optional*, defaults to `None`):
|
54 |
+
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
55 |
+
max_prompt_length (`int`, *optional*, defaults to `None`):
|
56 |
+
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
57 |
+
max_completion_length (`int`, *optional*, defaults to `None`):
|
58 |
+
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
59 |
+
beta (`float`, defaults to 0.1):
|
60 |
+
The beta factor in KTO loss. Higher beta means less divergence from the initial policy.
|
61 |
+
desirable_weight (`float`, *optional*, defaults to 1.0):
|
62 |
+
The desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
63 |
+
undesirable_weight (`float`, *optional*, defaults to 1.0):
|
64 |
+
The undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
65 |
+
label_pad_token_id (`int`, defaults to `-100`):
|
66 |
+
The label pad token id. This argument is required if you want to use the default data collator.
|
67 |
+
padding_value (`int`, defaults to `0`):
|
68 |
+
The padding value if it is different to the tokenizer's pad_token_id.
|
69 |
+
truncation_mode (`str`, defaults to `keep_end`):
|
70 |
+
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
71 |
+
generate_during_eval (`bool`, defaults to `False`):
|
72 |
+
Whether to sample and log generations during evaluation step.
|
73 |
+
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
74 |
+
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
75 |
+
precompute_ref_log_probs (`bool`, defaults to `False`):
|
76 |
+
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
|
77 |
+
without the reference model and reduce the total GPU memory needed.
|
78 |
+
model_init_kwargs: (`Optional[Dict]`, *optional*):
|
79 |
+
Dict of Optional kwargs to pass when instantiating the model from a string.
|
80 |
+
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
|
81 |
+
Dict of Optional kwargs to pass when instantiating the ref model from a string.
|
82 |
+
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
|
83 |
+
Number of processes to use for processing the datasets.
|
84 |
+
|
85 |
+
"""
|
86 |
+
vllm_sampling_params: Optional[Any] = field(
|
87 |
+
default = None,
|
88 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
89 |
+
)
|
90 |
+
unsloth_num_chunks : Optional[int] = field(
|
91 |
+
default = -1,
|
92 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
93 |
+
)
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
output_dir = None,
|
97 |
+
overwrite_output_dir = None,
|
98 |
+
do_train = False,
|
99 |
+
do_eval = False,
|
100 |
+
do_predict = False,
|
101 |
+
eval_strategy = 'no',
|
102 |
+
prediction_loss_only = False,
|
103 |
+
per_device_train_batch_size = 4,
|
104 |
+
per_device_eval_batch_size = 4,
|
105 |
+
per_gpu_train_batch_size = None,
|
106 |
+
per_gpu_eval_batch_size = None,
|
107 |
+
gradient_accumulation_steps = 2,
|
108 |
+
eval_accumulation_steps = 2,
|
109 |
+
eval_delay = 0,
|
110 |
+
torch_empty_cache_steps = 250,
|
111 |
+
learning_rate = 5e-05,
|
112 |
+
weight_decay = 0.01,
|
113 |
+
adam_beta1 = 0.9,
|
114 |
+
adam_beta2 = 0.999,
|
115 |
+
adam_epsilon = 1e-08,
|
116 |
+
max_grad_norm = 1.0,
|
117 |
+
num_train_epochs = 3.0,
|
118 |
+
max_steps = -1,
|
119 |
+
lr_scheduler_type = 'linear',
|
120 |
+
warmup_ratio = 0.1,
|
121 |
+
warmup_steps = 0,
|
122 |
+
log_level = 'passive',
|
123 |
+
log_level_replica = 'warning',
|
124 |
+
log_on_each_node = True,
|
125 |
+
logging_dir = None,
|
126 |
+
logging_strategy = 'steps',
|
127 |
+
logging_first_step = False,
|
128 |
+
logging_steps = 1,
|
129 |
+
logging_nan_inf_filter = False,
|
130 |
+
save_strategy = 'steps',
|
131 |
+
save_steps = 500,
|
132 |
+
save_total_limit = None,
|
133 |
+
save_safetensors = True,
|
134 |
+
save_on_each_node = False,
|
135 |
+
save_only_model = False,
|
136 |
+
restore_callback_states_from_checkpoint = False,
|
137 |
+
no_cuda = False,
|
138 |
+
use_cpu = False,
|
139 |
+
use_mps_device = False,
|
140 |
+
seed = 3407,
|
141 |
+
data_seed = 3407,
|
142 |
+
jit_mode_eval = False,
|
143 |
+
use_ipex = False,
|
144 |
+
bf16 = False,
|
145 |
+
fp16 = False,
|
146 |
+
fp16_opt_level = 'O1',
|
147 |
+
half_precision_backend = 'auto',
|
148 |
+
bf16_full_eval = False,
|
149 |
+
fp16_full_eval = False,
|
150 |
+
tf32 = None,
|
151 |
+
local_rank = -1,
|
152 |
+
ddp_backend = None,
|
153 |
+
tpu_num_cores = None,
|
154 |
+
tpu_metrics_debug = False,
|
155 |
+
debug = '',
|
156 |
+
dataloader_drop_last = False,
|
157 |
+
eval_steps = None,
|
158 |
+
dataloader_num_workers = 0,
|
159 |
+
dataloader_prefetch_factor = None,
|
160 |
+
past_index = -1,
|
161 |
+
run_name = None,
|
162 |
+
disable_tqdm = None,
|
163 |
+
remove_unused_columns = True,
|
164 |
+
label_names = None,
|
165 |
+
load_best_model_at_end = False,
|
166 |
+
metric_for_best_model = None,
|
167 |
+
greater_is_better = None,
|
168 |
+
ignore_data_skip = False,
|
169 |
+
fsdp = '',
|
170 |
+
fsdp_min_num_params = 0,
|
171 |
+
fsdp_config = None,
|
172 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
173 |
+
accelerator_config = None,
|
174 |
+
deepspeed = None,
|
175 |
+
label_smoothing_factor = 0.0,
|
176 |
+
optim = 'adamw_8bit',
|
177 |
+
optim_args = None,
|
178 |
+
adafactor = False,
|
179 |
+
group_by_length = False,
|
180 |
+
length_column_name = 'length',
|
181 |
+
report_to = None,
|
182 |
+
ddp_find_unused_parameters = None,
|
183 |
+
ddp_bucket_cap_mb = None,
|
184 |
+
ddp_broadcast_buffers = None,
|
185 |
+
dataloader_pin_memory = True,
|
186 |
+
dataloader_persistent_workers = False,
|
187 |
+
skip_memory_metrics = True,
|
188 |
+
use_legacy_prediction_loop = False,
|
189 |
+
push_to_hub = False,
|
190 |
+
resume_from_checkpoint = None,
|
191 |
+
hub_model_id = None,
|
192 |
+
hub_strategy = 'every_save',
|
193 |
+
hub_token = None,
|
194 |
+
hub_private_repo = None,
|
195 |
+
hub_always_push = False,
|
196 |
+
hub_revision = None,
|
197 |
+
gradient_checkpointing = False,
|
198 |
+
gradient_checkpointing_kwargs = None,
|
199 |
+
include_inputs_for_metrics = False,
|
200 |
+
eval_do_concat_batches = True,
|
201 |
+
fp16_backend = 'auto',
|
202 |
+
push_to_hub_model_id = None,
|
203 |
+
push_to_hub_organization = None,
|
204 |
+
push_to_hub_token = None,
|
205 |
+
mp_parameters = '',
|
206 |
+
auto_find_batch_size = False,
|
207 |
+
full_determinism = False,
|
208 |
+
torchdynamo = None,
|
209 |
+
ray_scope = 'last',
|
210 |
+
ddp_timeout = 1800,
|
211 |
+
torch_compile = False,
|
212 |
+
torch_compile_backend = None,
|
213 |
+
torch_compile_mode = None,
|
214 |
+
include_tokens_per_second = False,
|
215 |
+
include_num_input_tokens_seen = False,
|
216 |
+
neftune_noise_alpha = None,
|
217 |
+
optim_target_modules = None,
|
218 |
+
batch_eval_metrics = False,
|
219 |
+
eval_on_start = False,
|
220 |
+
use_liger_kernel = False,
|
221 |
+
liger_kernel_config = None,
|
222 |
+
eval_use_gather_object = False,
|
223 |
+
average_tokens_across_devices = False,
|
224 |
+
max_length = None,
|
225 |
+
max_prompt_length = None,
|
226 |
+
max_completion_length = None,
|
227 |
+
beta = 0.1,
|
228 |
+
desirable_weight = 1.0,
|
229 |
+
undesirable_weight = 1.0,
|
230 |
+
label_pad_token_id = -100,
|
231 |
+
padding_value = None,
|
232 |
+
truncation_mode = 'keep_end',
|
233 |
+
generate_during_eval = False,
|
234 |
+
is_encoder_decoder = None,
|
235 |
+
precompute_ref_log_probs = False,
|
236 |
+
model_init_kwargs = None,
|
237 |
+
ref_model_init_kwargs = None,
|
238 |
+
dataset_num_proc = None,
|
239 |
+
vllm_sampling_params = None,
|
240 |
+
unsloth_num_chunks = -1,
|
241 |
+
**kwargs,
|
242 |
+
):
|
243 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
244 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
245 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
246 |
+
output_dir = 'unsloth_training_checkpoints'
|
247 |
+
save_strategy = 'no'
|
248 |
+
if dataset_num_proc is None:
|
249 |
+
from multiprocessing import cpu_count
|
250 |
+
dataset_num_proc = cpu_count()
|
251 |
+
|
252 |
+
super().__init__(
|
253 |
+
output_dir = output_dir,
|
254 |
+
overwrite_output_dir = overwrite_output_dir,
|
255 |
+
do_train = do_train,
|
256 |
+
do_eval = do_eval,
|
257 |
+
do_predict = do_predict,
|
258 |
+
eval_strategy = eval_strategy,
|
259 |
+
prediction_loss_only = prediction_loss_only,
|
260 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
261 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
262 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
263 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
264 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
265 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
266 |
+
eval_delay = eval_delay,
|
267 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
268 |
+
learning_rate = learning_rate,
|
269 |
+
weight_decay = weight_decay,
|
270 |
+
adam_beta1 = adam_beta1,
|
271 |
+
adam_beta2 = adam_beta2,
|
272 |
+
adam_epsilon = adam_epsilon,
|
273 |
+
max_grad_norm = max_grad_norm,
|
274 |
+
num_train_epochs = num_train_epochs,
|
275 |
+
max_steps = max_steps,
|
276 |
+
lr_scheduler_type = lr_scheduler_type,
|
277 |
+
warmup_ratio = warmup_ratio,
|
278 |
+
warmup_steps = warmup_steps,
|
279 |
+
log_level = log_level,
|
280 |
+
log_level_replica = log_level_replica,
|
281 |
+
log_on_each_node = log_on_each_node,
|
282 |
+
logging_dir = logging_dir,
|
283 |
+
logging_strategy = logging_strategy,
|
284 |
+
logging_first_step = logging_first_step,
|
285 |
+
logging_steps = logging_steps,
|
286 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
287 |
+
save_strategy = save_strategy,
|
288 |
+
save_steps = save_steps,
|
289 |
+
save_total_limit = save_total_limit,
|
290 |
+
save_safetensors = save_safetensors,
|
291 |
+
save_on_each_node = save_on_each_node,
|
292 |
+
save_only_model = save_only_model,
|
293 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
294 |
+
no_cuda = no_cuda,
|
295 |
+
use_cpu = use_cpu,
|
296 |
+
use_mps_device = use_mps_device,
|
297 |
+
seed = seed,
|
298 |
+
data_seed = data_seed,
|
299 |
+
jit_mode_eval = jit_mode_eval,
|
300 |
+
use_ipex = use_ipex,
|
301 |
+
bf16 = bf16,
|
302 |
+
fp16 = fp16,
|
303 |
+
fp16_opt_level = fp16_opt_level,
|
304 |
+
half_precision_backend = half_precision_backend,
|
305 |
+
bf16_full_eval = bf16_full_eval,
|
306 |
+
fp16_full_eval = fp16_full_eval,
|
307 |
+
tf32 = tf32,
|
308 |
+
local_rank = local_rank,
|
309 |
+
ddp_backend = ddp_backend,
|
310 |
+
tpu_num_cores = tpu_num_cores,
|
311 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
312 |
+
debug = debug,
|
313 |
+
dataloader_drop_last = dataloader_drop_last,
|
314 |
+
eval_steps = eval_steps,
|
315 |
+
dataloader_num_workers = dataloader_num_workers,
|
316 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
317 |
+
past_index = past_index,
|
318 |
+
run_name = run_name,
|
319 |
+
disable_tqdm = disable_tqdm,
|
320 |
+
remove_unused_columns = remove_unused_columns,
|
321 |
+
label_names = label_names,
|
322 |
+
load_best_model_at_end = load_best_model_at_end,
|
323 |
+
metric_for_best_model = metric_for_best_model,
|
324 |
+
greater_is_better = greater_is_better,
|
325 |
+
ignore_data_skip = ignore_data_skip,
|
326 |
+
fsdp = fsdp,
|
327 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
328 |
+
fsdp_config = fsdp_config,
|
329 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
330 |
+
accelerator_config = accelerator_config,
|
331 |
+
deepspeed = deepspeed,
|
332 |
+
label_smoothing_factor = label_smoothing_factor,
|
333 |
+
optim = optim,
|
334 |
+
optim_args = optim_args,
|
335 |
+
adafactor = adafactor,
|
336 |
+
group_by_length = group_by_length,
|
337 |
+
length_column_name = length_column_name,
|
338 |
+
report_to = report_to,
|
339 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
340 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
341 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
342 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
343 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
344 |
+
skip_memory_metrics = skip_memory_metrics,
|
345 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
346 |
+
push_to_hub = push_to_hub,
|
347 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
348 |
+
hub_model_id = hub_model_id,
|
349 |
+
hub_strategy = hub_strategy,
|
350 |
+
hub_token = hub_token,
|
351 |
+
hub_private_repo = hub_private_repo,
|
352 |
+
hub_always_push = hub_always_push,
|
353 |
+
hub_revision = hub_revision,
|
354 |
+
gradient_checkpointing = gradient_checkpointing,
|
355 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
356 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
357 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
358 |
+
fp16_backend = fp16_backend,
|
359 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
360 |
+
push_to_hub_organization = push_to_hub_organization,
|
361 |
+
push_to_hub_token = push_to_hub_token,
|
362 |
+
mp_parameters = mp_parameters,
|
363 |
+
auto_find_batch_size = auto_find_batch_size,
|
364 |
+
full_determinism = full_determinism,
|
365 |
+
torchdynamo = torchdynamo,
|
366 |
+
ray_scope = ray_scope,
|
367 |
+
ddp_timeout = ddp_timeout,
|
368 |
+
torch_compile = torch_compile,
|
369 |
+
torch_compile_backend = torch_compile_backend,
|
370 |
+
torch_compile_mode = torch_compile_mode,
|
371 |
+
include_tokens_per_second = include_tokens_per_second,
|
372 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
373 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
374 |
+
optim_target_modules = optim_target_modules,
|
375 |
+
batch_eval_metrics = batch_eval_metrics,
|
376 |
+
eval_on_start = eval_on_start,
|
377 |
+
use_liger_kernel = use_liger_kernel,
|
378 |
+
liger_kernel_config = liger_kernel_config,
|
379 |
+
eval_use_gather_object = eval_use_gather_object,
|
380 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
381 |
+
max_length = max_length,
|
382 |
+
max_prompt_length = max_prompt_length,
|
383 |
+
max_completion_length = max_completion_length,
|
384 |
+
beta = beta,
|
385 |
+
desirable_weight = desirable_weight,
|
386 |
+
undesirable_weight = undesirable_weight,
|
387 |
+
label_pad_token_id = label_pad_token_id,
|
388 |
+
padding_value = padding_value,
|
389 |
+
truncation_mode = truncation_mode,
|
390 |
+
generate_during_eval = generate_during_eval,
|
391 |
+
is_encoder_decoder = is_encoder_decoder,
|
392 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
393 |
+
model_init_kwargs = model_init_kwargs,
|
394 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
395 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
396 |
+
self.vllm_sampling_params = vllm_sampling_params
|
397 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
398 |
+
pass
|
399 |
+
|
400 |
+
class _UnslothKTOTrainer(Trainer):
|
401 |
+
r""""""
|
402 |
+
|
403 |
+
_tag_names = ["trl", "kto"]
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
408 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
409 |
+
args: KTOConfig = None,
|
410 |
+
train_dataset: Optional[Dataset] = None,
|
411 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
412 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
413 |
+
data_collator: Optional[DataCollator] = None,
|
414 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
415 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
416 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
417 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
418 |
+
peft_config: Optional[Dict] = None,
|
419 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
420 |
+
model_adapter_name: Optional[str] = None,
|
421 |
+
ref_adapter_name: Optional[str] = None,
|
422 |
+
):
|
423 |
+
if type(args) == TrainingArguments:
|
424 |
+
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
425 |
+
|
426 |
+
if args.model_init_kwargs is None:
|
427 |
+
model_init_kwargs = {}
|
428 |
+
elif not isinstance(model, str):
|
429 |
+
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
430 |
+
else:
|
431 |
+
model_init_kwargs = args.model_init_kwargs
|
432 |
+
model_init_kwargs["torch_dtype"] = (
|
433 |
+
model_init_kwargs["torch_dtype"]
|
434 |
+
if model_init_kwargs["torch_dtype"] in ["auto", None]
|
435 |
+
else getattr(torch, model_init_kwargs["torch_dtype"])
|
436 |
+
)
|
437 |
+
|
438 |
+
if args.ref_model_init_kwargs is None:
|
439 |
+
ref_model_init_kwargs = {}
|
440 |
+
elif not isinstance(ref_model, str):
|
441 |
+
raise ValueError(
|
442 |
+
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
446 |
+
ref_model_init_kwargs["torch_dtype"] = (
|
447 |
+
ref_model_init_kwargs["torch_dtype"]
|
448 |
+
if ref_model_init_kwargs["torch_dtype"] in ["auto", None]
|
449 |
+
else getattr(torch, ref_model_init_kwargs["torch_dtype"])
|
450 |
+
)
|
451 |
+
|
452 |
+
if isinstance(model, str):
|
453 |
+
warnings.warn(
|
454 |
+
"You passed a model_id to the KTOTrainer. This will automatically create an "
|
455 |
+
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
|
456 |
+
)
|
457 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
458 |
+
|
459 |
+
if isinstance(ref_model, str):
|
460 |
+
warnings.warn(
|
461 |
+
"You passed a ref model_id to the KTOTrainer. This will automatically create an "
|
462 |
+
"`AutoModelForCausalLM`"
|
463 |
+
)
|
464 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
465 |
+
|
466 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
467 |
+
# has been called in order to properly call autocast if needed.
|
468 |
+
self._peft_has_been_casted_to_bf16 = False
|
469 |
+
|
470 |
+
if not is_peft_available() and peft_config is not None:
|
471 |
+
raise ValueError(
|
472 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
473 |
+
)
|
474 |
+
elif is_peft_available() and peft_config is not None:
|
475 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
476 |
+
if isinstance(model, PeftModel):
|
477 |
+
model = model.merge_and_unload()
|
478 |
+
|
479 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
480 |
+
_support_gc_kwargs = hasattr(
|
481 |
+
args, "gradient_checkpointing_kwargs"
|
482 |
+
) and "gradient_checkpointing_kwargs" in list(
|
483 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
484 |
+
)
|
485 |
+
|
486 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
487 |
+
|
488 |
+
if _support_gc_kwargs:
|
489 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
490 |
+
|
491 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
492 |
+
elif getattr(args, "gradient_checkpointing", False):
|
493 |
+
# For backward compatibility with older versions of transformers
|
494 |
+
if hasattr(model, "enable_input_require_grads"):
|
495 |
+
model.enable_input_require_grads()
|
496 |
+
else:
|
497 |
+
|
498 |
+
def make_inputs_require_grad(module, input, output):
|
499 |
+
output.requires_grad_(True)
|
500 |
+
|
501 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
502 |
+
|
503 |
+
# get peft model with the given config
|
504 |
+
model = model
|
505 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
506 |
+
peft_module_casting_to_bf16(model)
|
507 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
508 |
+
self._peft_has_been_casted_to_bf16 = True
|
509 |
+
|
510 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
511 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
512 |
+
# fail or completely fail.
|
513 |
+
elif getattr(args, "gradient_checkpointing", False):
|
514 |
+
# For backward compatibility with older versions of transformers
|
515 |
+
if hasattr(model, "enable_input_require_grads"):
|
516 |
+
model.enable_input_require_grads()
|
517 |
+
else:
|
518 |
+
|
519 |
+
def make_inputs_require_grad(module, input, output):
|
520 |
+
output.requires_grad_(True)
|
521 |
+
|
522 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
523 |
+
|
524 |
+
if args.generate_during_eval and not is_wandb_available():
|
525 |
+
raise ValueError(
|
526 |
+
"`generate_during_eval=True` requires Weights and Biases to be installed."
|
527 |
+
" Please install with `pip install wandb` to resolve."
|
528 |
+
)
|
529 |
+
|
530 |
+
if model is not None:
|
531 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
532 |
+
elif args.is_encoder_decoder is None:
|
533 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
534 |
+
else:
|
535 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
536 |
+
|
537 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
538 |
+
self.model_adapter_name = model_adapter_name
|
539 |
+
self.ref_adapter_name = ref_adapter_name
|
540 |
+
|
541 |
+
if ref_model:
|
542 |
+
self.ref_model = ref_model
|
543 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
544 |
+
# The `model` with adapters turned off will be used as the reference model
|
545 |
+
self.ref_model = None
|
546 |
+
else:
|
547 |
+
self.ref_model = create_reference_model(model)
|
548 |
+
|
549 |
+
if tokenizer is None:
|
550 |
+
raise ValueError(
|
551 |
+
"max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
|
552 |
+
)
|
553 |
+
if args.max_length is None:
|
554 |
+
warnings.warn(
|
555 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
556 |
+
" it will be set to `512` by default, but you should do it yourself in the future.",
|
557 |
+
UserWarning,
|
558 |
+
)
|
559 |
+
max_length = 512
|
560 |
+
if args.max_length is not None:
|
561 |
+
max_length = args.max_length
|
562 |
+
|
563 |
+
if args.max_prompt_length is None:
|
564 |
+
warnings.warn(
|
565 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
566 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
567 |
+
UserWarning,
|
568 |
+
)
|
569 |
+
max_prompt_length = 128
|
570 |
+
if args.max_prompt_length is not None:
|
571 |
+
max_prompt_length = args.max_prompt_length
|
572 |
+
|
573 |
+
max_completion_length = None
|
574 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
575 |
+
warnings.warn(
|
576 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
577 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
578 |
+
UserWarning,
|
579 |
+
)
|
580 |
+
max_completion_length = 128
|
581 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
582 |
+
max_completion_length = args.max_completion_length
|
583 |
+
|
584 |
+
if data_collator is None:
|
585 |
+
data_collator = DPODataCollatorWithPadding(
|
586 |
+
pad_token_id=tokenizer.pad_token_id,
|
587 |
+
label_pad_token_id=args.label_pad_token_id,
|
588 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
589 |
+
)
|
590 |
+
|
591 |
+
if args.remove_unused_columns:
|
592 |
+
args.remove_unused_columns = False
|
593 |
+
# warn users
|
594 |
+
warnings.warn(
|
595 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
596 |
+
" we have set it for you, but you should do it yourself in the future.",
|
597 |
+
UserWarning,
|
598 |
+
)
|
599 |
+
|
600 |
+
self.use_dpo_data_collator = True
|
601 |
+
else:
|
602 |
+
self.use_dpo_data_collator = False
|
603 |
+
|
604 |
+
# disable dropout in the model and reference model
|
605 |
+
disable_dropout_in_model(model)
|
606 |
+
if self.ref_model is not None:
|
607 |
+
disable_dropout_in_model(self.ref_model)
|
608 |
+
|
609 |
+
self.max_length = max_length
|
610 |
+
self.generate_during_eval = args.generate_during_eval
|
611 |
+
self.label_pad_token_id = args.label_pad_token_id
|
612 |
+
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
|
613 |
+
self.max_prompt_length = max_prompt_length
|
614 |
+
self.truncation_mode = args.truncation_mode
|
615 |
+
self.max_completion_length = max_completion_length
|
616 |
+
self.tokenizer = tokenizer
|
617 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
618 |
+
|
619 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
620 |
+
# keep track of first called to avoid computation of future calls
|
621 |
+
self._precomputed_train_ref_log_probs = False
|
622 |
+
self._precomputed_eval_ref_log_probs = False
|
623 |
+
|
624 |
+
# metric
|
625 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
626 |
+
|
627 |
+
# KTO parameter
|
628 |
+
self.beta = args.beta
|
629 |
+
self.desirable_weight = args.desirable_weight
|
630 |
+
self.undesirable_weight = args.undesirable_weight
|
631 |
+
|
632 |
+
with PartialState().local_main_process_first():
|
633 |
+
# Shuffle the datasets
|
634 |
+
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
635 |
+
if eval_dataset is not None:
|
636 |
+
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
637 |
+
# Tokenize and prepare the training datasets
|
638 |
+
train_dataset = train_dataset.map(
|
639 |
+
_tokenize,
|
640 |
+
fn_kwargs={"tokenizer": self.tokenizer},
|
641 |
+
batched=True,
|
642 |
+
desc="Tokenizing train dataset",
|
643 |
+
)
|
644 |
+
# Get KL datasets
|
645 |
+
total_batch_size = (
|
646 |
+
max(torch.cuda.device_count(), 1) * args.per_device_train_batch_size * args.gradient_accumulation_steps
|
647 |
+
)
|
648 |
+
if total_batch_size <= 1:
|
649 |
+
raise ValueError(
|
650 |
+
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
|
651 |
+
)
|
652 |
+
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
653 |
+
# i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n]
|
654 |
+
train_kl_dataset = train_dataset.map(
|
655 |
+
_get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting KL train dataset"
|
656 |
+
)
|
657 |
+
# Prepare the datasets
|
658 |
+
fn_kwargs = {
|
659 |
+
"prefix": "",
|
660 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
661 |
+
"tokenizer": self.tokenizer,
|
662 |
+
"max_length": self.max_length,
|
663 |
+
"truncation_mode": self.truncation_mode,
|
664 |
+
"label_pad_token_id": self.label_pad_token_id,
|
665 |
+
"max_prompt_length": self.max_prompt_length,
|
666 |
+
}
|
667 |
+
train_dataset = train_dataset.map(
|
668 |
+
_process_tokens,
|
669 |
+
fn_kwargs=fn_kwargs,
|
670 |
+
num_proc=args.dataset_num_proc,
|
671 |
+
desc="Processing tokenized train dataset",
|
672 |
+
)
|
673 |
+
fn_kwargs["prefix"] = "KL_"
|
674 |
+
train_kl_dataset = train_kl_dataset.map(
|
675 |
+
_process_tokens,
|
676 |
+
fn_kwargs=fn_kwargs,
|
677 |
+
num_proc=args.dataset_num_proc,
|
678 |
+
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
679 |
+
desc="Processing tokenized train KL dataset",
|
680 |
+
)
|
681 |
+
|
682 |
+
# merge the datasets
|
683 |
+
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
684 |
+
|
685 |
+
if eval_dataset is not None:
|
686 |
+
# Tokenize
|
687 |
+
eval_dataset = eval_dataset.map(
|
688 |
+
_tokenize,
|
689 |
+
fn_kwargs={"tokenizer": self.tokenizer},
|
690 |
+
batched=True,
|
691 |
+
desc="Tokenizing eval dataset",
|
692 |
+
)
|
693 |
+
# Get KL dataset
|
694 |
+
eval_kl_dataset = eval_dataset.map(
|
695 |
+
_get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting eval KL dataset"
|
696 |
+
)
|
697 |
+
# Process
|
698 |
+
fn_kwargs = {
|
699 |
+
"prefix": "",
|
700 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
701 |
+
"tokenizer": self.tokenizer,
|
702 |
+
"max_length": self.max_length,
|
703 |
+
"truncation_mode": self.truncation_mode,
|
704 |
+
"label_pad_token_id": self.label_pad_token_id,
|
705 |
+
"max_prompt_length": self.max_prompt_length,
|
706 |
+
}
|
707 |
+
eval_dataset = eval_dataset.map(
|
708 |
+
_process_tokens,
|
709 |
+
fn_kwargs=fn_kwargs,
|
710 |
+
num_proc=args.dataset_num_proc,
|
711 |
+
desc="Processing tokenized eval dataset",
|
712 |
+
)
|
713 |
+
fn_kwargs["prefix"] = "KL_"
|
714 |
+
eval_kl_dataset = eval_kl_dataset.map(
|
715 |
+
_process_tokens,
|
716 |
+
fn_kwargs=fn_kwargs,
|
717 |
+
num_proc=args.dataset_num_proc,
|
718 |
+
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
719 |
+
desc="Processing tokenized eval KL dataset",
|
720 |
+
)
|
721 |
+
|
722 |
+
# merge the datasets
|
723 |
+
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
724 |
+
|
725 |
+
desirable = train_dataset.filter(
|
726 |
+
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
727 |
+
)
|
728 |
+
undesirable = train_dataset.filter(
|
729 |
+
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
730 |
+
)
|
731 |
+
|
732 |
+
if len(desirable) != len(undesirable):
|
733 |
+
# The lower and upper bounds come from Eq. [8] of https://arxiv.org/abs/2402.01306
|
734 |
+
des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
|
735 |
+
des_weight_upper_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33, 2)
|
736 |
+
und_weight_lower_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1.33, 2)
|
737 |
+
und_weight_upper_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1, 2)
|
738 |
+
|
739 |
+
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
740 |
+
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
741 |
+
|
742 |
+
if not (des_weight_in_range or und_weight_in_range):
|
743 |
+
warnings.warn(
|
744 |
+
f"""
|
745 |
+
You have different amounts of desirable/positive and undesirable/negative examples but the
|
746 |
+
weights on the desirable and undesirable losses don't seem to be in an ideal range. Based
|
747 |
+
on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}]
|
748 |
+
or undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH).
|
749 |
+
See the documentation on how to optimally set these weights.""",
|
750 |
+
UserWarning,
|
751 |
+
)
|
752 |
+
|
753 |
+
super().__init__(
|
754 |
+
model=model,
|
755 |
+
args=args,
|
756 |
+
data_collator=data_collator,
|
757 |
+
train_dataset=train_dataset,
|
758 |
+
eval_dataset=eval_dataset,
|
759 |
+
tokenizer=tokenizer,
|
760 |
+
model_init=model_init,
|
761 |
+
compute_metrics=compute_metrics,
|
762 |
+
callbacks=callbacks,
|
763 |
+
optimizers=optimizers,
|
764 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
765 |
+
)
|
766 |
+
|
767 |
+
# Add tags for models that have been loaded with the correct transformers version
|
768 |
+
if hasattr(self.model, "add_model_tags"):
|
769 |
+
self.model.add_model_tags(self._tag_names)
|
770 |
+
|
771 |
+
if not hasattr(self, "accelerator"):
|
772 |
+
raise AttributeError(
|
773 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
774 |
+
)
|
775 |
+
|
776 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
777 |
+
if self.is_deepspeed_enabled:
|
778 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
779 |
+
raise ValueError(
|
780 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
781 |
+
)
|
782 |
+
|
783 |
+
if self.ref_model is None:
|
784 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
785 |
+
raise ValueError(
|
786 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
787 |
+
)
|
788 |
+
else:
|
789 |
+
if self.is_deepspeed_enabled:
|
790 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
791 |
+
else:
|
792 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
793 |
+
|
794 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
795 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
796 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
797 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
798 |
+
|
799 |
+
if model is not None:
|
800 |
+
if hasattr(model, "config"):
|
801 |
+
hidden_size = (
|
802 |
+
max(model.config.hidden_sizes)
|
803 |
+
if getattr(model.config, "hidden_sizes", None)
|
804 |
+
else getattr(model.config, "hidden_size", None)
|
805 |
+
)
|
806 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
807 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
808 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
809 |
+
config_kwargs.update(
|
810 |
+
{
|
811 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
812 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
813 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
814 |
+
}
|
815 |
+
)
|
816 |
+
|
817 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
818 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
819 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
820 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
821 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
822 |
+
model.eval()
|
823 |
+
return model
|
824 |
+
|
825 |
+
@contextmanager
|
826 |
+
def null_ref_context(self):
|
827 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
828 |
+
with self.accelerator.unwrap_model(
|
829 |
+
self.model
|
830 |
+
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
|
831 |
+
if self.ref_adapter_name:
|
832 |
+
self.model.set_adapter(self.ref_adapter_name)
|
833 |
+
yield
|
834 |
+
if self.ref_adapter_name:
|
835 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
836 |
+
|
837 |
+
def get_train_dataloader(self) -> DataLoader:
|
838 |
+
"""
|
839 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
840 |
+
|
841 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
842 |
+
"""
|
843 |
+
|
844 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
845 |
+
dataloader_params = {
|
846 |
+
"batch_size": self.args.per_device_train_batch_size,
|
847 |
+
"collate_fn": self.data_collator,
|
848 |
+
"num_workers": self.args.dataloader_num_workers,
|
849 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
850 |
+
"shuffle": False,
|
851 |
+
}
|
852 |
+
|
853 |
+
# prepare dataloader
|
854 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
855 |
+
reference_completion_logps = []
|
856 |
+
reference_KL_logps = []
|
857 |
+
|
858 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
859 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
860 |
+
|
861 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
862 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
863 |
+
|
864 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
865 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
866 |
+
|
867 |
+
self.train_dataset = self.train_dataset.add_column(
|
868 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
869 |
+
)
|
870 |
+
self.train_dataset = self.train_dataset.add_column(
|
871 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
872 |
+
)
|
873 |
+
|
874 |
+
self._precomputed_train_ref_log_probs = True
|
875 |
+
|
876 |
+
return super().get_train_dataloader()
|
877 |
+
|
878 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
879 |
+
"""
|
880 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
881 |
+
|
882 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
883 |
+
|
884 |
+
Args:
|
885 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
886 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
887 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
888 |
+
"""
|
889 |
+
if eval_dataset is None and self.eval_dataset is None:
|
890 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
891 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
892 |
+
|
893 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
894 |
+
dataloader_params = {
|
895 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
896 |
+
"collate_fn": self.data_collator,
|
897 |
+
"num_workers": self.args.dataloader_num_workers,
|
898 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
899 |
+
"shuffle": False,
|
900 |
+
}
|
901 |
+
|
902 |
+
# prepare dataloader
|
903 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
904 |
+
|
905 |
+
reference_completion_logps = []
|
906 |
+
reference_KL_logps = []
|
907 |
+
|
908 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
909 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
910 |
+
|
911 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
912 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
913 |
+
|
914 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
915 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
916 |
+
|
917 |
+
eval_dataset = eval_dataset.add_column(
|
918 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
919 |
+
)
|
920 |
+
eval_dataset = eval_dataset.add_column(
|
921 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
922 |
+
)
|
923 |
+
|
924 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
925 |
+
if self.eval_dataset is not None:
|
926 |
+
self.eval_dataset = eval_dataset
|
927 |
+
self._precomputed_eval_ref_log_probs = True
|
928 |
+
|
929 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
930 |
+
|
931 |
+
def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
|
932 |
+
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
933 |
+
with torch.no_grad():
|
934 |
+
if self.ref_model is None:
|
935 |
+
with self.null_ref_context():
|
936 |
+
if self.is_encoder_decoder:
|
937 |
+
completion_logits = self.model(
|
938 |
+
padded_batch["prompt_input_ids"],
|
939 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
940 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
941 |
+
labels=padded_batch["completion_labels"],
|
942 |
+
).logits
|
943 |
+
|
944 |
+
KL_logits = self.model(
|
945 |
+
padded_batch["KL_prompt_input_ids"],
|
946 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
947 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
948 |
+
labels=padded_batch["KL_completion_labels"],
|
949 |
+
).logits
|
950 |
+
else:
|
951 |
+
completion_logits = self.model(
|
952 |
+
padded_batch["completion_input_ids"],
|
953 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
954 |
+
).logits
|
955 |
+
|
956 |
+
KL_logits = self.model(
|
957 |
+
padded_batch["KL_completion_input_ids"],
|
958 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
959 |
+
).logits
|
960 |
+
else:
|
961 |
+
if self.is_encoder_decoder:
|
962 |
+
completion_logits = self.ref_model(
|
963 |
+
padded_batch["prompt_input_ids"],
|
964 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
965 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
966 |
+
labels=padded_batch["completion_labels"],
|
967 |
+
).logits
|
968 |
+
|
969 |
+
KL_logits = self.ref_model(
|
970 |
+
padded_batch["KL_prompt_input_ids"],
|
971 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
972 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
973 |
+
labels=padded_batch["KL_completion_labels"],
|
974 |
+
).logits
|
975 |
+
else:
|
976 |
+
completion_logits = self.ref_model(
|
977 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
978 |
+
).logits
|
979 |
+
|
980 |
+
KL_logits = self.ref_model(
|
981 |
+
padded_batch["KL_completion_input_ids"],
|
982 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
983 |
+
).logits
|
984 |
+
|
985 |
+
completion_logps = self.get_batch_logps(
|
986 |
+
completion_logits,
|
987 |
+
padded_batch["completion_labels"],
|
988 |
+
average_log_prob=False,
|
989 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
990 |
+
label_pad_token_id=self.label_pad_token_id,
|
991 |
+
)
|
992 |
+
|
993 |
+
KL_logps = self.get_batch_logps(
|
994 |
+
KL_logits,
|
995 |
+
padded_batch["KL_completion_labels"],
|
996 |
+
average_log_prob=False,
|
997 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
998 |
+
label_pad_token_id=self.label_pad_token_id,
|
999 |
+
)
|
1000 |
+
|
1001 |
+
return completion_logps, KL_logps
|
1002 |
+
|
1003 |
+
@staticmethod
|
1004 |
+
def get_batch_logps(
|
1005 |
+
logits: torch.FloatTensor,
|
1006 |
+
labels: torch.LongTensor,
|
1007 |
+
average_log_prob: bool = False,
|
1008 |
+
label_pad_token_id: int = -100,
|
1009 |
+
is_encoder_decoder: bool = False,
|
1010 |
+
) -> torch.FloatTensor:
|
1011 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1012 |
+
|
1013 |
+
Args:
|
1014 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1015 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1016 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1017 |
+
|
1018 |
+
Returns:
|
1019 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1020 |
+
"""
|
1021 |
+
if logits.shape[:-1] != labels.shape:
|
1022 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1023 |
+
|
1024 |
+
if not is_encoder_decoder:
|
1025 |
+
labels = labels[:, 1:].clone()
|
1026 |
+
logits = logits[:, :-1, :]
|
1027 |
+
else:
|
1028 |
+
# Fixes end-dec RuntimeError
|
1029 |
+
labels = labels.clone()
|
1030 |
+
|
1031 |
+
loss_mask = labels != label_pad_token_id
|
1032 |
+
|
1033 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1034 |
+
labels[labels == label_pad_token_id] = 0
|
1035 |
+
|
1036 |
+
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
1037 |
+
|
1038 |
+
if average_log_prob:
|
1039 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1040 |
+
else:
|
1041 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1042 |
+
|
1043 |
+
def forward(
|
1044 |
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
1045 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1046 |
+
if self.is_encoder_decoder:
|
1047 |
+
with torch.no_grad():
|
1048 |
+
KL_logits = model(
|
1049 |
+
batch["KL_prompt_input_ids"],
|
1050 |
+
attention_mask=batch["KL_prompt_attention_mask"],
|
1051 |
+
decoder_input_ids=batch.get("KL_completion_decoder_input_ids"),
|
1052 |
+
labels=batch["KL_completion_labels"],
|
1053 |
+
).logits
|
1054 |
+
|
1055 |
+
completion_logits = model(
|
1056 |
+
batch["prompt_input_ids"],
|
1057 |
+
attention_mask=batch["prompt_attention_mask"],
|
1058 |
+
decoder_input_ids=batch.get("completion_decoder_input_ids"),
|
1059 |
+
labels=batch["completion_labels"],
|
1060 |
+
).logits
|
1061 |
+
else:
|
1062 |
+
with torch.no_grad():
|
1063 |
+
KL_logits = model(
|
1064 |
+
batch["KL_completion_input_ids"],
|
1065 |
+
attention_mask=batch["KL_completion_attention_mask"],
|
1066 |
+
).logits
|
1067 |
+
|
1068 |
+
completion_logits = model(
|
1069 |
+
batch["completion_input_ids"],
|
1070 |
+
attention_mask=batch["completion_attention_mask"],
|
1071 |
+
).logits
|
1072 |
+
|
1073 |
+
completion_logps = self.get_batch_logps(
|
1074 |
+
completion_logits,
|
1075 |
+
batch["completion_labels"],
|
1076 |
+
average_log_prob=False,
|
1077 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1078 |
+
label_pad_token_id=self.label_pad_token_id,
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
KL_logps = self.get_batch_logps(
|
1082 |
+
KL_logits,
|
1083 |
+
batch["KL_completion_labels"],
|
1084 |
+
average_log_prob=False,
|
1085 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1086 |
+
label_pad_token_id=self.label_pad_token_id,
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
1090 |
+
raise ValueError(
|
1091 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
1092 |
+
"examples for which an output sequence was predicted."
|
1093 |
+
)
|
1094 |
+
|
1095 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1096 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1097 |
+
|
1098 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
1099 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
1100 |
+
|
1101 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
1102 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
1103 |
+
|
1104 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
1105 |
+
|
1106 |
+
def kto_loss(
|
1107 |
+
self,
|
1108 |
+
policy_chosen_logps: torch.FloatTensor,
|
1109 |
+
policy_rejected_logps: torch.FloatTensor,
|
1110 |
+
policy_KL_logps: torch.FloatTensor,
|
1111 |
+
reference_chosen_logps: torch.FloatTensor,
|
1112 |
+
reference_rejected_logps: torch.FloatTensor,
|
1113 |
+
reference_KL_logps: torch.FloatTensor,
|
1114 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1115 |
+
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
1116 |
+
|
1117 |
+
Args:
|
1118 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1119 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1120 |
+
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
1121 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1122 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1123 |
+
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
1124 |
+
|
1125 |
+
Returns:
|
1126 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
1127 |
+
The losses tensor contains the KTO loss for each example in the batch.
|
1128 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1129 |
+
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
1130 |
+
"""
|
1131 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
1132 |
+
kl = self.accelerator.gather(kl).mean().clamp(min=0)
|
1133 |
+
|
1134 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1135 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1136 |
+
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
1137 |
+
chosen_rewards = self.beta * chosen_logratios.detach()
|
1138 |
+
else:
|
1139 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1140 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1141 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1142 |
+
|
1143 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1144 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1145 |
+
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
1146 |
+
rejected_rewards = self.beta * rejected_logratios.detach()
|
1147 |
+
else:
|
1148 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1149 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1150 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1151 |
+
|
1152 |
+
losses = torch.cat(
|
1153 |
+
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
1154 |
+
0,
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
return losses, chosen_rewards, rejected_rewards, kl
|
1158 |
+
|
1159 |
+
def get_batch_loss_metrics(
|
1160 |
+
self,
|
1161 |
+
model,
|
1162 |
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
1163 |
+
):
|
1164 |
+
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
1165 |
+
metrics = {}
|
1166 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1167 |
+
|
1168 |
+
(
|
1169 |
+
policy_chosen_logps,
|
1170 |
+
policy_rejected_logps,
|
1171 |
+
policy_chosen_logits,
|
1172 |
+
policy_rejected_logits,
|
1173 |
+
policy_KL_logps,
|
1174 |
+
) = self.forward(model, batch)
|
1175 |
+
|
1176 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
1177 |
+
if "reference_logps" in batch:
|
1178 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1179 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1180 |
+
|
1181 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1182 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1183 |
+
reference_KL_logps = batch["reference_KL_logps"]
|
1184 |
+
else:
|
1185 |
+
with torch.no_grad():
|
1186 |
+
if self.ref_model is None:
|
1187 |
+
with self.null_ref_context():
|
1188 |
+
(
|
1189 |
+
reference_chosen_logps,
|
1190 |
+
reference_rejected_logps,
|
1191 |
+
_,
|
1192 |
+
_,
|
1193 |
+
reference_KL_logps,
|
1194 |
+
) = self.forward(self.model, batch)
|
1195 |
+
else:
|
1196 |
+
(
|
1197 |
+
reference_chosen_logps,
|
1198 |
+
reference_rejected_logps,
|
1199 |
+
_,
|
1200 |
+
_,
|
1201 |
+
reference_KL_logps,
|
1202 |
+
) = self.forward(self.ref_model, batch)
|
1203 |
+
|
1204 |
+
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
1205 |
+
policy_chosen_logps,
|
1206 |
+
policy_rejected_logps,
|
1207 |
+
policy_KL_logps,
|
1208 |
+
reference_chosen_logps,
|
1209 |
+
reference_rejected_logps,
|
1210 |
+
reference_KL_logps,
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1214 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1215 |
+
|
1216 |
+
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
|
1217 |
+
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
|
1218 |
+
|
1219 |
+
if all_num_chosen > 0:
|
1220 |
+
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
|
1221 |
+
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
|
1222 |
+
metrics["count/chosen"] = all_num_chosen
|
1223 |
+
|
1224 |
+
if all_num_rejected > 0:
|
1225 |
+
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
|
1226 |
+
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
|
1227 |
+
metrics["count/rejected"] = all_num_rejected
|
1228 |
+
|
1229 |
+
metrics["kl"] = kl.item()
|
1230 |
+
|
1231 |
+
return losses.nanmean(), metrics
|
1232 |
+
|
1233 |
+
def compute_loss(
|
1234 |
+
self,
|
1235 |
+
model: Union[PreTrainedModel, nn.Module],
|
1236 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1237 |
+
return_outputs=False,
|
1238 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
1239 |
+
if not self.use_dpo_data_collator:
|
1240 |
+
warnings.warn(
|
1241 |
+
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1242 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1243 |
+
)
|
1244 |
+
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1245 |
+
|
1246 |
+
with compute_loss_context_manager():
|
1247 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1248 |
+
|
1249 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1250 |
+
loss = loss.to(self.args.device)
|
1251 |
+
# force log the metrics
|
1252 |
+
if self.accelerator.is_main_process:
|
1253 |
+
self.store_metrics(metrics, train_eval="train")
|
1254 |
+
|
1255 |
+
if return_outputs:
|
1256 |
+
return (loss, metrics)
|
1257 |
+
return loss
|
1258 |
+
|
1259 |
+
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1260 |
+
for key, value in metrics.items():
|
1261 |
+
self._stored_metrics[train_eval][key].append(value)
|
1262 |
+
|
1263 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1264 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
1265 |
+
return None
|
1266 |
+
return SequentialSampler(self.train_dataset)
|
1267 |
+
|
1268 |
+
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
1269 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1270 |
+
|
1271 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1272 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1273 |
+
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
1274 |
+
|
1275 |
+
with generate_context_manager():
|
1276 |
+
policy_output = model.generate(
|
1277 |
+
input_ids=batch["prompt_input_ids"],
|
1278 |
+
attention_mask=batch["prompt_attention_mask"],
|
1279 |
+
max_length=self.max_length,
|
1280 |
+
do_sample=True,
|
1281 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
1282 |
+
)
|
1283 |
+
|
1284 |
+
# if reference_output in batch use that otherwise use the reference model
|
1285 |
+
if "reference_output" in batch:
|
1286 |
+
reference_output = batch["reference_output"]
|
1287 |
+
else:
|
1288 |
+
if self.ref_model is None:
|
1289 |
+
with self.null_ref_context():
|
1290 |
+
reference_output = self.model.generate(
|
1291 |
+
input_ids=batch["prompt_input_ids"],
|
1292 |
+
attention_mask=batch["prompt_attention_mask"],
|
1293 |
+
max_length=self.max_length,
|
1294 |
+
do_sample=True,
|
1295 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
1296 |
+
)
|
1297 |
+
else:
|
1298 |
+
reference_output = self.ref_model.generate(
|
1299 |
+
input_ids=batch["prompt_input_ids"],
|
1300 |
+
attention_mask=batch["prompt_attention_mask"],
|
1301 |
+
max_length=self.max_length,
|
1302 |
+
do_sample=True,
|
1303 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
1304 |
+
)
|
1305 |
+
|
1306 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
|
1307 |
+
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
1308 |
+
|
1309 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
|
1310 |
+
reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
|
1311 |
+
|
1312 |
+
return policy_output_decoded, reference_output_decoded
|
1313 |
+
|
1314 |
+
def prediction_step(
|
1315 |
+
self,
|
1316 |
+
model: Union[PreTrainedModel, nn.Module],
|
1317 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1318 |
+
prediction_loss_only: bool,
|
1319 |
+
ignore_keys: Optional[List[str]] = None,
|
1320 |
+
):
|
1321 |
+
if not self.use_dpo_data_collator:
|
1322 |
+
warnings.warn(
|
1323 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1324 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1325 |
+
)
|
1326 |
+
if ignore_keys is None:
|
1327 |
+
if hasattr(model, "config"):
|
1328 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1329 |
+
else:
|
1330 |
+
ignore_keys = []
|
1331 |
+
|
1332 |
+
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1333 |
+
with torch.no_grad(), prediction_context_manager():
|
1334 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1335 |
+
|
1336 |
+
# force log the metrics
|
1337 |
+
if self.accelerator.is_main_process:
|
1338 |
+
self.store_metrics(metrics, train_eval="eval")
|
1339 |
+
|
1340 |
+
if prediction_loss_only:
|
1341 |
+
return (loss.detach(), None, None)
|
1342 |
+
|
1343 |
+
# logits for the chosen and rejected samples from model
|
1344 |
+
logits_dict = {
|
1345 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
1346 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
1347 |
+
}
|
1348 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1349 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1350 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1351 |
+
|
1352 |
+
return (loss.detach(), logits, labels)
|
1353 |
+
|
1354 |
+
def evaluation_loop(
|
1355 |
+
self,
|
1356 |
+
dataloader: DataLoader,
|
1357 |
+
description: str,
|
1358 |
+
prediction_loss_only: Optional[bool] = None,
|
1359 |
+
ignore_keys: Optional[List[str]] = None,
|
1360 |
+
metric_key_prefix: str = "eval",
|
1361 |
+
) -> EvalLoopOutput:
|
1362 |
+
"""
|
1363 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1364 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1365 |
+
|
1366 |
+
Works both with or without labels.
|
1367 |
+
"""
|
1368 |
+
|
1369 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1370 |
+
if self.generate_during_eval:
|
1371 |
+
# Generate random indices within the range of the total number of samples
|
1372 |
+
num_samples = len(dataloader.dataset)
|
1373 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1374 |
+
|
1375 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1376 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1377 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1378 |
+
random_batch = self._prepare_inputs(random_batch)
|
1379 |
+
|
1380 |
+
target_indicies = [i for i in range(len(random_batch["kl"])) if random_batch["kl"][i] is False]
|
1381 |
+
target_batch = {
|
1382 |
+
"prompt_input_ids": itemgetter(*target_indicies)(random_batch["prompt_input_ids"]),
|
1383 |
+
"prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
|
1384 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1385 |
+
}
|
1386 |
+
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
|
1387 |
+
|
1388 |
+
self.log(
|
1389 |
+
{
|
1390 |
+
"game_log": wandb.Table(
|
1391 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
1392 |
+
rows=[
|
1393 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1394 |
+
for prompt, pol, ref in zip(
|
1395 |
+
target_batch["prompt"], policy_output_decoded, ref_output_decoded
|
1396 |
+
)
|
1397 |
+
],
|
1398 |
+
)
|
1399 |
+
}
|
1400 |
+
)
|
1401 |
+
self.state.log_history.pop()
|
1402 |
+
|
1403 |
+
# Base evaluation
|
1404 |
+
initial_output = super().evaluation_loop(
|
1405 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
return initial_output
|
1409 |
+
|
1410 |
+
def log(self, logs: Dict[str, float]) -> None:
|
1411 |
+
"""
|
1412 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1413 |
+
|
1414 |
+
Args:
|
1415 |
+
logs (`Dict[str, float]`):
|
1416 |
+
The values to log.
|
1417 |
+
"""
|
1418 |
+
# logs either has 'loss' or 'eval_loss'
|
1419 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1420 |
+
# accumulate average metrics from sums and lengths
|
1421 |
+
for split in ["chosen", "rejected"]:
|
1422 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1423 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1424 |
+
logs[f"{train_eval}/rewards/{split}"] = (
|
1425 |
+
torch.Tensor(self._stored_metrics[train_eval][f"rewards/{split}_sum"]).sum().item() / count_sum
|
1426 |
+
)
|
1427 |
+
logs[f"{train_eval}/logps/{split}"] = (
|
1428 |
+
torch.Tensor(self._stored_metrics[train_eval][f"logps/{split}_sum"]).sum().item() / count_sum
|
1429 |
+
)
|
1430 |
+
for key in [f"count/{split}", f"rewards/{split}_sum", f"logps/{split}_sum"]:
|
1431 |
+
del self._stored_metrics[train_eval][key]
|
1432 |
+
# calculate reward margin
|
1433 |
+
if f"{train_eval}/rewards/chosen" in logs and f"{train_eval}/rewards/rejected" in logs:
|
1434 |
+
logs[f"{train_eval}/rewards/margins"] = (
|
1435 |
+
logs[f"{train_eval}/rewards/chosen"] - logs[f"{train_eval}/rewards/rejected"]
|
1436 |
+
)
|
1437 |
+
# Add averaged stored metrics to logs
|
1438 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1439 |
+
logs[f"{train_eval}/{key}"] = torch.Tensor(metrics).mean().item()
|
1440 |
+
del self._stored_metrics[train_eval]
|
1441 |
+
return super().log(logs)
|
1442 |
+
|
1443 |
+
@wraps(Trainer.push_to_hub)
|
1444 |
+
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
1445 |
+
"""
|
1446 |
+
Overwrite the `push_to_hub` method in order to force-add the tag "kto" when pushing the
|
1447 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
1448 |
+
"""
|
1449 |
+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
1450 |
+
|
1451 |
+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
1452 |
+
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
1453 |
+
"""
|
1454 |
+
|
1455 |
+
Initialize KTOTrainer.
|
1456 |
+
|
1457 |
+
Args:
|
1458 |
+
model (`transformers.PreTrainedModel`):
|
1459 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1460 |
+
ref_model (`PreTrainedModelWrapper`):
|
1461 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1462 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1463 |
+
args (`KTOConfig`):
|
1464 |
+
The arguments to use for training.
|
1465 |
+
train_dataset (`datasets.Dataset`):
|
1466 |
+
The dataset to use for training.
|
1467 |
+
eval_dataset (`datasets.Dataset`):
|
1468 |
+
The dataset to use for evaluation.
|
1469 |
+
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
1470 |
+
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
1471 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1472 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1473 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1474 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1475 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1476 |
+
callbacks (`List[transformers.TrainerCallback]`):
|
1477 |
+
The callbacks to use for training.
|
1478 |
+
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1479 |
+
The optimizer and scheduler to use for training.
|
1480 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1481 |
+
The function to use to preprocess the logits before computing the metrics.
|
1482 |
+
peft_config (`Dict`, defaults to `None`):
|
1483 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1484 |
+
disable_dropout (`bool`, defaults to `True`):
|
1485 |
+
Whether or not to disable dropouts in `model` and `ref_model`.
|
1486 |
+
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
1487 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1488 |
+
a dictionary string to metric values.
|
1489 |
+
model_adapter_name (`str`, defaults to `None`):
|
1490 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1491 |
+
ref_adapter_name (`str`, defaults to `None`):
|
1492 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1493 |
+
|
1494 |
+
"""
|
1495 |
+
def __init__(
|
1496 |
+
self,
|
1497 |
+
model = None,
|
1498 |
+
ref_model = None,
|
1499 |
+
args = None,
|
1500 |
+
train_dataset = None,
|
1501 |
+
eval_dataset = None,
|
1502 |
+
tokenizer = None,
|
1503 |
+
data_collator = None,
|
1504 |
+
model_init = None,
|
1505 |
+
callbacks = None,
|
1506 |
+
preprocess_logits_for_metrics = None,
|
1507 |
+
peft_config = None,
|
1508 |
+
compute_metrics = None,
|
1509 |
+
model_adapter_name = None,
|
1510 |
+
ref_adapter_name = None,
|
1511 |
+
**kwargs
|
1512 |
+
):
|
1513 |
+
if args is None: args = UnslothKTOConfig()
|
1514 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1515 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
1516 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1517 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
1518 |
+
force_float32 = False
|
1519 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1520 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1521 |
+
force_float32 = True
|
1522 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1523 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1524 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1525 |
+
from unsloth_zoo.utils import _get_dtype
|
1526 |
+
dtype = _get_dtype(dtype)
|
1527 |
+
float16 = dtype == torch.float16
|
1528 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1529 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1530 |
+
if force_float32:
|
1531 |
+
args.fp16 = False
|
1532 |
+
args.bf16 = False
|
1533 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1534 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1535 |
+
args.fp16 = float16
|
1536 |
+
args.bf16 = not float16
|
1537 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1538 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1539 |
+
args.eval_strategy = 'steps'
|
1540 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1541 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1542 |
+
if ga_steps is not None and ga_steps > 1:
|
1543 |
+
from transformers import __version__ as transformers_version
|
1544 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1545 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1546 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1547 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1548 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1549 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1550 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1551 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1552 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
1553 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1554 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
1555 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1556 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1557 |
+
if force_float32:
|
1558 |
+
args.bf16_full_eval = False
|
1559 |
+
args.fp16_full_eval = False
|
1560 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1561 |
+
args.bf16_full_eval = True
|
1562 |
+
args.fp16_full_eval = False
|
1563 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1564 |
+
args.bf16_full_eval = args.bf16
|
1565 |
+
args.fp16_full_eval = args.fp16
|
1566 |
+
_output_logits = False
|
1567 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1568 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1569 |
+
if _output_logits:
|
1570 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1571 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1572 |
+
pass
|
1573 |
+
else:
|
1574 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1575 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1576 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1577 |
+
max_seq_length = model.max_seq_length
|
1578 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1579 |
+
if model is not None and hasattr(model, 'for_training'):
|
1580 |
+
model.for_training()
|
1581 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1582 |
+
if 'processing_class' in locals():
|
1583 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1584 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1585 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1586 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1587 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1588 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1589 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
1590 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1591 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1592 |
+
else:
|
1593 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1594 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1595 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1596 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1597 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1598 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1599 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1600 |
+
else:
|
1601 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
1602 |
+
other_metrics = []
|
1603 |
+
|
1604 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1605 |
+
PatchRLStatistics('kto_trainer', other_metrics)
|
1606 |
+
|
1607 |
+
super().__init__(
|
1608 |
+
model = model,
|
1609 |
+
ref_model = ref_model,
|
1610 |
+
args = args,
|
1611 |
+
train_dataset = train_dataset,
|
1612 |
+
eval_dataset = eval_dataset,
|
1613 |
+
tokenizer = tokenizer,
|
1614 |
+
data_collator = data_collator,
|
1615 |
+
model_init = model_init,
|
1616 |
+
callbacks = callbacks,
|
1617 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1618 |
+
peft_config = peft_config,
|
1619 |
+
compute_metrics = compute_metrics,
|
1620 |
+
model_adapter_name = model_adapter_name,
|
1621 |
+
ref_adapter_name = ref_adapter_name,**kwargs)
|
1622 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1623 |
+
self.neftune_hook_handle.remove()
|
1624 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1625 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1626 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1627 |
+
pass
|
1628 |
+
|
1629 |
+
pass
|
compilefcach/UnslothORPOTrainer.py
ADDED
@@ -0,0 +1,1413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, List, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, Trainer, TrainerCallback, Tuple, Union, deepcopy, defaultdict, disable_dropout_in_model, inspect, is_peft_available, is_torch_fx_proxy, is_wandb_available, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothORPOConfig(ORPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
ORPOConfig collects all training arguments related to the [`ORPOTrainer`] class.
|
47 |
+
|
48 |
+
Using [`HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int`, defaults to `None`):
|
54 |
+
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
55 |
+
max_prompt_length (`int`, defaults to `None`):
|
56 |
+
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
57 |
+
max_completion_length (`int`, defaults to `None`):
|
58 |
+
The maximum length of the completions. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
59 |
+
beta (`float`, defaults to 0.1):
|
60 |
+
The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss.
|
61 |
+
label_pad_token_id (`int`, defaults to `-100`):
|
62 |
+
The label pad token id. This argument is required if you want to use the default data collator.
|
63 |
+
padding_value (`int`, defaults to `None`):
|
64 |
+
The padding value if it is different to the tokenizer's pad_token_id.
|
65 |
+
truncation_mode (`str`, defaults to `keep_end`):
|
66 |
+
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
67 |
+
generate_during_eval (`bool`, defaults to `False`):
|
68 |
+
Whether to sample and log generations during evaluation step.
|
69 |
+
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
70 |
+
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
71 |
+
disable_dropout (`bool`, defaults to `True`):
|
72 |
+
Whether or not to disable dropouts in `model`.
|
73 |
+
model_init_kwargs (`Optional[Dict]`, *optional*):
|
74 |
+
Dict of Optional kwargs to pass when instantiating the model from a string
|
75 |
+
dataset_num_proc (`Optional[int]`, *optional*):
|
76 |
+
The number of workers to use to tokenize the data. Defaults to None.
|
77 |
+
|
78 |
+
"""
|
79 |
+
vllm_sampling_params: Optional[Any] = field(
|
80 |
+
default = None,
|
81 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
82 |
+
)
|
83 |
+
unsloth_num_chunks : Optional[int] = field(
|
84 |
+
default = -1,
|
85 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
86 |
+
)
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
output_dir = None,
|
90 |
+
overwrite_output_dir = None,
|
91 |
+
do_train = False,
|
92 |
+
do_eval = False,
|
93 |
+
do_predict = False,
|
94 |
+
eval_strategy = 'no',
|
95 |
+
prediction_loss_only = False,
|
96 |
+
per_device_train_batch_size = 4,
|
97 |
+
per_device_eval_batch_size = 4,
|
98 |
+
per_gpu_train_batch_size = None,
|
99 |
+
per_gpu_eval_batch_size = None,
|
100 |
+
gradient_accumulation_steps = 2,
|
101 |
+
eval_accumulation_steps = 2,
|
102 |
+
eval_delay = 0,
|
103 |
+
torch_empty_cache_steps = 250,
|
104 |
+
learning_rate = 5e-05,
|
105 |
+
weight_decay = 0.01,
|
106 |
+
adam_beta1 = 0.9,
|
107 |
+
adam_beta2 = 0.999,
|
108 |
+
adam_epsilon = 1e-08,
|
109 |
+
max_grad_norm = 1.0,
|
110 |
+
num_train_epochs = 3.0,
|
111 |
+
max_steps = -1,
|
112 |
+
lr_scheduler_type = 'linear',
|
113 |
+
warmup_ratio = 0.1,
|
114 |
+
warmup_steps = 0,
|
115 |
+
log_level = 'passive',
|
116 |
+
log_level_replica = 'warning',
|
117 |
+
log_on_each_node = True,
|
118 |
+
logging_dir = None,
|
119 |
+
logging_strategy = 'steps',
|
120 |
+
logging_first_step = False,
|
121 |
+
logging_steps = 1,
|
122 |
+
logging_nan_inf_filter = False,
|
123 |
+
save_strategy = 'steps',
|
124 |
+
save_steps = 500,
|
125 |
+
save_total_limit = None,
|
126 |
+
save_safetensors = True,
|
127 |
+
save_on_each_node = False,
|
128 |
+
save_only_model = False,
|
129 |
+
restore_callback_states_from_checkpoint = False,
|
130 |
+
no_cuda = False,
|
131 |
+
use_cpu = False,
|
132 |
+
use_mps_device = False,
|
133 |
+
seed = 3407,
|
134 |
+
data_seed = 3407,
|
135 |
+
jit_mode_eval = False,
|
136 |
+
use_ipex = False,
|
137 |
+
bf16 = False,
|
138 |
+
fp16 = False,
|
139 |
+
fp16_opt_level = 'O1',
|
140 |
+
half_precision_backend = 'auto',
|
141 |
+
bf16_full_eval = False,
|
142 |
+
fp16_full_eval = False,
|
143 |
+
tf32 = None,
|
144 |
+
local_rank = -1,
|
145 |
+
ddp_backend = None,
|
146 |
+
tpu_num_cores = None,
|
147 |
+
tpu_metrics_debug = False,
|
148 |
+
debug = '',
|
149 |
+
dataloader_drop_last = False,
|
150 |
+
eval_steps = None,
|
151 |
+
dataloader_num_workers = 0,
|
152 |
+
dataloader_prefetch_factor = None,
|
153 |
+
past_index = -1,
|
154 |
+
run_name = None,
|
155 |
+
disable_tqdm = None,
|
156 |
+
remove_unused_columns = True,
|
157 |
+
label_names = None,
|
158 |
+
load_best_model_at_end = False,
|
159 |
+
metric_for_best_model = None,
|
160 |
+
greater_is_better = None,
|
161 |
+
ignore_data_skip = False,
|
162 |
+
fsdp = '',
|
163 |
+
fsdp_min_num_params = 0,
|
164 |
+
fsdp_config = None,
|
165 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
166 |
+
accelerator_config = None,
|
167 |
+
deepspeed = None,
|
168 |
+
label_smoothing_factor = 0.0,
|
169 |
+
optim = 'adamw_8bit',
|
170 |
+
optim_args = None,
|
171 |
+
adafactor = False,
|
172 |
+
group_by_length = False,
|
173 |
+
length_column_name = 'length',
|
174 |
+
report_to = None,
|
175 |
+
ddp_find_unused_parameters = None,
|
176 |
+
ddp_bucket_cap_mb = None,
|
177 |
+
ddp_broadcast_buffers = None,
|
178 |
+
dataloader_pin_memory = True,
|
179 |
+
dataloader_persistent_workers = False,
|
180 |
+
skip_memory_metrics = True,
|
181 |
+
use_legacy_prediction_loop = False,
|
182 |
+
push_to_hub = False,
|
183 |
+
resume_from_checkpoint = None,
|
184 |
+
hub_model_id = None,
|
185 |
+
hub_strategy = 'every_save',
|
186 |
+
hub_token = None,
|
187 |
+
hub_private_repo = None,
|
188 |
+
hub_always_push = False,
|
189 |
+
hub_revision = None,
|
190 |
+
gradient_checkpointing = False,
|
191 |
+
gradient_checkpointing_kwargs = None,
|
192 |
+
include_inputs_for_metrics = False,
|
193 |
+
eval_do_concat_batches = True,
|
194 |
+
fp16_backend = 'auto',
|
195 |
+
push_to_hub_model_id = None,
|
196 |
+
push_to_hub_organization = None,
|
197 |
+
push_to_hub_token = None,
|
198 |
+
mp_parameters = '',
|
199 |
+
auto_find_batch_size = False,
|
200 |
+
full_determinism = False,
|
201 |
+
torchdynamo = None,
|
202 |
+
ray_scope = 'last',
|
203 |
+
ddp_timeout = 1800,
|
204 |
+
torch_compile = False,
|
205 |
+
torch_compile_backend = None,
|
206 |
+
torch_compile_mode = None,
|
207 |
+
include_tokens_per_second = False,
|
208 |
+
include_num_input_tokens_seen = False,
|
209 |
+
neftune_noise_alpha = None,
|
210 |
+
optim_target_modules = None,
|
211 |
+
batch_eval_metrics = False,
|
212 |
+
eval_on_start = False,
|
213 |
+
use_liger_kernel = False,
|
214 |
+
liger_kernel_config = None,
|
215 |
+
eval_use_gather_object = False,
|
216 |
+
average_tokens_across_devices = False,
|
217 |
+
max_length = None,
|
218 |
+
max_prompt_length = None,
|
219 |
+
max_completion_length = None,
|
220 |
+
beta = 0.1,
|
221 |
+
disable_dropout = True,
|
222 |
+
label_pad_token_id = -100,
|
223 |
+
padding_value = None,
|
224 |
+
truncation_mode = 'keep_end',
|
225 |
+
generate_during_eval = False,
|
226 |
+
is_encoder_decoder = None,
|
227 |
+
model_init_kwargs = None,
|
228 |
+
dataset_num_proc = None,
|
229 |
+
vllm_sampling_params = None,
|
230 |
+
unsloth_num_chunks = -1,
|
231 |
+
**kwargs,
|
232 |
+
):
|
233 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
234 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
235 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
236 |
+
output_dir = 'unsloth_training_checkpoints'
|
237 |
+
save_strategy = 'no'
|
238 |
+
if dataset_num_proc is None:
|
239 |
+
from multiprocessing import cpu_count
|
240 |
+
dataset_num_proc = cpu_count()
|
241 |
+
|
242 |
+
super().__init__(
|
243 |
+
output_dir = output_dir,
|
244 |
+
overwrite_output_dir = overwrite_output_dir,
|
245 |
+
do_train = do_train,
|
246 |
+
do_eval = do_eval,
|
247 |
+
do_predict = do_predict,
|
248 |
+
eval_strategy = eval_strategy,
|
249 |
+
prediction_loss_only = prediction_loss_only,
|
250 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
251 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
252 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
253 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
254 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
255 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
256 |
+
eval_delay = eval_delay,
|
257 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
258 |
+
learning_rate = learning_rate,
|
259 |
+
weight_decay = weight_decay,
|
260 |
+
adam_beta1 = adam_beta1,
|
261 |
+
adam_beta2 = adam_beta2,
|
262 |
+
adam_epsilon = adam_epsilon,
|
263 |
+
max_grad_norm = max_grad_norm,
|
264 |
+
num_train_epochs = num_train_epochs,
|
265 |
+
max_steps = max_steps,
|
266 |
+
lr_scheduler_type = lr_scheduler_type,
|
267 |
+
warmup_ratio = warmup_ratio,
|
268 |
+
warmup_steps = warmup_steps,
|
269 |
+
log_level = log_level,
|
270 |
+
log_level_replica = log_level_replica,
|
271 |
+
log_on_each_node = log_on_each_node,
|
272 |
+
logging_dir = logging_dir,
|
273 |
+
logging_strategy = logging_strategy,
|
274 |
+
logging_first_step = logging_first_step,
|
275 |
+
logging_steps = logging_steps,
|
276 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
277 |
+
save_strategy = save_strategy,
|
278 |
+
save_steps = save_steps,
|
279 |
+
save_total_limit = save_total_limit,
|
280 |
+
save_safetensors = save_safetensors,
|
281 |
+
save_on_each_node = save_on_each_node,
|
282 |
+
save_only_model = save_only_model,
|
283 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
284 |
+
no_cuda = no_cuda,
|
285 |
+
use_cpu = use_cpu,
|
286 |
+
use_mps_device = use_mps_device,
|
287 |
+
seed = seed,
|
288 |
+
data_seed = data_seed,
|
289 |
+
jit_mode_eval = jit_mode_eval,
|
290 |
+
use_ipex = use_ipex,
|
291 |
+
bf16 = bf16,
|
292 |
+
fp16 = fp16,
|
293 |
+
fp16_opt_level = fp16_opt_level,
|
294 |
+
half_precision_backend = half_precision_backend,
|
295 |
+
bf16_full_eval = bf16_full_eval,
|
296 |
+
fp16_full_eval = fp16_full_eval,
|
297 |
+
tf32 = tf32,
|
298 |
+
local_rank = local_rank,
|
299 |
+
ddp_backend = ddp_backend,
|
300 |
+
tpu_num_cores = tpu_num_cores,
|
301 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
302 |
+
debug = debug,
|
303 |
+
dataloader_drop_last = dataloader_drop_last,
|
304 |
+
eval_steps = eval_steps,
|
305 |
+
dataloader_num_workers = dataloader_num_workers,
|
306 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
307 |
+
past_index = past_index,
|
308 |
+
run_name = run_name,
|
309 |
+
disable_tqdm = disable_tqdm,
|
310 |
+
remove_unused_columns = remove_unused_columns,
|
311 |
+
label_names = label_names,
|
312 |
+
load_best_model_at_end = load_best_model_at_end,
|
313 |
+
metric_for_best_model = metric_for_best_model,
|
314 |
+
greater_is_better = greater_is_better,
|
315 |
+
ignore_data_skip = ignore_data_skip,
|
316 |
+
fsdp = fsdp,
|
317 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
318 |
+
fsdp_config = fsdp_config,
|
319 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
320 |
+
accelerator_config = accelerator_config,
|
321 |
+
deepspeed = deepspeed,
|
322 |
+
label_smoothing_factor = label_smoothing_factor,
|
323 |
+
optim = optim,
|
324 |
+
optim_args = optim_args,
|
325 |
+
adafactor = adafactor,
|
326 |
+
group_by_length = group_by_length,
|
327 |
+
length_column_name = length_column_name,
|
328 |
+
report_to = report_to,
|
329 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
330 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
331 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
332 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
333 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
334 |
+
skip_memory_metrics = skip_memory_metrics,
|
335 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
336 |
+
push_to_hub = push_to_hub,
|
337 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
338 |
+
hub_model_id = hub_model_id,
|
339 |
+
hub_strategy = hub_strategy,
|
340 |
+
hub_token = hub_token,
|
341 |
+
hub_private_repo = hub_private_repo,
|
342 |
+
hub_always_push = hub_always_push,
|
343 |
+
hub_revision = hub_revision,
|
344 |
+
gradient_checkpointing = gradient_checkpointing,
|
345 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
346 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
347 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
348 |
+
fp16_backend = fp16_backend,
|
349 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
350 |
+
push_to_hub_organization = push_to_hub_organization,
|
351 |
+
push_to_hub_token = push_to_hub_token,
|
352 |
+
mp_parameters = mp_parameters,
|
353 |
+
auto_find_batch_size = auto_find_batch_size,
|
354 |
+
full_determinism = full_determinism,
|
355 |
+
torchdynamo = torchdynamo,
|
356 |
+
ray_scope = ray_scope,
|
357 |
+
ddp_timeout = ddp_timeout,
|
358 |
+
torch_compile = torch_compile,
|
359 |
+
torch_compile_backend = torch_compile_backend,
|
360 |
+
torch_compile_mode = torch_compile_mode,
|
361 |
+
include_tokens_per_second = include_tokens_per_second,
|
362 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
363 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
364 |
+
optim_target_modules = optim_target_modules,
|
365 |
+
batch_eval_metrics = batch_eval_metrics,
|
366 |
+
eval_on_start = eval_on_start,
|
367 |
+
use_liger_kernel = use_liger_kernel,
|
368 |
+
liger_kernel_config = liger_kernel_config,
|
369 |
+
eval_use_gather_object = eval_use_gather_object,
|
370 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
371 |
+
max_length = max_length,
|
372 |
+
max_prompt_length = max_prompt_length,
|
373 |
+
max_completion_length = max_completion_length,
|
374 |
+
beta = beta,
|
375 |
+
disable_dropout = disable_dropout,
|
376 |
+
label_pad_token_id = label_pad_token_id,
|
377 |
+
padding_value = padding_value,
|
378 |
+
truncation_mode = truncation_mode,
|
379 |
+
generate_during_eval = generate_during_eval,
|
380 |
+
is_encoder_decoder = is_encoder_decoder,
|
381 |
+
model_init_kwargs = model_init_kwargs,
|
382 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
383 |
+
self.vllm_sampling_params = vllm_sampling_params
|
384 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
385 |
+
pass
|
386 |
+
|
387 |
+
class _UnslothORPOTrainer(Trainer):
|
388 |
+
r""""""
|
389 |
+
|
390 |
+
_tag_names = ["trl", "orpo"]
|
391 |
+
|
392 |
+
def __init__(
|
393 |
+
self,
|
394 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
395 |
+
args: Optional[ORPOConfig] = None,
|
396 |
+
data_collator: Optional[DataCollator] = None,
|
397 |
+
train_dataset: Optional[Dataset] = None,
|
398 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
399 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
400 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
401 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
402 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
403 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
404 |
+
peft_config: Optional[Dict] = None,
|
405 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
406 |
+
):
|
407 |
+
if args.model_init_kwargs is None:
|
408 |
+
model_init_kwargs = {}
|
409 |
+
elif not isinstance(model, str):
|
410 |
+
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
411 |
+
else:
|
412 |
+
model_init_kwargs = args.model_init_kwargs
|
413 |
+
model_init_kwargs["torch_dtype"] = (
|
414 |
+
model_init_kwargs["torch_dtype"]
|
415 |
+
if model_init_kwargs["torch_dtype"] in ["auto", None]
|
416 |
+
else getattr(torch, model_init_kwargs["torch_dtype"])
|
417 |
+
)
|
418 |
+
|
419 |
+
if isinstance(model, str):
|
420 |
+
warnings.warn(
|
421 |
+
"You passed a model_id to the ORPOTrainer. This will automatically create an "
|
422 |
+
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
|
423 |
+
)
|
424 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
425 |
+
|
426 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
427 |
+
# has been called in order to properly call autocast if needed.
|
428 |
+
self._peft_has_been_casted_to_bf16 = False
|
429 |
+
|
430 |
+
if not is_peft_available() and peft_config is not None:
|
431 |
+
raise ValueError(
|
432 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
433 |
+
)
|
434 |
+
elif is_peft_available() and peft_config is not None:
|
435 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
436 |
+
if isinstance(model, PeftModel):
|
437 |
+
model = model.merge_and_unload()
|
438 |
+
|
439 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
440 |
+
_support_gc_kwargs = hasattr(
|
441 |
+
args, "gradient_checkpointing_kwargs"
|
442 |
+
) and "gradient_checkpointing_kwargs" in list(
|
443 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
444 |
+
)
|
445 |
+
|
446 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
447 |
+
|
448 |
+
if _support_gc_kwargs:
|
449 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
450 |
+
|
451 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
452 |
+
elif getattr(args, "gradient_checkpointing", False):
|
453 |
+
# For backward compatibility with older versions of transformers
|
454 |
+
if hasattr(model, "enable_input_require_grads"):
|
455 |
+
model.enable_input_require_grads()
|
456 |
+
else:
|
457 |
+
|
458 |
+
def make_inputs_require_grad(module, input, output):
|
459 |
+
output.requires_grad_(True)
|
460 |
+
|
461 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
462 |
+
|
463 |
+
# get peft model with the given config
|
464 |
+
model = model
|
465 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
466 |
+
peft_module_casting_to_bf16(model)
|
467 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
468 |
+
self._peft_has_been_casted_to_bf16 = True
|
469 |
+
|
470 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
471 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
472 |
+
# fail or completely fail.
|
473 |
+
elif getattr(args, "gradient_checkpointing", False):
|
474 |
+
# For backward compatibility with older versions of transformers
|
475 |
+
if hasattr(model, "enable_input_require_grads"):
|
476 |
+
model.enable_input_require_grads()
|
477 |
+
else:
|
478 |
+
|
479 |
+
def make_inputs_require_grad(module, input, output):
|
480 |
+
output.requires_grad_(True)
|
481 |
+
|
482 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
483 |
+
|
484 |
+
if args.generate_during_eval and not is_wandb_available():
|
485 |
+
raise ValueError(
|
486 |
+
"`generate_during_eval=True` requires Weights and Biases to be installed."
|
487 |
+
" Please install `wandb` to resolve."
|
488 |
+
)
|
489 |
+
|
490 |
+
if model is not None:
|
491 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
492 |
+
elif args.is_encoder_decoder is None:
|
493 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
494 |
+
else:
|
495 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
496 |
+
|
497 |
+
if self.is_encoder_decoder:
|
498 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
499 |
+
self.pad_token_id = model.config.pad_token_id
|
500 |
+
|
501 |
+
if tokenizer is None:
|
502 |
+
raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.")
|
503 |
+
if args.max_length is None:
|
504 |
+
warnings.warn(
|
505 |
+
"`max_length` is not set in the ORPOConfig's init"
|
506 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
507 |
+
UserWarning,
|
508 |
+
)
|
509 |
+
max_length = 512
|
510 |
+
else:
|
511 |
+
max_length = args.max_length
|
512 |
+
if args.max_prompt_length is None:
|
513 |
+
warnings.warn(
|
514 |
+
"`max_prompt_length` is not set in the ORPOConfig's init"
|
515 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
516 |
+
UserWarning,
|
517 |
+
)
|
518 |
+
max_prompt_length = 128
|
519 |
+
else:
|
520 |
+
max_prompt_length = args.max_prompt_length
|
521 |
+
|
522 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
523 |
+
warnings.warn(
|
524 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
525 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
526 |
+
UserWarning,
|
527 |
+
)
|
528 |
+
self.max_completion_length = 128
|
529 |
+
else:
|
530 |
+
self.max_completion_length = args.max_completion_length
|
531 |
+
|
532 |
+
if data_collator is None:
|
533 |
+
data_collator = DPODataCollatorWithPadding(
|
534 |
+
pad_token_id=tokenizer.pad_token_id,
|
535 |
+
label_pad_token_id=args.label_pad_token_id,
|
536 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
537 |
+
)
|
538 |
+
|
539 |
+
if args.remove_unused_columns:
|
540 |
+
args.remove_unused_columns = False
|
541 |
+
# warn users
|
542 |
+
warnings.warn(
|
543 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
544 |
+
" we have set it for you, but you should do it yourself in the future.",
|
545 |
+
UserWarning,
|
546 |
+
)
|
547 |
+
|
548 |
+
self.use_dpo_data_collator = True
|
549 |
+
else:
|
550 |
+
self.use_dpo_data_collator = False
|
551 |
+
|
552 |
+
if args.disable_dropout:
|
553 |
+
disable_dropout_in_model(model)
|
554 |
+
|
555 |
+
self.max_length = max_length
|
556 |
+
self.generate_during_eval = args.generate_during_eval
|
557 |
+
self.label_pad_token_id = args.label_pad_token_id
|
558 |
+
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
|
559 |
+
self.max_prompt_length = max_prompt_length
|
560 |
+
self.truncation_mode = args.truncation_mode
|
561 |
+
self.tokenizer = tokenizer
|
562 |
+
|
563 |
+
self.beta = args.beta
|
564 |
+
|
565 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
566 |
+
|
567 |
+
# Compute that only on the main process for faster data processing.
|
568 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
569 |
+
with PartialState().local_main_process_first():
|
570 |
+
# tokenize the dataset
|
571 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
572 |
+
if eval_dataset is not None:
|
573 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
574 |
+
|
575 |
+
super().__init__(
|
576 |
+
model=model,
|
577 |
+
args=args,
|
578 |
+
data_collator=data_collator,
|
579 |
+
train_dataset=train_dataset,
|
580 |
+
eval_dataset=eval_dataset,
|
581 |
+
tokenizer=tokenizer,
|
582 |
+
model_init=model_init,
|
583 |
+
compute_metrics=compute_metrics,
|
584 |
+
callbacks=callbacks,
|
585 |
+
optimizers=optimizers,
|
586 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
587 |
+
)
|
588 |
+
|
589 |
+
# Add tags for models that have been loaded with the correct transformers version
|
590 |
+
if hasattr(self.model, "add_model_tags"):
|
591 |
+
self.model.add_model_tags(self._tag_names)
|
592 |
+
|
593 |
+
if not hasattr(self, "accelerator"):
|
594 |
+
raise AttributeError(
|
595 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
596 |
+
)
|
597 |
+
|
598 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
599 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
600 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
601 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
602 |
+
|
603 |
+
if model is not None:
|
604 |
+
if hasattr(model, "config"):
|
605 |
+
hidden_size = (
|
606 |
+
max(model.config.hidden_sizes)
|
607 |
+
if getattr(model.config, "hidden_sizes", None)
|
608 |
+
else getattr(model.config, "hidden_size", None)
|
609 |
+
)
|
610 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
611 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
612 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
613 |
+
config_kwargs.update(
|
614 |
+
{
|
615 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
616 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
617 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
618 |
+
}
|
619 |
+
)
|
620 |
+
|
621 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
622 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
623 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
624 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
625 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
626 |
+
model.eval()
|
627 |
+
return model
|
628 |
+
|
629 |
+
def build_tokenized_answer(self, prompt, answer):
|
630 |
+
"""
|
631 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
632 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
633 |
+
Reference:
|
634 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
635 |
+
"""
|
636 |
+
|
637 |
+
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
|
638 |
+
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
639 |
+
|
640 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
641 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
642 |
+
|
643 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
644 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
645 |
+
|
646 |
+
# Prepare input tokens for token by token comparison
|
647 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
648 |
+
|
649 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
650 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
651 |
+
|
652 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
653 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
654 |
+
# on the last token from the prompt being different when tokenized on its own
|
655 |
+
# vs when done as prompt+answer.
|
656 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
657 |
+
|
658 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
659 |
+
# last token has changed due to merging.
|
660 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
661 |
+
response_token_ids_start_idx -= 1
|
662 |
+
|
663 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
664 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
665 |
+
|
666 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
667 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
668 |
+
|
669 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
670 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
671 |
+
|
672 |
+
return dict(
|
673 |
+
prompt_input_ids=prompt_input_ids,
|
674 |
+
prompt_attention_mask=prompt_attention_mask,
|
675 |
+
input_ids=answer_input_ids,
|
676 |
+
attention_mask=answer_attention_mask,
|
677 |
+
)
|
678 |
+
|
679 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
|
680 |
+
"""Tokenize a single row from a ORPO specific dataset.
|
681 |
+
|
682 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
683 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
684 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
685 |
+
|
686 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
687 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
688 |
+
label_pad_token_id for the prompt tokens.
|
689 |
+
"""
|
690 |
+
batch = {}
|
691 |
+
prompt = feature["prompt"]
|
692 |
+
chosen = feature["chosen"]
|
693 |
+
rejected = feature["rejected"]
|
694 |
+
|
695 |
+
if not self.is_encoder_decoder:
|
696 |
+
# Check issues below for more details
|
697 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
698 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
699 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
700 |
+
|
701 |
+
if not isinstance(prompt, str):
|
702 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
703 |
+
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
|
704 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
705 |
+
|
706 |
+
if not isinstance(chosen, str):
|
707 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
708 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
709 |
+
|
710 |
+
if not isinstance(rejected, str):
|
711 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
712 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
713 |
+
|
714 |
+
# Last prompt token might get merged by tokenizer and
|
715 |
+
# it should not be included for generation if that happens
|
716 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
717 |
+
|
718 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
719 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
720 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
721 |
+
|
722 |
+
for k, v in prompt_tokens.items():
|
723 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
724 |
+
|
725 |
+
# Make sure prompts only have one different token at most an
|
726 |
+
# and length only differs by 1 at most
|
727 |
+
num_diff_tokens = sum(
|
728 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
729 |
+
)
|
730 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
731 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
732 |
+
raise ValueError(
|
733 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
734 |
+
"last token due to tokenizer merge ops."
|
735 |
+
)
|
736 |
+
|
737 |
+
# add BOS token to head of prompt
|
738 |
+
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
|
739 |
+
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
|
740 |
+
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
|
741 |
+
|
742 |
+
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
|
743 |
+
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
|
744 |
+
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
|
745 |
+
|
746 |
+
# add EOS token to end of answer
|
747 |
+
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
748 |
+
chosen_tokens["attention_mask"].append(1)
|
749 |
+
|
750 |
+
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
751 |
+
rejected_tokens["attention_mask"].append(1)
|
752 |
+
|
753 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
754 |
+
|
755 |
+
# if combined sequence is too long, truncate the prompt
|
756 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
757 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
758 |
+
if self.truncation_mode == "keep_start":
|
759 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
760 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
761 |
+
elif self.truncation_mode == "keep_end":
|
762 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
763 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
764 |
+
else:
|
765 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
766 |
+
|
767 |
+
# if that's still too long, truncate the response
|
768 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
769 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
770 |
+
for k in ["input_ids", "attention_mask"]:
|
771 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
772 |
+
|
773 |
+
# Create labels
|
774 |
+
chosen_sequence_tokens = {
|
775 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
776 |
+
}
|
777 |
+
rejected_sequence_tokens = {
|
778 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
779 |
+
}
|
780 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
781 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
782 |
+
self.label_pad_token_id
|
783 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
784 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
785 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
786 |
+
self.label_pad_token_id
|
787 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
788 |
+
|
789 |
+
for k, toks in {
|
790 |
+
"chosen_": chosen_sequence_tokens,
|
791 |
+
"rejected_": rejected_sequence_tokens,
|
792 |
+
"": prompt_tokens,
|
793 |
+
}.items():
|
794 |
+
for type_key, tokens in toks.items():
|
795 |
+
if type_key == "token_type_ids":
|
796 |
+
continue
|
797 |
+
batch[f"{k}{type_key}"] = tokens
|
798 |
+
|
799 |
+
else:
|
800 |
+
chosen_tokens = self.tokenizer(
|
801 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
802 |
+
)
|
803 |
+
rejected_tokens = self.tokenizer(
|
804 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
805 |
+
)
|
806 |
+
prompt_tokens = self.tokenizer(
|
807 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
808 |
+
)
|
809 |
+
|
810 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
811 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
812 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
813 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
814 |
+
|
815 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
816 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
817 |
+
labels=torch.tensor(batch["rejected_labels"])
|
818 |
+
)
|
819 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
820 |
+
labels=torch.tensor(batch["chosen_labels"])
|
821 |
+
)
|
822 |
+
|
823 |
+
return batch
|
824 |
+
|
825 |
+
@staticmethod
|
826 |
+
def concatenated_inputs(
|
827 |
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
828 |
+
is_encoder_decoder: bool = False,
|
829 |
+
label_pad_token_id: int = -100,
|
830 |
+
padding_value: int = 0,
|
831 |
+
device: Optional[torch.device] = None,
|
832 |
+
) -> Dict[str, torch.LongTensor]:
|
833 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
834 |
+
|
835 |
+
Args:
|
836 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
837 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
838 |
+
label_pad_token_id: The label pad token id.
|
839 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
840 |
+
device: The device for the concatenated inputs.
|
841 |
+
|
842 |
+
Returns:
|
843 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
844 |
+
"""
|
845 |
+
concatenated_batch = {}
|
846 |
+
|
847 |
+
if is_encoder_decoder:
|
848 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
849 |
+
else:
|
850 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
851 |
+
|
852 |
+
for k in batch:
|
853 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
854 |
+
if "labels" in k or is_encoder_decoder:
|
855 |
+
pad_value = label_pad_token_id
|
856 |
+
elif k.endswith("_input_ids"):
|
857 |
+
pad_value = padding_value
|
858 |
+
elif k.endswith("_attention_mask"):
|
859 |
+
pad_value = 0
|
860 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
861 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
862 |
+
for k in batch:
|
863 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
864 |
+
if "labels" in k or is_encoder_decoder:
|
865 |
+
pad_value = label_pad_token_id
|
866 |
+
elif k.endswith("_input_ids"):
|
867 |
+
pad_value = padding_value
|
868 |
+
elif k.endswith("_attention_mask"):
|
869 |
+
pad_value = 0
|
870 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
871 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
872 |
+
(
|
873 |
+
concatenated_batch[concatenated_key],
|
874 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
875 |
+
),
|
876 |
+
dim=0,
|
877 |
+
).to(device=device)
|
878 |
+
|
879 |
+
if is_encoder_decoder:
|
880 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
881 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
882 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
883 |
+
)
|
884 |
+
|
885 |
+
return concatenated_batch
|
886 |
+
|
887 |
+
def odds_ratio_loss(
|
888 |
+
self,
|
889 |
+
policy_chosen_logps: torch.FloatTensor,
|
890 |
+
policy_rejected_logps: torch.FloatTensor,
|
891 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
892 |
+
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
893 |
+
|
894 |
+
Args:
|
895 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
896 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
897 |
+
|
898 |
+
Returns:
|
899 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
900 |
+
The losses tensor contains the ORPO loss for each example in the batch.
|
901 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
902 |
+
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
903 |
+
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
904 |
+
"""
|
905 |
+
|
906 |
+
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
907 |
+
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
908 |
+
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
909 |
+
)
|
910 |
+
sig_ratio = F.sigmoid(log_odds)
|
911 |
+
ratio = torch.log(sig_ratio)
|
912 |
+
losses = self.beta * ratio
|
913 |
+
|
914 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
915 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
916 |
+
|
917 |
+
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
|
918 |
+
|
919 |
+
@staticmethod
|
920 |
+
def get_batch_logps(
|
921 |
+
logits: torch.FloatTensor,
|
922 |
+
labels: torch.LongTensor,
|
923 |
+
average_log_prob: bool = False,
|
924 |
+
label_pad_token_id: int = -100,
|
925 |
+
is_encoder_decoder: bool = False,
|
926 |
+
) -> torch.FloatTensor:
|
927 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
928 |
+
|
929 |
+
Args:
|
930 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
931 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
932 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
933 |
+
label_pad_token_id: The label pad token id.
|
934 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
935 |
+
|
936 |
+
Returns:
|
937 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
938 |
+
"""
|
939 |
+
if logits.shape[:-1] != labels.shape:
|
940 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
941 |
+
|
942 |
+
if not is_encoder_decoder:
|
943 |
+
labels = labels[:, 1:].clone()
|
944 |
+
logits = logits[:, :-1, :]
|
945 |
+
loss_mask = labels != label_pad_token_id
|
946 |
+
|
947 |
+
# dummy token; we'll ignore the losses on these tokens later
|
948 |
+
labels[labels == label_pad_token_id] = 0
|
949 |
+
|
950 |
+
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
951 |
+
|
952 |
+
if average_log_prob:
|
953 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
954 |
+
else:
|
955 |
+
return (per_token_logps * loss_mask).sum(-1)
|
956 |
+
|
957 |
+
def concatenated_forward(
|
958 |
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
959 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
960 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
961 |
+
|
962 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
963 |
+
"""
|
964 |
+
concatenated_batch = self.concatenated_inputs(
|
965 |
+
batch,
|
966 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
967 |
+
label_pad_token_id=self.label_pad_token_id,
|
968 |
+
padding_value=self.padding_value,
|
969 |
+
device=self.accelerator.device,
|
970 |
+
)
|
971 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
972 |
+
|
973 |
+
model_kwargs = (
|
974 |
+
{
|
975 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
976 |
+
}
|
977 |
+
if self.is_encoder_decoder
|
978 |
+
else {}
|
979 |
+
)
|
980 |
+
|
981 |
+
outputs = model(
|
982 |
+
concatenated_batch["concatenated_input_ids"],
|
983 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
984 |
+
use_cache=False,
|
985 |
+
**model_kwargs,
|
986 |
+
)
|
987 |
+
all_logits = outputs.logits
|
988 |
+
|
989 |
+
def cross_entropy_loss(logits, labels):
|
990 |
+
if not self.is_encoder_decoder:
|
991 |
+
# Shift so that tokens < n predict n
|
992 |
+
logits = logits[..., :-1, :].contiguous()
|
993 |
+
labels = labels[..., 1:].contiguous()
|
994 |
+
# Flatten the tokens
|
995 |
+
loss_fct = nn.CrossEntropyLoss()
|
996 |
+
logits = logits.view(-1, logits.shape[-1])
|
997 |
+
labels = labels.view(-1)
|
998 |
+
# Enable model parallelism
|
999 |
+
labels = labels.to(logits.device)
|
1000 |
+
loss = loss_fct(logits, labels)
|
1001 |
+
return loss
|
1002 |
+
|
1003 |
+
if self.is_encoder_decoder:
|
1004 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1005 |
+
else:
|
1006 |
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
1007 |
+
|
1008 |
+
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1009 |
+
|
1010 |
+
all_logps = self.get_batch_logps(
|
1011 |
+
all_logits,
|
1012 |
+
concatenated_batch["concatenated_labels"],
|
1013 |
+
average_log_prob=True,
|
1014 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1015 |
+
label_pad_token_id=self.label_pad_token_id,
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
chosen_logps = all_logps[:len_chosen]
|
1019 |
+
rejected_logps = all_logps[len_chosen:]
|
1020 |
+
|
1021 |
+
chosen_logits = all_logits[:len_chosen]
|
1022 |
+
rejected_logits = all_logits[len_chosen:]
|
1023 |
+
|
1024 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
1025 |
+
|
1026 |
+
def get_batch_loss_metrics(
|
1027 |
+
self,
|
1028 |
+
model,
|
1029 |
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
1030 |
+
train_eval: Literal["train", "eval"] = "train",
|
1031 |
+
):
|
1032 |
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
1033 |
+
metrics = {}
|
1034 |
+
|
1035 |
+
(
|
1036 |
+
policy_chosen_logps,
|
1037 |
+
policy_rejected_logps,
|
1038 |
+
policy_chosen_logits,
|
1039 |
+
policy_rejected_logits,
|
1040 |
+
policy_nll_loss,
|
1041 |
+
) = self.concatenated_forward(model, batch)
|
1042 |
+
|
1043 |
+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
1044 |
+
policy_chosen_logps, policy_rejected_logps
|
1045 |
+
)
|
1046 |
+
# full ORPO loss
|
1047 |
+
loss = policy_nll_loss - losses.mean()
|
1048 |
+
|
1049 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1050 |
+
|
1051 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1052 |
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
1053 |
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
1054 |
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
1055 |
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
1056 |
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
1057 |
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
1058 |
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
1059 |
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
1060 |
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
1061 |
+
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
1062 |
+
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
1063 |
+
|
1064 |
+
return loss, metrics
|
1065 |
+
|
1066 |
+
def compute_loss(
|
1067 |
+
self,
|
1068 |
+
model: Union[PreTrainedModel, nn.Module],
|
1069 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1070 |
+
return_outputs=False,
|
1071 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
1072 |
+
if not self.use_dpo_data_collator:
|
1073 |
+
warnings.warn(
|
1074 |
+
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1075 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1079 |
+
|
1080 |
+
with compute_loss_context_manager():
|
1081 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1082 |
+
|
1083 |
+
# force log the metrics
|
1084 |
+
self.store_metrics(metrics, train_eval="train")
|
1085 |
+
|
1086 |
+
if return_outputs:
|
1087 |
+
return (loss, metrics)
|
1088 |
+
return loss
|
1089 |
+
|
1090 |
+
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
1091 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1092 |
+
|
1093 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1094 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1095 |
+
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
1096 |
+
|
1097 |
+
with generate_context_manager():
|
1098 |
+
policy_output = model.generate(
|
1099 |
+
input_ids=batch["prompt_input_ids"],
|
1100 |
+
attention_mask=batch["prompt_attention_mask"],
|
1101 |
+
max_length=self.max_length,
|
1102 |
+
do_sample=True,
|
1103 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
|
1107 |
+
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
1108 |
+
|
1109 |
+
return policy_output_decoded
|
1110 |
+
|
1111 |
+
def prediction_step(
|
1112 |
+
self,
|
1113 |
+
model: Union[PreTrainedModel, nn.Module],
|
1114 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
1115 |
+
prediction_loss_only: bool,
|
1116 |
+
ignore_keys: Optional[List[str]] = None,
|
1117 |
+
):
|
1118 |
+
if not self.use_dpo_data_collator:
|
1119 |
+
warnings.warn(
|
1120 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1121 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1122 |
+
)
|
1123 |
+
if ignore_keys is None:
|
1124 |
+
if hasattr(model, "config"):
|
1125 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1126 |
+
else:
|
1127 |
+
ignore_keys = []
|
1128 |
+
|
1129 |
+
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
1130 |
+
|
1131 |
+
with torch.no_grad(), prediction_context_manager():
|
1132 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1133 |
+
|
1134 |
+
# force log the metrics
|
1135 |
+
self.store_metrics(metrics, train_eval="eval")
|
1136 |
+
|
1137 |
+
if prediction_loss_only:
|
1138 |
+
return (loss.detach(), None, None)
|
1139 |
+
|
1140 |
+
# logits for the chosen and rejected samples from model
|
1141 |
+
logits_dict = {
|
1142 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1143 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1144 |
+
}
|
1145 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1146 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1147 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1148 |
+
|
1149 |
+
return (loss.detach(), logits, labels)
|
1150 |
+
|
1151 |
+
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1152 |
+
for key, value in metrics.items():
|
1153 |
+
self._stored_metrics[train_eval][key].append(value)
|
1154 |
+
|
1155 |
+
def evaluation_loop(
|
1156 |
+
self,
|
1157 |
+
dataloader: DataLoader,
|
1158 |
+
description: str,
|
1159 |
+
prediction_loss_only: Optional[bool] = None,
|
1160 |
+
ignore_keys: Optional[List[str]] = None,
|
1161 |
+
metric_key_prefix: str = "eval",
|
1162 |
+
) -> EvalLoopOutput:
|
1163 |
+
"""
|
1164 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1165 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1166 |
+
|
1167 |
+
Works both with or without labels.
|
1168 |
+
"""
|
1169 |
+
|
1170 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1171 |
+
if self.generate_during_eval:
|
1172 |
+
# Generate random indices within the range of the total number of samples
|
1173 |
+
num_samples = len(dataloader.dataset)
|
1174 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1175 |
+
|
1176 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1177 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1178 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1179 |
+
random_batch = self._prepare_inputs(random_batch)
|
1180 |
+
|
1181 |
+
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
|
1182 |
+
|
1183 |
+
self.log(
|
1184 |
+
{
|
1185 |
+
"game_log": wandb.Table(
|
1186 |
+
columns=["Prompt", "Policy"],
|
1187 |
+
rows=[
|
1188 |
+
[prompt, pol[len(prompt) :]]
|
1189 |
+
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1190 |
+
],
|
1191 |
+
)
|
1192 |
+
}
|
1193 |
+
)
|
1194 |
+
self.state.log_history.pop()
|
1195 |
+
|
1196 |
+
# Base evaluation
|
1197 |
+
initial_output = super().evaluation_loop(
|
1198 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
return initial_output
|
1202 |
+
|
1203 |
+
def log(self, logs: Dict[str, float]) -> None:
|
1204 |
+
"""
|
1205 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1206 |
+
|
1207 |
+
Args:
|
1208 |
+
logs (`Dict[str, float]`):
|
1209 |
+
The values to log.
|
1210 |
+
"""
|
1211 |
+
# logs either has 'loss' or 'eval_loss'
|
1212 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1213 |
+
# Add averaged stored metrics to logs
|
1214 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1215 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1216 |
+
del self._stored_metrics[train_eval]
|
1217 |
+
return super().log(logs)
|
1218 |
+
|
1219 |
+
def _shift_right(self, input_ids):
|
1220 |
+
if self.decoder_start_token_id is None:
|
1221 |
+
raise ValueError(
|
1222 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
# shift inputs to the right
|
1226 |
+
if is_torch_fx_proxy(input_ids):
|
1227 |
+
# Item assignment is not supported natively for proxies.
|
1228 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1229 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1230 |
+
else:
|
1231 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1232 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1233 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1234 |
+
|
1235 |
+
if self.pad_token_id is None:
|
1236 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1237 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1238 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1239 |
+
|
1240 |
+
return shifted_input_ids
|
1241 |
+
|
1242 |
+
@wraps(Trainer.push_to_hub)
|
1243 |
+
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
1244 |
+
"""
|
1245 |
+
Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the
|
1246 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
1247 |
+
"""
|
1248 |
+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
1249 |
+
|
1250 |
+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
1251 |
+
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
1252 |
+
"""
|
1253 |
+
|
1254 |
+
Initialize ORPOTrainer.
|
1255 |
+
|
1256 |
+
Args:
|
1257 |
+
model (`transformers.PreTrainedModel`):
|
1258 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1259 |
+
args (`ORPOConfig`):
|
1260 |
+
The ORPO config arguments to use for training.
|
1261 |
+
data_collator (`transformers.DataCollator`):
|
1262 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1263 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1264 |
+
train_dataset (`datasets.Dataset`):
|
1265 |
+
The dataset to use for training.
|
1266 |
+
eval_dataset (`datasets.Dataset`):
|
1267 |
+
The dataset to use for evaluation.
|
1268 |
+
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
1269 |
+
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
1270 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1271 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1272 |
+
callbacks (`List[transformers.TrainerCallback]`):
|
1273 |
+
The callbacks to use for training.
|
1274 |
+
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1275 |
+
The optimizer and scheduler to use for training.
|
1276 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1277 |
+
The function to use to preprocess the logits before computing the metrics.
|
1278 |
+
peft_config (`Dict`, defaults to `None`):
|
1279 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1280 |
+
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
1281 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1282 |
+
a dictionary string to metric values.
|
1283 |
+
|
1284 |
+
"""
|
1285 |
+
def __init__(
|
1286 |
+
self,
|
1287 |
+
model = None,
|
1288 |
+
args = None,
|
1289 |
+
data_collator = None,
|
1290 |
+
train_dataset = None,
|
1291 |
+
eval_dataset = None,
|
1292 |
+
tokenizer = None,
|
1293 |
+
model_init = None,
|
1294 |
+
callbacks = None,
|
1295 |
+
preprocess_logits_for_metrics = None,
|
1296 |
+
peft_config = None,
|
1297 |
+
compute_metrics = None,
|
1298 |
+
**kwargs
|
1299 |
+
):
|
1300 |
+
if args is None: args = UnslothORPOConfig()
|
1301 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1302 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
1303 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1304 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
1305 |
+
force_float32 = False
|
1306 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1307 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1308 |
+
force_float32 = True
|
1309 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1310 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1311 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1312 |
+
from unsloth_zoo.utils import _get_dtype
|
1313 |
+
dtype = _get_dtype(dtype)
|
1314 |
+
float16 = dtype == torch.float16
|
1315 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1316 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1317 |
+
if force_float32:
|
1318 |
+
args.fp16 = False
|
1319 |
+
args.bf16 = False
|
1320 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1321 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1322 |
+
args.fp16 = float16
|
1323 |
+
args.bf16 = not float16
|
1324 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1325 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1326 |
+
args.eval_strategy = 'steps'
|
1327 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1328 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1329 |
+
if ga_steps is not None and ga_steps > 1:
|
1330 |
+
from transformers import __version__ as transformers_version
|
1331 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1332 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1333 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1334 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1335 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1336 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1337 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1338 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1339 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
1340 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1341 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
1342 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1343 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1344 |
+
if force_float32:
|
1345 |
+
args.bf16_full_eval = False
|
1346 |
+
args.fp16_full_eval = False
|
1347 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1348 |
+
args.bf16_full_eval = True
|
1349 |
+
args.fp16_full_eval = False
|
1350 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1351 |
+
args.bf16_full_eval = args.bf16
|
1352 |
+
args.fp16_full_eval = args.fp16
|
1353 |
+
_output_logits = False
|
1354 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1355 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1356 |
+
if _output_logits:
|
1357 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1358 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1359 |
+
pass
|
1360 |
+
else:
|
1361 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1362 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1363 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1364 |
+
max_seq_length = model.max_seq_length
|
1365 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1366 |
+
if model is not None and hasattr(model, 'for_training'):
|
1367 |
+
model.for_training()
|
1368 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1369 |
+
if 'processing_class' in locals():
|
1370 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1371 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1372 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1373 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1374 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1375 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1376 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
1377 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1378 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1379 |
+
else:
|
1380 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1381 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1382 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1383 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1384 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1385 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1386 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1387 |
+
else:
|
1388 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
1389 |
+
other_metrics = []
|
1390 |
+
|
1391 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1392 |
+
PatchRLStatistics('orpo_trainer', other_metrics)
|
1393 |
+
|
1394 |
+
super().__init__(
|
1395 |
+
model = model,
|
1396 |
+
args = args,
|
1397 |
+
data_collator = data_collator,
|
1398 |
+
train_dataset = train_dataset,
|
1399 |
+
eval_dataset = eval_dataset,
|
1400 |
+
tokenizer = tokenizer,
|
1401 |
+
model_init = model_init,
|
1402 |
+
callbacks = callbacks,
|
1403 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1404 |
+
peft_config = peft_config,
|
1405 |
+
compute_metrics = compute_metrics,**kwargs)
|
1406 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1407 |
+
self.neftune_hook_handle.remove()
|
1408 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1409 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1410 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1411 |
+
pass
|
1412 |
+
|
1413 |
+
pass
|
compilefcach/UnslothPPOTrainer.py
ADDED
@@ -0,0 +1,1566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ppo_trainer import (Accelerator, Adam, AdaptiveKLController, BaseTrainer, Callable, DataCollatorForLanguageModeling, Dataset, F, FixedKLController, List, MODEL_CARD_TEMPLATE, Optional, PPOConfig, PPODecorators, PPOTrainer, PreTrainedModelWrapper, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ProjectConfiguration, RunningMoments, SUPPORTED_ARCHITECTURES, Union, WANDB_PADDING, clip_by_value, convert_to_scalar, create_reference_model, datasets, entropy_from_logits, flatten_dict, gather_object, inspect, is_npu_available, is_torch_greater_2_0, is_xpu_available, logprobs_from_logits, masked_mean, masked_var, masked_whiten, math, np, nullcontext, os, set_seed, stack_dicts, stats_to_np, time, torch, typing, unwrap_model_for_generation, version, warnings, whoami)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothPPOConfig(PPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for PPOTrainer
|
47 |
+
|
48 |
+
"""
|
49 |
+
vllm_sampling_params: Optional[Any] = field(
|
50 |
+
default = None,
|
51 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
52 |
+
)
|
53 |
+
unsloth_num_chunks : Optional[int] = field(
|
54 |
+
default = -1,
|
55 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
56 |
+
)
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
exp_name = 'colab_kernel_launcher',
|
60 |
+
seed = 3407,
|
61 |
+
log_with = None,
|
62 |
+
task_name = None,
|
63 |
+
model_name = 'gpt2',
|
64 |
+
query_dataset = 'imdb',
|
65 |
+
reward_model = 'sentiment-analysis:lvwerra/distilbert-imdb',
|
66 |
+
remove_unused_columns = True,
|
67 |
+
tracker_project_name = 'trl',
|
68 |
+
steps = 20000,
|
69 |
+
learning_rate = 5e-05,
|
70 |
+
adap_kl_ctrl = True,
|
71 |
+
init_kl_coef = 0.2,
|
72 |
+
kl_penalty = 'kl',
|
73 |
+
target = 6,
|
74 |
+
horizon = 10000,
|
75 |
+
gamma = 1,
|
76 |
+
lam = 0.95,
|
77 |
+
cliprange = 0.2,
|
78 |
+
cliprange_value = 0.2,
|
79 |
+
vf_coef = 0.1,
|
80 |
+
batch_size = 128,
|
81 |
+
forward_batch_size = None,
|
82 |
+
mini_batch_size = 128,
|
83 |
+
gradient_accumulation_steps = 2,
|
84 |
+
world_size = None,
|
85 |
+
ppo_epochs = 4,
|
86 |
+
max_grad_norm = None,
|
87 |
+
optimize_cuda_cache = None,
|
88 |
+
optimize_device_cache = False,
|
89 |
+
early_stopping = False,
|
90 |
+
target_kl = 1,
|
91 |
+
compare_steps = 1,
|
92 |
+
ratio_threshold = 10.0,
|
93 |
+
use_score_scaling = False,
|
94 |
+
use_score_norm = False,
|
95 |
+
score_clip = None,
|
96 |
+
whiten_rewards = False,
|
97 |
+
is_encoder_decoder = None,
|
98 |
+
is_peft_model = None,
|
99 |
+
backward_batch_size = None,
|
100 |
+
global_backward_batch_size = None,
|
101 |
+
global_batch_size = None,
|
102 |
+
vllm_sampling_params = None,
|
103 |
+
unsloth_num_chunks = -1,
|
104 |
+
**kwargs,
|
105 |
+
):
|
106 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
107 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
108 |
+
|
109 |
+
super().__init__(
|
110 |
+
exp_name = exp_name,
|
111 |
+
seed = seed,
|
112 |
+
log_with = log_with,
|
113 |
+
task_name = task_name,
|
114 |
+
model_name = model_name,
|
115 |
+
query_dataset = query_dataset,
|
116 |
+
reward_model = reward_model,
|
117 |
+
remove_unused_columns = remove_unused_columns,
|
118 |
+
tracker_project_name = tracker_project_name,
|
119 |
+
steps = steps,
|
120 |
+
learning_rate = learning_rate,
|
121 |
+
adap_kl_ctrl = adap_kl_ctrl,
|
122 |
+
init_kl_coef = init_kl_coef,
|
123 |
+
kl_penalty = kl_penalty,
|
124 |
+
target = target,
|
125 |
+
horizon = horizon,
|
126 |
+
gamma = gamma,
|
127 |
+
lam = lam,
|
128 |
+
cliprange = cliprange,
|
129 |
+
cliprange_value = cliprange_value,
|
130 |
+
vf_coef = vf_coef,
|
131 |
+
batch_size = batch_size,
|
132 |
+
forward_batch_size = forward_batch_size,
|
133 |
+
mini_batch_size = mini_batch_size,
|
134 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
135 |
+
world_size = world_size,
|
136 |
+
ppo_epochs = ppo_epochs,
|
137 |
+
max_grad_norm = max_grad_norm,
|
138 |
+
optimize_cuda_cache = optimize_cuda_cache,
|
139 |
+
optimize_device_cache = optimize_device_cache,
|
140 |
+
early_stopping = early_stopping,
|
141 |
+
target_kl = target_kl,
|
142 |
+
compare_steps = compare_steps,
|
143 |
+
ratio_threshold = ratio_threshold,
|
144 |
+
use_score_scaling = use_score_scaling,
|
145 |
+
use_score_norm = use_score_norm,
|
146 |
+
score_clip = score_clip,
|
147 |
+
whiten_rewards = whiten_rewards,
|
148 |
+
is_encoder_decoder = is_encoder_decoder,
|
149 |
+
is_peft_model = is_peft_model,
|
150 |
+
backward_batch_size = backward_batch_size,
|
151 |
+
global_backward_batch_size = global_backward_batch_size,
|
152 |
+
global_batch_size = global_batch_size,**kwargs)
|
153 |
+
self.vllm_sampling_params = vllm_sampling_params
|
154 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
155 |
+
pass
|
156 |
+
|
157 |
+
class _UnslothPPOTrainer(BaseTrainer):
|
158 |
+
""""""
|
159 |
+
|
160 |
+
_tag_names = ["trl", "ppo"]
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
config: Optional[PPOConfig] = None,
|
165 |
+
model: Optional[PreTrainedModelWrapper] = None,
|
166 |
+
ref_model: Optional[PreTrainedModelWrapper] = None,
|
167 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
168 |
+
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
|
169 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
170 |
+
data_collator: Optional[typing.Callable] = None,
|
171 |
+
num_shared_layers: Optional[int] = None,
|
172 |
+
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
Initialize PPOTrainer.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
config (`PPOConfig`):
|
179 |
+
Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
|
180 |
+
model (`PreTrainedModelWrapper`):
|
181 |
+
Hugging Face transformer model with a value head.
|
182 |
+
ref_model (`PreTrainedModelWrapper`):
|
183 |
+
Hugging Face transformer model with a casual language modelling head. Used for KL penalty
|
184 |
+
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
185 |
+
Hugging Face tokenizer
|
186 |
+
dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
|
187 |
+
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
|
188 |
+
will be preprocessed by removing the columns that are not used by the model. If none is passed,
|
189 |
+
a warning will be raised in a multi-GPU setting.
|
190 |
+
optimizer (Optional[`torch.optim.Optimizer`]):
|
191 |
+
Optimizer used for training. If `None`, the `Adam` is used as default.
|
192 |
+
data_collator (Optional[function]):
|
193 |
+
Data collator function.
|
194 |
+
num_shared_layers (Optional[int]):
|
195 |
+
Number of shared layers between the model and the reference model. If `None`, all layers are shared.
|
196 |
+
used only if `ref_model` is `None`.
|
197 |
+
lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
|
198 |
+
Learning rate scheduler used for training.
|
199 |
+
"""
|
200 |
+
super().__init__(config)
|
201 |
+
|
202 |
+
# initial seed for reproducible experiments
|
203 |
+
set_seed(config.seed)
|
204 |
+
|
205 |
+
# Step 0: check positional arguments validity
|
206 |
+
if not isinstance(config, PPOConfig):
|
207 |
+
raise ValueError(f"config must be a PPOConfig, got {type(config)}")
|
208 |
+
if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
|
209 |
+
raise ValueError(
|
210 |
+
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
|
211 |
+
)
|
212 |
+
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
|
213 |
+
raise ValueError(
|
214 |
+
f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
|
215 |
+
)
|
216 |
+
# Step 1: Initialize Accelerator
|
217 |
+
self.accelerator = Accelerator(
|
218 |
+
log_with=config.log_with,
|
219 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
220 |
+
project_config=ProjectConfiguration(**config.project_kwargs),
|
221 |
+
**config.accelerator_kwargs,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Step 1.1 Runtime variables filled by the accelerator
|
225 |
+
config.world_size = self.accelerator.num_processes
|
226 |
+
config.global_backward_batch_size = config.backward_batch_size * config.world_size
|
227 |
+
config.global_batch_size = config.batch_size * config.world_size
|
228 |
+
|
229 |
+
self.model = model
|
230 |
+
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
231 |
+
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
|
232 |
+
self.is_peft_model = getattr(self.model, "is_peft_model", False)
|
233 |
+
config.is_encoder_decoder = self.is_encoder_decoder
|
234 |
+
config.is_peft_model = self.is_peft_model
|
235 |
+
|
236 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
237 |
+
self.accelerator.init_trackers(
|
238 |
+
config.tracker_project_name,
|
239 |
+
config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
240 |
+
init_kwargs=config.tracker_kwargs,
|
241 |
+
)
|
242 |
+
self.is_using_text_environment = getattr(config, "use_text_environment", False)
|
243 |
+
|
244 |
+
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
|
245 |
+
self.ref_model = ref_model
|
246 |
+
if num_shared_layers is not None:
|
247 |
+
warnings.warn(
|
248 |
+
"num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
|
249 |
+
"model and the reference model and no layers are shared.",
|
250 |
+
UserWarning,
|
251 |
+
)
|
252 |
+
elif ref_model is None and not self.is_peft_model:
|
253 |
+
self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
|
254 |
+
elif self.is_peft_model:
|
255 |
+
self.ref_model = None
|
256 |
+
else:
|
257 |
+
raise ValueError(
|
258 |
+
f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
|
259 |
+
f"architectures are: {SUPPORTED_ARCHITECTURES} "
|
260 |
+
)
|
261 |
+
self.optional_peft_ctx = (
|
262 |
+
self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
|
263 |
+
if self.is_peft_model
|
264 |
+
else nullcontext
|
265 |
+
)
|
266 |
+
|
267 |
+
if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
|
268 |
+
raise ValueError(
|
269 |
+
"tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
|
270 |
+
)
|
271 |
+
self.tokenizer = tokenizer
|
272 |
+
|
273 |
+
if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
|
274 |
+
raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
|
275 |
+
elif dataset is None:
|
276 |
+
warnings.warn(
|
277 |
+
"No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
|
278 |
+
UserWarning,
|
279 |
+
)
|
280 |
+
self.dataset = dataset
|
281 |
+
self._signature_columns = None
|
282 |
+
if self.dataset is not None:
|
283 |
+
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
|
284 |
+
elif self.dataset is None and self.accelerator.num_processes > 1:
|
285 |
+
warnings.warn(
|
286 |
+
"No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
|
287 |
+
" prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
|
288 |
+
" and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
|
289 |
+
" refer to the documentation for more details.",
|
290 |
+
UserWarning,
|
291 |
+
)
|
292 |
+
self.dataloader = None
|
293 |
+
else:
|
294 |
+
self.dataloader = None
|
295 |
+
|
296 |
+
# Step 3: Initialize optimizer and data collator
|
297 |
+
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
|
298 |
+
if optimizer is None:
|
299 |
+
self.optimizer = Adam(
|
300 |
+
filter(lambda p: p.requires_grad, self.model.parameters()),
|
301 |
+
lr=self.config.learning_rate,
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
self.optimizer = optimizer
|
305 |
+
|
306 |
+
self.lr_scheduler = lr_scheduler
|
307 |
+
if self.lr_scheduler is not None:
|
308 |
+
lr_scheduler_class = (
|
309 |
+
torch.optim.lr_scheduler._LRScheduler
|
310 |
+
if not is_torch_greater_2_0()
|
311 |
+
else torch.optim.lr_scheduler.LRScheduler
|
312 |
+
)
|
313 |
+
|
314 |
+
if not isinstance(self.lr_scheduler, lr_scheduler_class):
|
315 |
+
raise ValueError(
|
316 |
+
"lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
|
317 |
+
)
|
318 |
+
|
319 |
+
if self.config.adap_kl_ctrl:
|
320 |
+
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
|
321 |
+
else:
|
322 |
+
self.kl_ctl = FixedKLController(self.config.init_kl_coef)
|
323 |
+
|
324 |
+
# Safety checkers for DS integration
|
325 |
+
is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
326 |
+
self.accelerator.state, "deepspeed_plugin"
|
327 |
+
)
|
328 |
+
|
329 |
+
(
|
330 |
+
self.model,
|
331 |
+
self.optimizer,
|
332 |
+
self.data_collator,
|
333 |
+
self.dataloader,
|
334 |
+
self.lr_scheduler,
|
335 |
+
) = self.accelerator.prepare(
|
336 |
+
self.model,
|
337 |
+
self.optimizer,
|
338 |
+
self.data_collator,
|
339 |
+
self.dataloader,
|
340 |
+
self.lr_scheduler,
|
341 |
+
)
|
342 |
+
if is_deepspeed_used:
|
343 |
+
# Quantized models are already set on the correct device
|
344 |
+
if not self.is_peft_model and not (
|
345 |
+
getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
|
346 |
+
or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
|
347 |
+
):
|
348 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
349 |
+
else:
|
350 |
+
self.ref_model = self.accelerator.prepare(self.ref_model)
|
351 |
+
|
352 |
+
# In a distributed setup, only logging needs to be performed on the main process
|
353 |
+
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
354 |
+
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
|
355 |
+
self.is_distributed = self.accelerator.num_processes > 1
|
356 |
+
|
357 |
+
# init the current step
|
358 |
+
self.current_step = 0
|
359 |
+
|
360 |
+
# init variables for pushing model to hub
|
361 |
+
if config.push_to_hub_if_best_kwargs:
|
362 |
+
if "repo_id" not in config.push_to_hub_if_best_kwargs:
|
363 |
+
raise ValueError("You have to specify repo_id in order to push the model to the hub!")
|
364 |
+
self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
|
365 |
+
self.compare_step = 0
|
366 |
+
self.highest_reward = torch.tensor(-float("inf"))
|
367 |
+
|
368 |
+
# post process for PP
|
369 |
+
if not getattr(self.model, "is_sequential_parallel", False):
|
370 |
+
self.current_device = self.accelerator.device
|
371 |
+
else:
|
372 |
+
if is_xpu_available():
|
373 |
+
self.current_device = torch.device("xpu:0")
|
374 |
+
elif is_npu_available():
|
375 |
+
self.current_device = torch.device("npu:0")
|
376 |
+
else:
|
377 |
+
self.current_device = torch.device("cuda:0")
|
378 |
+
|
379 |
+
PPODecorators.optimize_device_cache = self.config.optimize_device_cache
|
380 |
+
|
381 |
+
self.running = RunningMoments(self.accelerator)
|
382 |
+
|
383 |
+
def _filter_kwargs(self, kwargs, target_func):
|
384 |
+
"""
|
385 |
+
filter the keyword arguments that are supported by the target function.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
kwargs (dict):
|
389 |
+
Keyword arguments
|
390 |
+
target_func (function):
|
391 |
+
Target function
|
392 |
+
"""
|
393 |
+
return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
|
394 |
+
|
395 |
+
def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
|
396 |
+
"""
|
397 |
+
Prepare the dataloader for training.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
|
401 |
+
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
|
402 |
+
will be preprocessed by removing the columns that are not used by the model.
|
403 |
+
data_collator (Optional[function]):
|
404 |
+
Data collator function.
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
`torch.utils.data.DataLoader`: PyTorch dataloader
|
408 |
+
"""
|
409 |
+
if isinstance(dataset, Dataset):
|
410 |
+
dataset = self._remove_unused_columns(dataset)
|
411 |
+
dataloader = torch.utils.data.DataLoader(
|
412 |
+
dataset,
|
413 |
+
batch_size=self.config.batch_size,
|
414 |
+
collate_fn=data_collator,
|
415 |
+
shuffle=True,
|
416 |
+
drop_last=True,
|
417 |
+
)
|
418 |
+
return dataloader
|
419 |
+
|
420 |
+
# Adapted from transformers.Trainer._set_signature_columns_if_needed
|
421 |
+
def _set_signature_columns_if_needed(self):
|
422 |
+
if self._signature_columns is None:
|
423 |
+
# Inspect model forward signature to keep only the arguments it accepts.
|
424 |
+
signature = inspect.signature(self.model.forward)
|
425 |
+
self._signature_columns = list(signature.parameters.keys())
|
426 |
+
# label => sentiment | we need query and response for logging purpose
|
427 |
+
self._signature_columns += ["label", "query", "response"]
|
428 |
+
|
429 |
+
# Adapted from transformers.Trainer._remove_unused_columns
|
430 |
+
def _remove_unused_columns(self, dataset: "Dataset"):
|
431 |
+
if not self.config.remove_unused_columns:
|
432 |
+
return dataset
|
433 |
+
self._set_signature_columns_if_needed()
|
434 |
+
signature_columns = self._signature_columns
|
435 |
+
|
436 |
+
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
437 |
+
|
438 |
+
columns = [k for k in signature_columns if k in dataset.column_names]
|
439 |
+
|
440 |
+
if version.parse(datasets.__version__) < version.parse("1.4.0"):
|
441 |
+
dataset.set_format(
|
442 |
+
type=dataset.format["type"],
|
443 |
+
columns=columns,
|
444 |
+
format_kwargs=dataset.format["format_kwargs"],
|
445 |
+
)
|
446 |
+
return dataset
|
447 |
+
else:
|
448 |
+
return dataset.remove_columns(ignored_columns)
|
449 |
+
|
450 |
+
def generate(
|
451 |
+
self,
|
452 |
+
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
|
453 |
+
length_sampler: Optional[Callable] = None,
|
454 |
+
batch_size: int = 4,
|
455 |
+
return_prompt: bool = True,
|
456 |
+
generate_ref_response: bool = False,
|
457 |
+
**generation_kwargs,
|
458 |
+
):
|
459 |
+
"""
|
460 |
+
Generate response with the model given the query tensor.
|
461 |
+
call the `generate` method of the model.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
query_tensor (`torch.LongTensor`):
|
465 |
+
A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
|
466 |
+
length_sampler (`Callable`, *optional*):
|
467 |
+
Callable that returns the number of newly generated tokens.
|
468 |
+
batch_size (`int`, *optional):
|
469 |
+
Batch size used for generation, defaults to `4`.
|
470 |
+
return_prompt (`bool`, *optional*):
|
471 |
+
If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
|
472 |
+
generate_ref_response (`bool`, *optional*):
|
473 |
+
If set to `True` the reference response is also generated, defaults to `False`.
|
474 |
+
generation_kwargs (dict[str, Any]):
|
475 |
+
Keyword arguments for generation.
|
476 |
+
|
477 |
+
Returns:
|
478 |
+
`torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
|
479 |
+
"""
|
480 |
+
if generate_ref_response:
|
481 |
+
ref_model = self.model if self.is_peft_model else self.ref_model
|
482 |
+
if isinstance(query_tensor, List):
|
483 |
+
response = self._generate_batched(
|
484 |
+
self.model,
|
485 |
+
query_tensor,
|
486 |
+
length_sampler=length_sampler,
|
487 |
+
batch_size=batch_size,
|
488 |
+
return_prompt=return_prompt,
|
489 |
+
**generation_kwargs,
|
490 |
+
)
|
491 |
+
if generate_ref_response:
|
492 |
+
ref_response = self._generate_batched(
|
493 |
+
ref_model,
|
494 |
+
query_tensor,
|
495 |
+
length_sampler=length_sampler,
|
496 |
+
batch_size=batch_size,
|
497 |
+
return_prompt=return_prompt,
|
498 |
+
**generation_kwargs,
|
499 |
+
)
|
500 |
+
|
501 |
+
else:
|
502 |
+
if len(query_tensor.shape) == 2:
|
503 |
+
raise ValueError(
|
504 |
+
"query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
|
505 |
+
)
|
506 |
+
|
507 |
+
if length_sampler is not None:
|
508 |
+
generation_kwargs["max_new_tokens"] = length_sampler()
|
509 |
+
|
510 |
+
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
511 |
+
response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
|
512 |
+
|
513 |
+
if generate_ref_response:
|
514 |
+
with unwrap_model_for_generation(
|
515 |
+
ref_model, self.accelerator, is_peft_model=self.is_peft_model
|
516 |
+
) as unwrapped_model:
|
517 |
+
ref_response = unwrapped_model.generate(
|
518 |
+
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
|
519 |
+
)
|
520 |
+
|
521 |
+
if not return_prompt and not self.is_encoder_decoder:
|
522 |
+
response = response[:, query_tensor.shape[0] :]
|
523 |
+
if generate_ref_response:
|
524 |
+
ref_response = ref_response[:, query_tensor.shape[0] :]
|
525 |
+
|
526 |
+
if generate_ref_response:
|
527 |
+
return response, ref_response
|
528 |
+
return response
|
529 |
+
|
530 |
+
def _generate_batched(
|
531 |
+
self,
|
532 |
+
model: PreTrainedModelWrapper,
|
533 |
+
query_tensors: List[torch.Tensor],
|
534 |
+
length_sampler: Optional[Callable] = None,
|
535 |
+
batch_size: int = 4,
|
536 |
+
return_prompt: bool = True,
|
537 |
+
pad_to_multiple_of: Optional[int] = None,
|
538 |
+
remove_padding: bool = True,
|
539 |
+
**generation_kwargs,
|
540 |
+
):
|
541 |
+
outputs = []
|
542 |
+
|
543 |
+
padding_side_default = self.tokenizer.padding_side
|
544 |
+
if not self.is_encoder_decoder:
|
545 |
+
self.tokenizer.padding_side = "left"
|
546 |
+
|
547 |
+
# in case we have fewer examples than bs
|
548 |
+
batch_size = min(len(query_tensors), batch_size)
|
549 |
+
|
550 |
+
for i in range(0, len(query_tensors), batch_size):
|
551 |
+
if length_sampler is not None:
|
552 |
+
generation_kwargs["max_new_tokens"] = length_sampler()
|
553 |
+
|
554 |
+
# prevent overflow if query tensors are not even multiple of bs
|
555 |
+
end_index = min(len(query_tensors), i + batch_size)
|
556 |
+
|
557 |
+
batch = query_tensors[i:end_index]
|
558 |
+
batch_mask = [torch.ones_like(element) for element in batch]
|
559 |
+
inputs = {"input_ids": batch, "attention_mask": batch_mask}
|
560 |
+
|
561 |
+
padded_inputs = self.tokenizer.pad(
|
562 |
+
inputs,
|
563 |
+
padding=True,
|
564 |
+
max_length=None,
|
565 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
566 |
+
return_tensors="pt",
|
567 |
+
).to(self.current_device)
|
568 |
+
|
569 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
570 |
+
generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
|
571 |
+
|
572 |
+
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
|
573 |
+
if not self.is_encoder_decoder:
|
574 |
+
output = generation[(1 - mask).sum() :] # remove padding
|
575 |
+
else:
|
576 |
+
output = generation
|
577 |
+
|
578 |
+
if not return_prompt and not self.is_encoder_decoder:
|
579 |
+
output = output[(mask).sum() :] # remove prompt
|
580 |
+
|
581 |
+
if remove_padding and self.tokenizer.eos_token_id in output:
|
582 |
+
pad_mask = output == self.tokenizer.eos_token_id
|
583 |
+
pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
|
584 |
+
output = output[: pad_start + 1] # keep the eos token at the end
|
585 |
+
|
586 |
+
outputs.append(output)
|
587 |
+
|
588 |
+
self.tokenizer.padding_side = padding_side_default
|
589 |
+
return outputs
|
590 |
+
|
591 |
+
def _step_safety_checker(
|
592 |
+
self,
|
593 |
+
batch_size: int,
|
594 |
+
queries: List[torch.LongTensor],
|
595 |
+
responses: List[torch.LongTensor],
|
596 |
+
scores: List[torch.FloatTensor],
|
597 |
+
masks: Optional[List[torch.LongTensor]] = None,
|
598 |
+
):
|
599 |
+
"""
|
600 |
+
Check if the input data is valid for training.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
batch_size (int):
|
604 |
+
Batch size from the config file.
|
605 |
+
queries (List[`torch.LongTensor`]):
|
606 |
+
List of tensors containing the encoded queries of shape (`query_length`)
|
607 |
+
responses (List[`torch.LongTensor`]):
|
608 |
+
List of tensors containing the encoded responses of shape (`response_length`)
|
609 |
+
scores (List[`torch.FloatTensor`]):
|
610 |
+
List of tensors containing the scores.
|
611 |
+
masks (List[`torch.LongTensor`], *optional*):
|
612 |
+
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
|
613 |
+
Returns:
|
614 |
+
`tuple`: The input processed data.
|
615 |
+
"""
|
616 |
+
for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
|
617 |
+
if not isinstance(tensor_list, list):
|
618 |
+
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
|
619 |
+
if not isinstance(tensor_list[0], torch.Tensor):
|
620 |
+
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
|
621 |
+
if batch_size is not None and len(tensor_list) != batch_size:
|
622 |
+
raise ValueError(
|
623 |
+
f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
|
624 |
+
)
|
625 |
+
|
626 |
+
# add queries, scores and responses on the correct device
|
627 |
+
queries = [tensor.to(self.current_device) for tensor in queries]
|
628 |
+
responses = [tensor.to(self.current_device) for tensor in responses]
|
629 |
+
scores = [tensor.to(self.current_device) for tensor in scores]
|
630 |
+
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
|
631 |
+
|
632 |
+
# squeeze scores if needed
|
633 |
+
for i, score in enumerate(scores):
|
634 |
+
if score.dim() > 1:
|
635 |
+
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
|
636 |
+
elif score.dim() == 1:
|
637 |
+
scores[i] = score.squeeze()
|
638 |
+
|
639 |
+
return queries, responses, scores, masks
|
640 |
+
|
641 |
+
@PPODecorators.empty_device_cache()
|
642 |
+
def step(
|
643 |
+
self,
|
644 |
+
queries: List[torch.LongTensor],
|
645 |
+
responses: List[torch.LongTensor],
|
646 |
+
scores: List[torch.FloatTensor],
|
647 |
+
response_masks: Optional[List[torch.LongTensor]] = None,
|
648 |
+
):
|
649 |
+
"""
|
650 |
+
Run a PPO optimisation step given a list of queries, model responses, and rewards.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
queries (List[`torch.LongTensor`]):
|
654 |
+
List of tensors containing the encoded queries of shape (`query_length`)
|
655 |
+
responses (List[`torch.LongTensor`]):
|
656 |
+
List of tensors containing the encoded responses of shape (`response_length`)
|
657 |
+
scores (List[`torch.FloatTensor`]):
|
658 |
+
List of tensors containing the scores.
|
659 |
+
response_masks (List[`torch.FloatTensor`], *optional*)):
|
660 |
+
List of tensors containing masks of the response tokens.
|
661 |
+
|
662 |
+
Returns:
|
663 |
+
`dict[str, Any]`: A summary of the training statistics
|
664 |
+
"""
|
665 |
+
bs = self.config.batch_size
|
666 |
+
|
667 |
+
queries, responses, scores, response_masks = self._step_safety_checker(
|
668 |
+
bs, queries, responses, scores, response_masks
|
669 |
+
)
|
670 |
+
scores = torch.tensor(scores, device=self.current_device)
|
671 |
+
if self.config.use_score_scaling:
|
672 |
+
# Score scaling
|
673 |
+
scores_mean, scores_std = self.running.update(scores)
|
674 |
+
tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
|
675 |
+
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
|
676 |
+
if self.config.use_score_norm:
|
677 |
+
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
|
678 |
+
else:
|
679 |
+
scores /= score_scaling_factor
|
680 |
+
|
681 |
+
if self.config.score_clip is not None:
|
682 |
+
# Score clipping
|
683 |
+
scores_dtype = scores.dtype
|
684 |
+
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
|
685 |
+
|
686 |
+
# if we want to push best model to the hub
|
687 |
+
if hasattr(self, "highest_reward"):
|
688 |
+
if self.compare_step % self.config.compare_steps == 0:
|
689 |
+
curr_mean_reward = scores.mean()
|
690 |
+
# if the best reward ever seen
|
691 |
+
if curr_mean_reward > self.highest_reward:
|
692 |
+
self.highest_reward = curr_mean_reward
|
693 |
+
# push model to hub
|
694 |
+
self.push_to_hub(**self.push_to_hub_kwargs)
|
695 |
+
self.compare_step += 1
|
696 |
+
|
697 |
+
timing = dict()
|
698 |
+
t0 = time.time()
|
699 |
+
|
700 |
+
t = time.time()
|
701 |
+
|
702 |
+
model_inputs = self.prepare_model_inputs(queries, responses)
|
703 |
+
|
704 |
+
if self.is_distributed:
|
705 |
+
pad_first = self.tokenizer.padding_side == "left"
|
706 |
+
|
707 |
+
model_inputs["input_ids"] = self.accelerator.pad_across_processes(
|
708 |
+
model_inputs["input_ids"],
|
709 |
+
dim=1,
|
710 |
+
pad_index=self.tokenizer.pad_token_id,
|
711 |
+
pad_first=pad_first,
|
712 |
+
)
|
713 |
+
model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
|
714 |
+
model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
|
715 |
+
)
|
716 |
+
if self.is_encoder_decoder:
|
717 |
+
model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
|
718 |
+
model_inputs["decoder_input_ids"],
|
719 |
+
dim=1,
|
720 |
+
pad_index=self.tokenizer.pad_token_id,
|
721 |
+
pad_first=pad_first,
|
722 |
+
)
|
723 |
+
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
|
724 |
+
model_inputs["decoder_attention_mask"],
|
725 |
+
dim=1,
|
726 |
+
pad_index=0,
|
727 |
+
pad_first=pad_first,
|
728 |
+
)
|
729 |
+
|
730 |
+
model_inputs_names = list(model_inputs.keys())
|
731 |
+
|
732 |
+
full_kl_penalty = self.config.kl_penalty == "full"
|
733 |
+
|
734 |
+
with torch.no_grad():
|
735 |
+
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
|
736 |
+
self.model,
|
737 |
+
queries,
|
738 |
+
responses,
|
739 |
+
model_inputs,
|
740 |
+
response_masks=response_masks,
|
741 |
+
return_logits=full_kl_penalty,
|
742 |
+
)
|
743 |
+
with self.optional_peft_ctx():
|
744 |
+
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
|
745 |
+
self.model if self.is_peft_model else self.ref_model,
|
746 |
+
queries,
|
747 |
+
responses,
|
748 |
+
model_inputs,
|
749 |
+
return_logits=full_kl_penalty,
|
750 |
+
)
|
751 |
+
|
752 |
+
timing["time/ppo/forward_pass"] = time.time() - t
|
753 |
+
|
754 |
+
with torch.no_grad():
|
755 |
+
t = time.time()
|
756 |
+
if full_kl_penalty:
|
757 |
+
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
|
758 |
+
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
|
759 |
+
|
760 |
+
rewards, non_score_reward, kls = self.compute_rewards(
|
761 |
+
scores, active_full_logprobs, ref_full_logprobs, masks
|
762 |
+
)
|
763 |
+
else:
|
764 |
+
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
|
765 |
+
timing["time/ppo/compute_rewards"] = time.time() - t
|
766 |
+
|
767 |
+
t = time.time()
|
768 |
+
values, advantages, returns = self.compute_advantages(values, rewards, masks)
|
769 |
+
timing["time/ppo/compute_advantages"] = time.time() - t
|
770 |
+
|
771 |
+
# upcast to float32 to avoid dataset issues
|
772 |
+
batch_dict = {
|
773 |
+
"queries": queries,
|
774 |
+
"responses": responses,
|
775 |
+
"logprobs": all_logprobs.to(torch.float32),
|
776 |
+
"values": values.to(torch.float32),
|
777 |
+
"masks": masks,
|
778 |
+
"advantages": advantages,
|
779 |
+
"returns": returns,
|
780 |
+
}
|
781 |
+
batch_dict.update(model_inputs)
|
782 |
+
|
783 |
+
t = time.time()
|
784 |
+
all_stats = []
|
785 |
+
early_stop = False
|
786 |
+
for _ in range(self.config.ppo_epochs):
|
787 |
+
if early_stop:
|
788 |
+
break
|
789 |
+
b_inds = np.random.permutation(bs)
|
790 |
+
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
|
791 |
+
backward_batch_end = backward_batch_start + self.config.backward_batch_size
|
792 |
+
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
|
793 |
+
|
794 |
+
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
|
795 |
+
mini_batch_end = mini_batch_start + self.config.mini_batch_size
|
796 |
+
mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
|
797 |
+
mini_batch_dict = {
|
798 |
+
"logprobs": batch_dict["logprobs"][mini_batch_inds],
|
799 |
+
"values": batch_dict["values"][mini_batch_inds],
|
800 |
+
"masks": batch_dict["masks"][mini_batch_inds],
|
801 |
+
# hacks: the queries and responses are ragged.
|
802 |
+
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
|
803 |
+
"responses": [batch_dict["responses"][i] for i in mini_batch_inds],
|
804 |
+
"advantages": batch_dict["advantages"][mini_batch_inds],
|
805 |
+
"returns": batch_dict["returns"][mini_batch_inds],
|
806 |
+
}
|
807 |
+
for k in model_inputs_names:
|
808 |
+
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
|
809 |
+
with self.accelerator.accumulate(self.model):
|
810 |
+
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
|
811 |
+
|
812 |
+
logprobs, logits, vpreds, _ = self.batched_forward_pass(
|
813 |
+
self.model,
|
814 |
+
mini_batch_dict["queries"],
|
815 |
+
mini_batch_dict["responses"],
|
816 |
+
model_inputs,
|
817 |
+
return_logits=True,
|
818 |
+
)
|
819 |
+
train_stats = self.train_minibatch(
|
820 |
+
mini_batch_dict["logprobs"],
|
821 |
+
mini_batch_dict["values"],
|
822 |
+
logprobs,
|
823 |
+
logits,
|
824 |
+
vpreds,
|
825 |
+
mini_batch_dict["masks"],
|
826 |
+
mini_batch_dict["advantages"],
|
827 |
+
mini_batch_dict["returns"],
|
828 |
+
)
|
829 |
+
all_stats.append(train_stats)
|
830 |
+
|
831 |
+
# typically, early stopping is done at the epoch level
|
832 |
+
if self.config.early_stopping:
|
833 |
+
policykl = train_stats["policy/policykl"]
|
834 |
+
early_stop = self._early_stop(policykl)
|
835 |
+
if early_stop:
|
836 |
+
break
|
837 |
+
|
838 |
+
timing["time/ppo/optimize_step"] = time.time() - t
|
839 |
+
|
840 |
+
t = time.time()
|
841 |
+
train_stats = stack_dicts(all_stats)
|
842 |
+
|
843 |
+
# reshape advantages/ratios such that they are not averaged.
|
844 |
+
train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
|
845 |
+
train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
|
846 |
+
train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
|
847 |
+
|
848 |
+
stats = self.record_step_stats(
|
849 |
+
scores=scores,
|
850 |
+
logprobs=all_logprobs,
|
851 |
+
ref_logprobs=ref_logprobs,
|
852 |
+
non_score_reward=non_score_reward,
|
853 |
+
train_stats=train_stats,
|
854 |
+
kl_coef=self.kl_ctl.value,
|
855 |
+
masks=masks,
|
856 |
+
queries=queries,
|
857 |
+
responses=responses,
|
858 |
+
kls=kls,
|
859 |
+
)
|
860 |
+
# Gather/Reduce stats from all processes
|
861 |
+
if self.is_distributed:
|
862 |
+
stats = self.gather_stats(stats)
|
863 |
+
stats = stats_to_np(stats)
|
864 |
+
timing["time/ppo/calc_stats"] = time.time() - t
|
865 |
+
stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
|
866 |
+
|
867 |
+
# Update the KL control - multiply the batch_size by the number of processes
|
868 |
+
self.kl_ctl.update(
|
869 |
+
stats["objective/kl"],
|
870 |
+
self.config.batch_size * self.accelerator.num_processes,
|
871 |
+
)
|
872 |
+
|
873 |
+
# Log the total ppo time
|
874 |
+
timing["time/ppo/total"] = time.time() - t0
|
875 |
+
stats.update(timing)
|
876 |
+
|
877 |
+
# post-process stats for tensorboard and other loggers
|
878 |
+
if self.config.log_with != "wandb":
|
879 |
+
stats = convert_to_scalar(stats)
|
880 |
+
|
881 |
+
if self.lr_scheduler is not None:
|
882 |
+
self.lr_scheduler.step()
|
883 |
+
|
884 |
+
return stats
|
885 |
+
|
886 |
+
def _early_stop(self, policykl):
|
887 |
+
r"""
|
888 |
+
Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
|
889 |
+
the optimization step is skipped.
|
890 |
+
This also handles the multi-gpu case where the policy KL is averaged across all processes.
|
891 |
+
|
892 |
+
Args:
|
893 |
+
policy_kl (torch.Tensor):
|
894 |
+
the policy KL
|
895 |
+
|
896 |
+
Returns:
|
897 |
+
`bool`: whether to early stop or not
|
898 |
+
"""
|
899 |
+
early_stop = False
|
900 |
+
if not self.config.early_stopping:
|
901 |
+
return early_stop
|
902 |
+
|
903 |
+
if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
|
904 |
+
self.optimizer.zero_grad()
|
905 |
+
early_stop = True
|
906 |
+
elif self.is_distributed:
|
907 |
+
import torch.distributed as dist
|
908 |
+
|
909 |
+
# Wait for all processes to finish
|
910 |
+
dist.barrier()
|
911 |
+
|
912 |
+
# all gather the policykl
|
913 |
+
dist.all_reduce(policykl, dist.ReduceOp.SUM)
|
914 |
+
policykl /= self.accelerator.num_processes
|
915 |
+
|
916 |
+
if policykl > 1.5 * self.config.target_kl:
|
917 |
+
self.optimizer.zero_grad()
|
918 |
+
early_stop = True
|
919 |
+
return early_stop
|
920 |
+
|
921 |
+
def gather_stats(self, stats):
|
922 |
+
"""
|
923 |
+
Gather stats from all processes. Useful in the context of distributed training.
|
924 |
+
|
925 |
+
Args:
|
926 |
+
stats (dict[str, Any]):
|
927 |
+
a dictionary of stats to be gathered. The stats should contain torch tensors.
|
928 |
+
|
929 |
+
Returns:
|
930 |
+
`dict[str, Any]`: A dictionary of stats with the tensors gathered.
|
931 |
+
"""
|
932 |
+
import torch.distributed as dist
|
933 |
+
|
934 |
+
# Wait for all processes to finish
|
935 |
+
dist.barrier()
|
936 |
+
|
937 |
+
for k, v in stats.items():
|
938 |
+
if isinstance(v, torch.Tensor):
|
939 |
+
dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
|
940 |
+
v /= self.accelerator.num_processes
|
941 |
+
stats[k] = v
|
942 |
+
return stats
|
943 |
+
|
944 |
+
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
|
945 |
+
if self.is_encoder_decoder:
|
946 |
+
input_data = self.data_collator(
|
947 |
+
[{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
|
948 |
+
).to(self.current_device)
|
949 |
+
|
950 |
+
decoder_inputs = self.data_collator(
|
951 |
+
[{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
|
952 |
+
).to(self.current_device)
|
953 |
+
|
954 |
+
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
|
955 |
+
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
|
956 |
+
else:
|
957 |
+
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
|
958 |
+
input_data = self.data_collator(
|
959 |
+
[{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
|
960 |
+
).to(self.current_device)
|
961 |
+
|
962 |
+
input_data.pop("labels", None) # we don't want to compute LM losses
|
963 |
+
return input_data
|
964 |
+
|
965 |
+
@PPODecorators.empty_device_cache()
|
966 |
+
def batched_forward_pass(
|
967 |
+
self,
|
968 |
+
model: PreTrainedModelWrapper,
|
969 |
+
queries: torch.Tensor,
|
970 |
+
responses: torch.Tensor,
|
971 |
+
model_inputs: dict,
|
972 |
+
return_logits: bool = False,
|
973 |
+
response_masks: Optional[torch.Tensor] = None,
|
974 |
+
):
|
975 |
+
"""
|
976 |
+
Calculate model outputs in multiple batches.
|
977 |
+
|
978 |
+
Args:
|
979 |
+
queries (`torch.LongTensor`):
|
980 |
+
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
|
981 |
+
responses (`torch.LongTensor`):
|
982 |
+
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
|
983 |
+
return_logits (`bool`, *optional*, defaults to `False`):
|
984 |
+
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
|
985 |
+
Returns:
|
986 |
+
(tuple):
|
987 |
+
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
|
988 |
+
shape (`batch_size`, `response_length`)
|
989 |
+
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
|
990 |
+
shape (`batch_size`, `response_length`)
|
991 |
+
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
|
992 |
+
"""
|
993 |
+
bs = len(queries)
|
994 |
+
fbs = self.config.mini_batch_size
|
995 |
+
all_logprobs = []
|
996 |
+
all_logits = []
|
997 |
+
all_masks = []
|
998 |
+
all_values = []
|
999 |
+
|
1000 |
+
model.eval()
|
1001 |
+
|
1002 |
+
for i in range(math.ceil(bs / fbs)):
|
1003 |
+
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
1004 |
+
query_batch = queries[i * fbs : (i + 1) * fbs]
|
1005 |
+
response_batch = responses[i * fbs : (i + 1) * fbs]
|
1006 |
+
if response_masks is not None:
|
1007 |
+
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
1008 |
+
logits, _, values = model(**input_kwargs)
|
1009 |
+
|
1010 |
+
if self.is_encoder_decoder:
|
1011 |
+
input_ids = input_kwargs["decoder_input_ids"]
|
1012 |
+
attention_mask = input_kwargs["decoder_attention_mask"]
|
1013 |
+
else:
|
1014 |
+
input_ids = input_kwargs["input_ids"]
|
1015 |
+
attention_mask = input_kwargs["attention_mask"]
|
1016 |
+
|
1017 |
+
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
1018 |
+
masks = torch.zeros_like(attention_mask)
|
1019 |
+
masks[:, :-1] = attention_mask[:, 1:]
|
1020 |
+
|
1021 |
+
for j in range(len(query_batch)):
|
1022 |
+
if self.is_encoder_decoder:
|
1023 |
+
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
|
1024 |
+
start = 1
|
1025 |
+
end = attention_mask[j, :].sum() - 1
|
1026 |
+
else:
|
1027 |
+
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
|
1028 |
+
if attention_mask[j, 0] == 0: # offset left padding
|
1029 |
+
start += attention_mask[j, :].nonzero()[0]
|
1030 |
+
end = start + len(response_batch[j])
|
1031 |
+
if response_masks is not None:
|
1032 |
+
response_masks_batch[j] = torch.cat(
|
1033 |
+
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
1034 |
+
)[1:]
|
1035 |
+
|
1036 |
+
masks[j, :start] = 0
|
1037 |
+
masks[j, end:] = 0
|
1038 |
+
if response_masks is not None:
|
1039 |
+
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
1040 |
+
|
1041 |
+
if return_logits:
|
1042 |
+
all_logits.append(logits)
|
1043 |
+
else:
|
1044 |
+
del logits
|
1045 |
+
all_values.append(values)
|
1046 |
+
all_logprobs.append(logprobs)
|
1047 |
+
all_masks.append(masks)
|
1048 |
+
|
1049 |
+
return (
|
1050 |
+
torch.cat(all_logprobs),
|
1051 |
+
torch.cat(all_logits)[:, :-1] if return_logits else None,
|
1052 |
+
torch.cat(all_values)[:, :-1],
|
1053 |
+
torch.cat(all_masks)[:, :-1],
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
@PPODecorators.empty_device_cache()
|
1057 |
+
def train_minibatch(
|
1058 |
+
self,
|
1059 |
+
old_logprobs: torch.FloatTensor,
|
1060 |
+
values: torch.FloatTensor,
|
1061 |
+
logprobs: torch.FloatTensor,
|
1062 |
+
logits: torch.FloatTensor,
|
1063 |
+
vpreds: torch.FloatTensor,
|
1064 |
+
mask: torch.LongTensor,
|
1065 |
+
advantages: torch.FloatTensor,
|
1066 |
+
returns: torch.FloatTensor,
|
1067 |
+
):
|
1068 |
+
"""
|
1069 |
+
Train one PPO minibatch
|
1070 |
+
|
1071 |
+
Args:
|
1072 |
+
logprobs (`torch.FloatTensor`):
|
1073 |
+
Log probabilities of the model, shape [mini_batch_size, response_length]
|
1074 |
+
values (`torch.FloatTensor`):
|
1075 |
+
Values of the value head, shape [mini_batch_size, response_length]
|
1076 |
+
query (`torch.LongTensor`):
|
1077 |
+
Encoded queries, shape [mini_batch_size, query_length]
|
1078 |
+
response (`torch.LongTensor`):
|
1079 |
+
Encoded responses, shape [mini_batch_size, response_length]
|
1080 |
+
model_input (`torch.LongTensor`):
|
1081 |
+
Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
|
1082 |
+
|
1083 |
+
Returns:
|
1084 |
+
train_stats (dict[str, `torch.Tensor`]):
|
1085 |
+
Dictionary of training statistics
|
1086 |
+
"""
|
1087 |
+
self.model.train()
|
1088 |
+
loss_p, loss_v, train_stats = self.loss(
|
1089 |
+
old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
|
1090 |
+
)
|
1091 |
+
loss = loss_p + loss_v
|
1092 |
+
self.accelerator.backward(loss)
|
1093 |
+
if self.config.max_grad_norm is not None:
|
1094 |
+
if self.accelerator.sync_gradients:
|
1095 |
+
self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
|
1096 |
+
self.optimizer.step()
|
1097 |
+
# we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
|
1098 |
+
# see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
|
1099 |
+
self.optimizer.zero_grad()
|
1100 |
+
return train_stats
|
1101 |
+
|
1102 |
+
def compute_rewards(
|
1103 |
+
self,
|
1104 |
+
scores: torch.FloatTensor,
|
1105 |
+
logprobs: torch.FloatTensor,
|
1106 |
+
ref_logprobs: torch.FloatTensor,
|
1107 |
+
masks: torch.LongTensor,
|
1108 |
+
):
|
1109 |
+
"""
|
1110 |
+
Compute per token rewards from scores and KL-penalty.
|
1111 |
+
|
1112 |
+
Args:
|
1113 |
+
scores (`torch.FloatTensor`):
|
1114 |
+
Scores from the reward model, shape (`batch_size`)
|
1115 |
+
logprobs (`torch.FloatTensor`):
|
1116 |
+
Log probabilities of the model, shape (`batch_size`, `response_length`)
|
1117 |
+
ref_logprobs (`torch.FloatTensor`):
|
1118 |
+
Log probabilities of the reference model, shape (`batch_size`, `response_length`)
|
1119 |
+
|
1120 |
+
Returns:
|
1121 |
+
`torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
|
1122 |
+
`torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
|
1123 |
+
`torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
|
1124 |
+
"""
|
1125 |
+
rewards, non_score_rewards, kls = [], [], []
|
1126 |
+
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
|
1127 |
+
# compute KL penalty (from difference in logprobs)
|
1128 |
+
kl = self._kl_penalty(logprob, ref_logprob)
|
1129 |
+
kls.append(kl)
|
1130 |
+
non_score_reward = -self.kl_ctl.value * kl
|
1131 |
+
non_score_rewards.append(non_score_reward)
|
1132 |
+
reward = non_score_reward.clone()
|
1133 |
+
last_non_masked_index = mask.nonzero()[-1]
|
1134 |
+
|
1135 |
+
# reward is preference model score + KL penalty
|
1136 |
+
reward[last_non_masked_index] += score
|
1137 |
+
rewards.append(reward)
|
1138 |
+
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
|
1139 |
+
|
1140 |
+
def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
|
1141 |
+
if self.config.kl_penalty == "kl":
|
1142 |
+
return logprob - ref_logprob
|
1143 |
+
|
1144 |
+
if self.config.kl_penalty == "abs":
|
1145 |
+
return (logprob - ref_logprob).abs()
|
1146 |
+
|
1147 |
+
if self.config.kl_penalty == "mse":
|
1148 |
+
return 0.5 * (logprob - ref_logprob).square()
|
1149 |
+
|
1150 |
+
if self.config.kl_penalty == "full":
|
1151 |
+
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
|
1152 |
+
return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
|
1153 |
+
|
1154 |
+
raise NotImplementedError
|
1155 |
+
|
1156 |
+
def compute_advantages(
|
1157 |
+
self,
|
1158 |
+
values: torch.FloatTensor,
|
1159 |
+
rewards: torch.FloatTensor,
|
1160 |
+
mask: torch.FloatTensor,
|
1161 |
+
):
|
1162 |
+
lastgaelam = 0
|
1163 |
+
advantages_reversed = []
|
1164 |
+
gen_len = rewards.shape[-1]
|
1165 |
+
|
1166 |
+
values = values * mask
|
1167 |
+
rewards = rewards * mask
|
1168 |
+
|
1169 |
+
if self.config.whiten_rewards:
|
1170 |
+
rewards = masked_whiten(rewards, mask, shift_mean=False)
|
1171 |
+
|
1172 |
+
for t in reversed(range(gen_len)):
|
1173 |
+
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
1174 |
+
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
|
1175 |
+
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
|
1176 |
+
advantages_reversed.append(lastgaelam)
|
1177 |
+
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
|
1178 |
+
|
1179 |
+
returns = advantages + values
|
1180 |
+
advantages = masked_whiten(advantages, mask)
|
1181 |
+
advantages = advantages.detach()
|
1182 |
+
return values, advantages, returns
|
1183 |
+
|
1184 |
+
def loss(
|
1185 |
+
self,
|
1186 |
+
old_logprobs: torch.FloatTensor,
|
1187 |
+
values: torch.FloatTensor,
|
1188 |
+
logits: torch.FloatTensor,
|
1189 |
+
vpreds: torch.FloatTensor,
|
1190 |
+
logprobs: torch.FloatTensor,
|
1191 |
+
mask: torch.LongTensor,
|
1192 |
+
advantages: torch.FloatTensor,
|
1193 |
+
returns: torch.FloatTensor,
|
1194 |
+
):
|
1195 |
+
"""
|
1196 |
+
Calculate policy and value losses.
|
1197 |
+
|
1198 |
+
Args:
|
1199 |
+
old_logprobs (`torch.FloatTensor`):
|
1200 |
+
Log probabilities of the model, shape (`batch_size`, `response_length`)
|
1201 |
+
values (`torch.FloatTensor`):
|
1202 |
+
Values of the value head, shape (`batch_size`, `response_length`)
|
1203 |
+
rewards (`torch.FloatTensor`):
|
1204 |
+
Rewards from the reward model, shape (`batch_size`, `response_length`)
|
1205 |
+
logits (`torch.FloatTensor`):
|
1206 |
+
Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
|
1207 |
+
v_pred (`torch.FloatTensor`):
|
1208 |
+
Values of the value head, shape (`batch_size`, `response_length`)
|
1209 |
+
logprobs (`torch.FloatTensor`):
|
1210 |
+
Log probabilities of the model, shape (`batch_size`, `response_length`)
|
1211 |
+
"""
|
1212 |
+
|
1213 |
+
vpredclipped = clip_by_value(
|
1214 |
+
vpreds,
|
1215 |
+
values - self.config.cliprange_value,
|
1216 |
+
values + self.config.cliprange_value,
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
vf_losses1 = (vpreds - returns) ** 2
|
1220 |
+
vf_losses2 = (vpredclipped - returns) ** 2
|
1221 |
+
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
|
1222 |
+
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
|
1223 |
+
|
1224 |
+
ratio = torch.exp(logprobs - old_logprobs)
|
1225 |
+
|
1226 |
+
pg_losses = -advantages * ratio
|
1227 |
+
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
|
1228 |
+
|
1229 |
+
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
|
1230 |
+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
|
1231 |
+
|
1232 |
+
loss = pg_loss + self.config.vf_coef * vf_loss
|
1233 |
+
|
1234 |
+
avg_ratio = masked_mean(ratio, mask).item()
|
1235 |
+
if avg_ratio > self.config.ratio_threshold:
|
1236 |
+
warnings.warn(
|
1237 |
+
f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
|
1238 |
+
)
|
1239 |
+
pg_loss = pg_loss * 0.0
|
1240 |
+
vf_loss = vf_loss * 0.0
|
1241 |
+
loss = loss * 0.0
|
1242 |
+
|
1243 |
+
entropy = masked_mean(entropy_from_logits(logits), mask)
|
1244 |
+
|
1245 |
+
approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
|
1246 |
+
policykl = masked_mean(old_logprobs - logprobs, mask)
|
1247 |
+
|
1248 |
+
return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
|
1249 |
+
value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
|
1250 |
+
|
1251 |
+
stats = dict(
|
1252 |
+
loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
|
1253 |
+
policy=dict(
|
1254 |
+
entropy=entropy.detach(),
|
1255 |
+
approxkl=approxkl.detach(),
|
1256 |
+
policykl=policykl.detach(),
|
1257 |
+
clipfrac=pg_clipfrac.detach(),
|
1258 |
+
advantages=advantages.detach(),
|
1259 |
+
advantages_mean=masked_mean(advantages, mask).detach(),
|
1260 |
+
ratio=ratio.detach(),
|
1261 |
+
),
|
1262 |
+
returns=dict(mean=return_mean.detach(), var=return_var.detach()),
|
1263 |
+
val=dict(
|
1264 |
+
vpred=masked_mean(vpreds, mask).detach(),
|
1265 |
+
error=masked_mean((vpreds - returns) ** 2, mask).detach(),
|
1266 |
+
clipfrac=vf_clipfrac.detach(),
|
1267 |
+
mean=value_mean.detach(),
|
1268 |
+
var=value_var.detach(),
|
1269 |
+
),
|
1270 |
+
)
|
1271 |
+
return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
|
1272 |
+
|
1273 |
+
def record_step_stats(self, kl_coef: float, **data):
|
1274 |
+
"""
|
1275 |
+
Record training step statistics.
|
1276 |
+
Args:
|
1277 |
+
kl_coef (`float`):
|
1278 |
+
KL coefficient
|
1279 |
+
data (`dict`):
|
1280 |
+
Dictionary of training step data
|
1281 |
+
|
1282 |
+
Returns:
|
1283 |
+
stats (`dict`):
|
1284 |
+
Dictionary of training step statistics
|
1285 |
+
"""
|
1286 |
+
mask = data.pop("masks")
|
1287 |
+
|
1288 |
+
kls = data.pop("kls")
|
1289 |
+
kl_list = ((kls) * mask).sum(axis=-1)
|
1290 |
+
mean_kl = kl_list.mean()
|
1291 |
+
mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
|
1292 |
+
|
1293 |
+
mean_non_score_reward = masked_mean(
|
1294 |
+
data["non_score_reward"], mask
|
1295 |
+
) # non_score_reward is size `batch_size`, `response_length`
|
1296 |
+
mean_scores = data["scores"].mean() # scores is size `batch_size`
|
1297 |
+
std_scores = data["scores"].std()
|
1298 |
+
|
1299 |
+
if mean_kl.item() < -1.0:
|
1300 |
+
# warn users
|
1301 |
+
warnings.warn(
|
1302 |
+
f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
|
1303 |
+
" sometimes this happens because the generation kwargs are not correctly set. Please make sure"
|
1304 |
+
" that the generation kwargs are set correctly, or review your training hyperparameters."
|
1305 |
+
)
|
1306 |
+
|
1307 |
+
stats = {
|
1308 |
+
"objective/kl": mean_kl,
|
1309 |
+
"objective/kl_dist": kl_list,
|
1310 |
+
"objective/logprobs": data["logprobs"],
|
1311 |
+
"objective/ref_logprobs": data["ref_logprobs"],
|
1312 |
+
"objective/kl_coef": kl_coef,
|
1313 |
+
"objective/entropy": mean_entropy,
|
1314 |
+
"ppo/mean_non_score_reward": mean_non_score_reward,
|
1315 |
+
"ppo/mean_scores": mean_scores,
|
1316 |
+
"ppo/std_scores": std_scores,
|
1317 |
+
}
|
1318 |
+
|
1319 |
+
# Log text properties
|
1320 |
+
query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
|
1321 |
+
response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
|
1322 |
+
|
1323 |
+
stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
|
1324 |
+
stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
|
1325 |
+
stats["tokens/queries_dist"] = query_lens.cpu().numpy()
|
1326 |
+
stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
|
1327 |
+
stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
|
1328 |
+
stats["tokens/responses_dist"] = response_lens.cpu().numpy()
|
1329 |
+
|
1330 |
+
for k, v in data["train_stats"].items():
|
1331 |
+
stats[f"ppo/{k}"] = torch.mean(v, axis=0)
|
1332 |
+
stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
|
1333 |
+
return stats
|
1334 |
+
|
1335 |
+
def log_stats(
|
1336 |
+
self,
|
1337 |
+
stats: dict,
|
1338 |
+
batch: dict,
|
1339 |
+
rewards: List[torch.FloatTensor],
|
1340 |
+
columns_to_log: typing.Iterable[str] = ("query", "response"),
|
1341 |
+
):
|
1342 |
+
"""
|
1343 |
+
A function that logs all the training stats. Call it at the end of each epoch.
|
1344 |
+
|
1345 |
+
Args:
|
1346 |
+
stats (dict[str, Any]):
|
1347 |
+
A dictionary of training stats.
|
1348 |
+
batch (dict[str, Any]):
|
1349 |
+
A dictionary of batch data, this contains the queries and responses.
|
1350 |
+
rewards (`List[torch.FloatTensor]`):
|
1351 |
+
A tensor of rewards.
|
1352 |
+
"""
|
1353 |
+
|
1354 |
+
# all gather stats
|
1355 |
+
if not isinstance(rewards, torch.Tensor):
|
1356 |
+
rewards = torch.tensor(rewards).to(self.current_device)
|
1357 |
+
rewards = self.accelerator.gather(rewards).flatten()
|
1358 |
+
|
1359 |
+
if self.config.log_with == "wandb":
|
1360 |
+
import wandb
|
1361 |
+
|
1362 |
+
if any(column_to_log not in batch.keys() for column_to_log in columns_to_log):
|
1363 |
+
raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
|
1364 |
+
|
1365 |
+
batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
|
1366 |
+
if self.is_distributed:
|
1367 |
+
gathered_batch_list = []
|
1368 |
+
for b in batch_list:
|
1369 |
+
flattened = gather_object(b)
|
1370 |
+
gathered_batch_list.append(flattened)
|
1371 |
+
batch_list = gathered_batch_list
|
1372 |
+
|
1373 |
+
# Log only if we are in the main process
|
1374 |
+
if self.accelerator.is_main_process:
|
1375 |
+
logs = {}
|
1376 |
+
|
1377 |
+
# Log stats
|
1378 |
+
if "query" not in batch.keys() and "response" not in batch.keys():
|
1379 |
+
# warn the user that the game logs will not be logged
|
1380 |
+
warnings.warn(
|
1381 |
+
"The game logs will not be logged because the batch does not contain the keys 'query' and "
|
1382 |
+
"'response'. "
|
1383 |
+
)
|
1384 |
+
elif self.config.log_with == "wandb":
|
1385 |
+
table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
|
1386 |
+
logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
|
1387 |
+
|
1388 |
+
logs.update(stats)
|
1389 |
+
|
1390 |
+
# manually cast in fp32 for bf16 torch tensors
|
1391 |
+
for k, v in logs.items():
|
1392 |
+
if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
|
1393 |
+
logs[k] = v.float()
|
1394 |
+
|
1395 |
+
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
|
1396 |
+
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
|
1397 |
+
logs["env/reward_dist"] = rewards.cpu().numpy()
|
1398 |
+
|
1399 |
+
if self.config.log_with == "tensorboard":
|
1400 |
+
# update the current step
|
1401 |
+
self.current_step += 1
|
1402 |
+
|
1403 |
+
self.accelerator.log(
|
1404 |
+
logs,
|
1405 |
+
step=self.current_step if self.config.log_with == "tensorboard" else None,
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
|
1409 |
+
"""Creates and saves a model card for a TRL model.
|
1410 |
+
|
1411 |
+
Args:
|
1412 |
+
path (`str`): The path to save the model card to.
|
1413 |
+
model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
|
1414 |
+
"""
|
1415 |
+
try:
|
1416 |
+
user = whoami()["name"]
|
1417 |
+
# handle the offline case
|
1418 |
+
except Exception:
|
1419 |
+
warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
|
1420 |
+
return
|
1421 |
+
|
1422 |
+
if not os.path.exists(path):
|
1423 |
+
os.makedirs(path)
|
1424 |
+
|
1425 |
+
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
|
1426 |
+
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
|
1427 |
+
f.write(model_card_content)
|
1428 |
+
|
1429 |
+
def _save_pretrained(self, save_directory: str) -> None:
|
1430 |
+
self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
|
1431 |
+
self.tokenizer.save_pretrained(save_directory)
|
1432 |
+
self.create_model_card(save_directory)
|
1433 |
+
|
1434 |
+
def _show_tokens(self, tokens, masks):
|
1435 |
+
from rich import print
|
1436 |
+
from rich.text import Text
|
1437 |
+
|
1438 |
+
text = Text()
|
1439 |
+
|
1440 |
+
for _i, (token, mask) in enumerate(zip(tokens, masks)):
|
1441 |
+
if mask == 1:
|
1442 |
+
text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
|
1443 |
+
text.append(" ")
|
1444 |
+
else:
|
1445 |
+
text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
|
1446 |
+
text.append(" ")
|
1447 |
+
print(text)
|
1448 |
+
|
1449 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
1450 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
1451 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
1452 |
+
config_kwargs = deepspeed_plugin.deepspeed_config
|
1453 |
+
if model is not None:
|
1454 |
+
if hasattr(model, "config"):
|
1455 |
+
hidden_size = (
|
1456 |
+
max(model.config.hidden_sizes)
|
1457 |
+
if getattr(model.config, "hidden_sizes", None)
|
1458 |
+
else getattr(model.config, "hidden_size", None)
|
1459 |
+
)
|
1460 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
1461 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
1462 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
1463 |
+
config_kwargs.update(
|
1464 |
+
{
|
1465 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
1466 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
1467 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
1468 |
+
}
|
1469 |
+
)
|
1470 |
+
|
1471 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
1472 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
1473 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
1474 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
1475 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
1476 |
+
model.eval()
|
1477 |
+
return model
|
1478 |
+
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
1479 |
+
"""
|
1480 |
+
|
1481 |
+
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
|
1482 |
+
Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
|
1483 |
+
https://github.com/openai/summarize-from-feedback
|
1484 |
+
|
1485 |
+
Attributes:
|
1486 |
+
**config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
|
1487 |
+
details.
|
1488 |
+
**model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
|
1489 |
+
Check the documentation of `PreTrainedModelWrapper` for more details.
|
1490 |
+
**ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
|
1491 |
+
transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
|
1492 |
+
for more details. If no reference model is provided, the trainer will create a reference model with the same
|
1493 |
+
architecture as the model to be optimized with shared layers.
|
1494 |
+
**tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
|
1495 |
+
data. Check the documentation of `transformers.PreTrainedTokenizer` and
|
1496 |
+
`transformers.PreTrainedTokenizerFast` for more details.
|
1497 |
+
**dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
|
1498 |
+
Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
|
1499 |
+
created outside the trainer users needs to design their own dataloader and make sure the batch
|
1500 |
+
size that is used is the same as the one specified in the configuration object.
|
1501 |
+
**optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
|
1502 |
+
provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
|
1503 |
+
object.
|
1504 |
+
**data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
|
1505 |
+
passed along the dataloader
|
1506 |
+
**num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
|
1507 |
+
model, if no reference model is passed. If no number is provided, all the layers will be shared.
|
1508 |
+
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
|
1509 |
+
|
1510 |
+
"""
|
1511 |
+
def __init__(
|
1512 |
+
self,
|
1513 |
+
config = None,
|
1514 |
+
model = None,
|
1515 |
+
ref_model = None,
|
1516 |
+
tokenizer = None,
|
1517 |
+
dataset = None,
|
1518 |
+
optimizer = None,
|
1519 |
+
data_collator = None,
|
1520 |
+
num_shared_layers = None,
|
1521 |
+
lr_scheduler = None,
|
1522 |
+
**kwargs
|
1523 |
+
):
|
1524 |
+
if args is None: args = UnslothPPOConfig()
|
1525 |
+
_output_logits = False
|
1526 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1527 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1528 |
+
if _output_logits:
|
1529 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1530 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1531 |
+
pass
|
1532 |
+
else:
|
1533 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1534 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1535 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1536 |
+
max_seq_length = model.max_seq_length
|
1537 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1538 |
+
if model is not None and hasattr(model, 'for_training'):
|
1539 |
+
model.for_training()
|
1540 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1541 |
+
if 'processing_class' in locals():
|
1542 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1543 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1544 |
+
other_metrics = []
|
1545 |
+
|
1546 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1547 |
+
PatchRLStatistics('ppo_trainer', other_metrics)
|
1548 |
+
|
1549 |
+
super().__init__(
|
1550 |
+
config = config,
|
1551 |
+
model = model,
|
1552 |
+
ref_model = ref_model,
|
1553 |
+
tokenizer = tokenizer,
|
1554 |
+
dataset = dataset,
|
1555 |
+
optimizer = optimizer,
|
1556 |
+
data_collator = data_collator,
|
1557 |
+
num_shared_layers = num_shared_layers,
|
1558 |
+
lr_scheduler = lr_scheduler,**kwargs)
|
1559 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1560 |
+
self.neftune_hook_handle.remove()
|
1561 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1562 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1563 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1564 |
+
pass
|
1565 |
+
|
1566 |
+
pass
|
compilefcach/UnslothRewardTrainer.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.6.8
|
3 |
+
2025.6.12
|
4 |
+
4.53.0
|
5 |
+
0.8.6
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.reward_trainer import (Any, Callable, DataCollator, Dataset, Dict, EvalPrediction, FrozenInstanceError, List, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, TrainingArguments, Tuple, Union, compute_accuracy, inspect, is_peft_available, nested_detach, nn, prepare_model_for_kbit_training, replace, torch, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothRewardConfig(RewardConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
|
47 |
+
|
48 |
+
Using [`HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int`, *optional*, defaults to `None`):
|
54 |
+
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
55 |
+
gradient_checkpointing (`bool`, *optional*, defaults to `True`):
|
56 |
+
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
57 |
+
|
58 |
+
"""
|
59 |
+
vllm_sampling_params: Optional[Any] = field(
|
60 |
+
default = None,
|
61 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
62 |
+
)
|
63 |
+
unsloth_num_chunks : Optional[int] = field(
|
64 |
+
default = -1,
|
65 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
66 |
+
)
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
output_dir = None,
|
70 |
+
overwrite_output_dir = None,
|
71 |
+
do_train = False,
|
72 |
+
do_eval = False,
|
73 |
+
do_predict = False,
|
74 |
+
eval_strategy = 'no',
|
75 |
+
prediction_loss_only = False,
|
76 |
+
per_device_train_batch_size = 4,
|
77 |
+
per_device_eval_batch_size = 4,
|
78 |
+
per_gpu_train_batch_size = None,
|
79 |
+
per_gpu_eval_batch_size = None,
|
80 |
+
gradient_accumulation_steps = 2,
|
81 |
+
eval_accumulation_steps = 2,
|
82 |
+
eval_delay = 0,
|
83 |
+
torch_empty_cache_steps = 250,
|
84 |
+
learning_rate = 5e-05,
|
85 |
+
weight_decay = 0.01,
|
86 |
+
adam_beta1 = 0.9,
|
87 |
+
adam_beta2 = 0.999,
|
88 |
+
adam_epsilon = 1e-08,
|
89 |
+
max_grad_norm = 1.0,
|
90 |
+
num_train_epochs = 3.0,
|
91 |
+
max_steps = -1,
|
92 |
+
lr_scheduler_type = 'linear',
|
93 |
+
warmup_ratio = 0.1,
|
94 |
+
warmup_steps = 0,
|
95 |
+
log_level = 'passive',
|
96 |
+
log_level_replica = 'warning',
|
97 |
+
log_on_each_node = True,
|
98 |
+
logging_dir = None,
|
99 |
+
logging_strategy = 'steps',
|
100 |
+
logging_first_step = False,
|
101 |
+
logging_steps = 1,
|
102 |
+
logging_nan_inf_filter = False,
|
103 |
+
save_strategy = 'steps',
|
104 |
+
save_steps = 500,
|
105 |
+
save_total_limit = None,
|
106 |
+
save_safetensors = True,
|
107 |
+
save_on_each_node = False,
|
108 |
+
save_only_model = False,
|
109 |
+
restore_callback_states_from_checkpoint = False,
|
110 |
+
no_cuda = False,
|
111 |
+
use_cpu = False,
|
112 |
+
use_mps_device = False,
|
113 |
+
seed = 3407,
|
114 |
+
data_seed = 3407,
|
115 |
+
jit_mode_eval = False,
|
116 |
+
use_ipex = False,
|
117 |
+
bf16 = False,
|
118 |
+
fp16 = False,
|
119 |
+
fp16_opt_level = 'O1',
|
120 |
+
half_precision_backend = 'auto',
|
121 |
+
bf16_full_eval = False,
|
122 |
+
fp16_full_eval = False,
|
123 |
+
tf32 = None,
|
124 |
+
local_rank = -1,
|
125 |
+
ddp_backend = None,
|
126 |
+
tpu_num_cores = None,
|
127 |
+
tpu_metrics_debug = False,
|
128 |
+
debug = '',
|
129 |
+
dataloader_drop_last = False,
|
130 |
+
eval_steps = None,
|
131 |
+
dataloader_num_workers = 0,
|
132 |
+
dataloader_prefetch_factor = None,
|
133 |
+
past_index = -1,
|
134 |
+
run_name = None,
|
135 |
+
disable_tqdm = None,
|
136 |
+
remove_unused_columns = True,
|
137 |
+
label_names = None,
|
138 |
+
load_best_model_at_end = False,
|
139 |
+
metric_for_best_model = None,
|
140 |
+
greater_is_better = None,
|
141 |
+
ignore_data_skip = False,
|
142 |
+
fsdp = '',
|
143 |
+
fsdp_min_num_params = 0,
|
144 |
+
fsdp_config = None,
|
145 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
146 |
+
accelerator_config = None,
|
147 |
+
deepspeed = None,
|
148 |
+
label_smoothing_factor = 0.0,
|
149 |
+
optim = 'adamw_8bit',
|
150 |
+
optim_args = None,
|
151 |
+
adafactor = False,
|
152 |
+
group_by_length = False,
|
153 |
+
length_column_name = 'length',
|
154 |
+
report_to = None,
|
155 |
+
ddp_find_unused_parameters = None,
|
156 |
+
ddp_bucket_cap_mb = None,
|
157 |
+
ddp_broadcast_buffers = None,
|
158 |
+
dataloader_pin_memory = True,
|
159 |
+
dataloader_persistent_workers = False,
|
160 |
+
skip_memory_metrics = True,
|
161 |
+
use_legacy_prediction_loop = False,
|
162 |
+
push_to_hub = False,
|
163 |
+
resume_from_checkpoint = None,
|
164 |
+
hub_model_id = None,
|
165 |
+
hub_strategy = 'every_save',
|
166 |
+
hub_token = None,
|
167 |
+
hub_private_repo = None,
|
168 |
+
hub_always_push = False,
|
169 |
+
hub_revision = None,
|
170 |
+
gradient_checkpointing = False,
|
171 |
+
gradient_checkpointing_kwargs = None,
|
172 |
+
include_inputs_for_metrics = False,
|
173 |
+
eval_do_concat_batches = True,
|
174 |
+
fp16_backend = 'auto',
|
175 |
+
push_to_hub_model_id = None,
|
176 |
+
push_to_hub_organization = None,
|
177 |
+
push_to_hub_token = None,
|
178 |
+
mp_parameters = '',
|
179 |
+
auto_find_batch_size = False,
|
180 |
+
full_determinism = False,
|
181 |
+
torchdynamo = None,
|
182 |
+
ray_scope = 'last',
|
183 |
+
ddp_timeout = 1800,
|
184 |
+
torch_compile = False,
|
185 |
+
torch_compile_backend = None,
|
186 |
+
torch_compile_mode = None,
|
187 |
+
include_tokens_per_second = False,
|
188 |
+
include_num_input_tokens_seen = False,
|
189 |
+
neftune_noise_alpha = None,
|
190 |
+
optim_target_modules = None,
|
191 |
+
batch_eval_metrics = False,
|
192 |
+
eval_on_start = False,
|
193 |
+
use_liger_kernel = False,
|
194 |
+
liger_kernel_config = None,
|
195 |
+
eval_use_gather_object = False,
|
196 |
+
average_tokens_across_devices = False,
|
197 |
+
max_length = None,
|
198 |
+
vllm_sampling_params = None,
|
199 |
+
unsloth_num_chunks = -1,
|
200 |
+
**kwargs,
|
201 |
+
):
|
202 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
203 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
204 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
205 |
+
output_dir = 'unsloth_training_checkpoints'
|
206 |
+
save_strategy = 'no'
|
207 |
+
|
208 |
+
super().__init__(
|
209 |
+
output_dir = output_dir,
|
210 |
+
overwrite_output_dir = overwrite_output_dir,
|
211 |
+
do_train = do_train,
|
212 |
+
do_eval = do_eval,
|
213 |
+
do_predict = do_predict,
|
214 |
+
eval_strategy = eval_strategy,
|
215 |
+
prediction_loss_only = prediction_loss_only,
|
216 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
217 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
218 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
219 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
220 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
221 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
222 |
+
eval_delay = eval_delay,
|
223 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
224 |
+
learning_rate = learning_rate,
|
225 |
+
weight_decay = weight_decay,
|
226 |
+
adam_beta1 = adam_beta1,
|
227 |
+
adam_beta2 = adam_beta2,
|
228 |
+
adam_epsilon = adam_epsilon,
|
229 |
+
max_grad_norm = max_grad_norm,
|
230 |
+
num_train_epochs = num_train_epochs,
|
231 |
+
max_steps = max_steps,
|
232 |
+
lr_scheduler_type = lr_scheduler_type,
|
233 |
+
warmup_ratio = warmup_ratio,
|
234 |
+
warmup_steps = warmup_steps,
|
235 |
+
log_level = log_level,
|
236 |
+
log_level_replica = log_level_replica,
|
237 |
+
log_on_each_node = log_on_each_node,
|
238 |
+
logging_dir = logging_dir,
|
239 |
+
logging_strategy = logging_strategy,
|
240 |
+
logging_first_step = logging_first_step,
|
241 |
+
logging_steps = logging_steps,
|
242 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
243 |
+
save_strategy = save_strategy,
|
244 |
+
save_steps = save_steps,
|
245 |
+
save_total_limit = save_total_limit,
|
246 |
+
save_safetensors = save_safetensors,
|
247 |
+
save_on_each_node = save_on_each_node,
|
248 |
+
save_only_model = save_only_model,
|
249 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
250 |
+
no_cuda = no_cuda,
|
251 |
+
use_cpu = use_cpu,
|
252 |
+
use_mps_device = use_mps_device,
|
253 |
+
seed = seed,
|
254 |
+
data_seed = data_seed,
|
255 |
+
jit_mode_eval = jit_mode_eval,
|
256 |
+
use_ipex = use_ipex,
|
257 |
+
bf16 = bf16,
|
258 |
+
fp16 = fp16,
|
259 |
+
fp16_opt_level = fp16_opt_level,
|
260 |
+
half_precision_backend = half_precision_backend,
|
261 |
+
bf16_full_eval = bf16_full_eval,
|
262 |
+
fp16_full_eval = fp16_full_eval,
|
263 |
+
tf32 = tf32,
|
264 |
+
local_rank = local_rank,
|
265 |
+
ddp_backend = ddp_backend,
|
266 |
+
tpu_num_cores = tpu_num_cores,
|
267 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
268 |
+
debug = debug,
|
269 |
+
dataloader_drop_last = dataloader_drop_last,
|
270 |
+
eval_steps = eval_steps,
|
271 |
+
dataloader_num_workers = dataloader_num_workers,
|
272 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
273 |
+
past_index = past_index,
|
274 |
+
run_name = run_name,
|
275 |
+
disable_tqdm = disable_tqdm,
|
276 |
+
remove_unused_columns = remove_unused_columns,
|
277 |
+
label_names = label_names,
|
278 |
+
load_best_model_at_end = load_best_model_at_end,
|
279 |
+
metric_for_best_model = metric_for_best_model,
|
280 |
+
greater_is_better = greater_is_better,
|
281 |
+
ignore_data_skip = ignore_data_skip,
|
282 |
+
fsdp = fsdp,
|
283 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
284 |
+
fsdp_config = fsdp_config,
|
285 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
286 |
+
accelerator_config = accelerator_config,
|
287 |
+
deepspeed = deepspeed,
|
288 |
+
label_smoothing_factor = label_smoothing_factor,
|
289 |
+
optim = optim,
|
290 |
+
optim_args = optim_args,
|
291 |
+
adafactor = adafactor,
|
292 |
+
group_by_length = group_by_length,
|
293 |
+
length_column_name = length_column_name,
|
294 |
+
report_to = report_to,
|
295 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
296 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
297 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
298 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
299 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
300 |
+
skip_memory_metrics = skip_memory_metrics,
|
301 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
302 |
+
push_to_hub = push_to_hub,
|
303 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
304 |
+
hub_model_id = hub_model_id,
|
305 |
+
hub_strategy = hub_strategy,
|
306 |
+
hub_token = hub_token,
|
307 |
+
hub_private_repo = hub_private_repo,
|
308 |
+
hub_always_push = hub_always_push,
|
309 |
+
hub_revision = hub_revision,
|
310 |
+
gradient_checkpointing = gradient_checkpointing,
|
311 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
312 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
313 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
314 |
+
fp16_backend = fp16_backend,
|
315 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
316 |
+
push_to_hub_organization = push_to_hub_organization,
|
317 |
+
push_to_hub_token = push_to_hub_token,
|
318 |
+
mp_parameters = mp_parameters,
|
319 |
+
auto_find_batch_size = auto_find_batch_size,
|
320 |
+
full_determinism = full_determinism,
|
321 |
+
torchdynamo = torchdynamo,
|
322 |
+
ray_scope = ray_scope,
|
323 |
+
ddp_timeout = ddp_timeout,
|
324 |
+
torch_compile = torch_compile,
|
325 |
+
torch_compile_backend = torch_compile_backend,
|
326 |
+
torch_compile_mode = torch_compile_mode,
|
327 |
+
include_tokens_per_second = include_tokens_per_second,
|
328 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
329 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
330 |
+
optim_target_modules = optim_target_modules,
|
331 |
+
batch_eval_metrics = batch_eval_metrics,
|
332 |
+
eval_on_start = eval_on_start,
|
333 |
+
use_liger_kernel = use_liger_kernel,
|
334 |
+
liger_kernel_config = liger_kernel_config,
|
335 |
+
eval_use_gather_object = eval_use_gather_object,
|
336 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
337 |
+
max_length = max_length,**kwargs)
|
338 |
+
self.vllm_sampling_params = vllm_sampling_params
|
339 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
340 |
+
pass
|
341 |
+
|
342 |
+
class _UnslothRewardTrainer(Trainer):
|
343 |
+
r""""""
|
344 |
+
|
345 |
+
_tag_names = ["trl", "reward-trainer"]
|
346 |
+
|
347 |
+
def __init__(
|
348 |
+
self,
|
349 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
350 |
+
args: Optional[RewardConfig] = None,
|
351 |
+
data_collator: Optional[DataCollator] = None,
|
352 |
+
train_dataset: Optional[Dataset] = None,
|
353 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
354 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
355 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
356 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
357 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
358 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
359 |
+
None,
|
360 |
+
None,
|
361 |
+
),
|
362 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
363 |
+
max_length: Optional[int] = None,
|
364 |
+
peft_config: Optional[Dict] = None,
|
365 |
+
):
|
366 |
+
"""
|
367 |
+
Initialize RewardTrainer.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
model (`transformers.PreTrainedModel`):
|
371 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
372 |
+
args (`RewardConfig`):
|
373 |
+
The arguments to use for training.
|
374 |
+
data_collator (`transformers.DataCollator`):
|
375 |
+
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
376 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
377 |
+
train_dataset (`datasets.Dataset`):
|
378 |
+
The dataset to use for training.
|
379 |
+
eval_dataset (`datasets.Dataset`):
|
380 |
+
The dataset to use for evaluation.
|
381 |
+
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
382 |
+
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
383 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
384 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
385 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):
|
386 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
387 |
+
callbacks (`List[transformers.TrainerCallback]`):
|
388 |
+
The callbacks to use for training.
|
389 |
+
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
390 |
+
The optimizer and scheduler to use for training.
|
391 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
392 |
+
The function to use to preprocess the logits before computing the metrics.
|
393 |
+
max_length (`int`, defaults to `None`):
|
394 |
+
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
395 |
+
peft_config (`Dict`, defaults to `None`):
|
396 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
397 |
+
"""
|
398 |
+
if type(args) == TrainingArguments:
|
399 |
+
warnings.warn(
|
400 |
+
"Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
|
401 |
+
FutureWarning,
|
402 |
+
)
|
403 |
+
if max_length is not None:
|
404 |
+
warnings.warn(
|
405 |
+
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
|
406 |
+
FutureWarning,
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
if max_length is not None and args.max_length is not None:
|
410 |
+
raise ValueError(
|
411 |
+
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
|
412 |
+
)
|
413 |
+
if max_length is not None and args.max_length is None:
|
414 |
+
warnings.warn(
|
415 |
+
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
|
416 |
+
FutureWarning,
|
417 |
+
)
|
418 |
+
if not is_peft_available() and peft_config is not None:
|
419 |
+
raise ValueError(
|
420 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
421 |
+
)
|
422 |
+
elif is_peft_available() and peft_config is not None:
|
423 |
+
if not isinstance(model, PeftModel):
|
424 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
425 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
426 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
427 |
+
)
|
428 |
+
|
429 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
430 |
+
|
431 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
432 |
+
warnings.warn(
|
433 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
434 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
435 |
+
)
|
436 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
437 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
438 |
+
|
439 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
440 |
+
|
441 |
+
model = model
|
442 |
+
|
443 |
+
if compute_metrics is None:
|
444 |
+
compute_metrics = compute_accuracy
|
445 |
+
|
446 |
+
if data_collator is None:
|
447 |
+
if tokenizer is None:
|
448 |
+
raise ValueError(
|
449 |
+
"max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
|
450 |
+
)
|
451 |
+
if type(args) == TrainingArguments:
|
452 |
+
if max_length is None:
|
453 |
+
warnings.warn(
|
454 |
+
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
|
455 |
+
" It will be set to `512` by default, but you should do it yourself in the future.",
|
456 |
+
UserWarning,
|
457 |
+
)
|
458 |
+
max_length = 512
|
459 |
+
else:
|
460 |
+
if max_length is None and args.max_length is None:
|
461 |
+
warnings.warn(
|
462 |
+
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
|
463 |
+
" It will be set to `512` by default, but you should do it yourself in the future.",
|
464 |
+
UserWarning,
|
465 |
+
)
|
466 |
+
max_length = 512
|
467 |
+
if max_length is None and args.max_length is not None:
|
468 |
+
max_length = args.max_length
|
469 |
+
|
470 |
+
data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
|
471 |
+
|
472 |
+
if args.remove_unused_columns:
|
473 |
+
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
474 |
+
args.remove_unused_columns = False
|
475 |
+
except FrozenInstanceError:
|
476 |
+
args = replace(args, remove_unused_columns=False)
|
477 |
+
# warn users
|
478 |
+
warnings.warn(
|
479 |
+
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
480 |
+
" we have set it for you, but you should do it yourself in the future.",
|
481 |
+
UserWarning,
|
482 |
+
)
|
483 |
+
|
484 |
+
self.use_reward_data_collator = True
|
485 |
+
else:
|
486 |
+
self.use_reward_data_collator = False
|
487 |
+
super().__init__(
|
488 |
+
model=model,
|
489 |
+
args=args,
|
490 |
+
data_collator=data_collator,
|
491 |
+
train_dataset=train_dataset,
|
492 |
+
eval_dataset=eval_dataset,
|
493 |
+
tokenizer=tokenizer,
|
494 |
+
model_init=model_init,
|
495 |
+
compute_metrics=compute_metrics,
|
496 |
+
callbacks=callbacks,
|
497 |
+
optimizers=optimizers,
|
498 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
499 |
+
)
|
500 |
+
|
501 |
+
# Add tags for models that have been loaded with the correct transformers version
|
502 |
+
if hasattr(self.model, "add_model_tags"):
|
503 |
+
self.model.add_model_tags(self._tag_names)
|
504 |
+
|
505 |
+
def compute_loss(
|
506 |
+
self,
|
507 |
+
model: Union[PreTrainedModel, nn.Module],
|
508 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
509 |
+
return_outputs=False,
|
510 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
511 |
+
if not self.use_reward_data_collator:
|
512 |
+
warnings.warn(
|
513 |
+
"The current compute_loss is implemented for RewardDataCollatorWithPadding,"
|
514 |
+
" if you are using a custom data collator make sure you know what you are doing or"
|
515 |
+
" implement your own compute_loss method."
|
516 |
+
)
|
517 |
+
rewards_chosen = model(
|
518 |
+
input_ids=inputs["input_ids_chosen"],
|
519 |
+
attention_mask=inputs["attention_mask_chosen"],
|
520 |
+
return_dict=True,
|
521 |
+
)["logits"]
|
522 |
+
rewards_rejected = model(
|
523 |
+
input_ids=inputs["input_ids_rejected"],
|
524 |
+
attention_mask=inputs["attention_mask_rejected"],
|
525 |
+
return_dict=True,
|
526 |
+
)["logits"]
|
527 |
+
# calculate loss, optionally modulate with margin
|
528 |
+
if "margin" in inputs:
|
529 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
530 |
+
else:
|
531 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
532 |
+
|
533 |
+
if return_outputs:
|
534 |
+
return loss, {
|
535 |
+
"rewards_chosen": rewards_chosen,
|
536 |
+
"rewards_rejected": rewards_rejected,
|
537 |
+
}
|
538 |
+
return loss
|
539 |
+
|
540 |
+
def prediction_step(
|
541 |
+
self,
|
542 |
+
model: Union[PreTrainedModel, nn.Module],
|
543 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
544 |
+
prediction_loss_only: bool,
|
545 |
+
ignore_keys: Optional[List[str]] = None,
|
546 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
547 |
+
inputs = self._prepare_inputs(inputs)
|
548 |
+
if ignore_keys is None:
|
549 |
+
if hasattr(self.model, "config"):
|
550 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
551 |
+
else:
|
552 |
+
ignore_keys = []
|
553 |
+
|
554 |
+
with torch.no_grad():
|
555 |
+
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
556 |
+
|
557 |
+
if prediction_loss_only:
|
558 |
+
return (loss, None, None)
|
559 |
+
|
560 |
+
loss = loss.detach()
|
561 |
+
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
562 |
+
logits = nested_detach(logits)
|
563 |
+
# Stack accepted against rejected, mean over logits
|
564 |
+
# and softmax to get preferences between accepted and rejected to sum to 1
|
565 |
+
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
566 |
+
|
567 |
+
labels = torch.zeros(logits.shape[0])
|
568 |
+
labels = self._prepare_inputs(labels)
|
569 |
+
|
570 |
+
return loss, logits, labels
|
571 |
+
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
572 |
+
"""
|
573 |
+
|
574 |
+
The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
|
575 |
+
`transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
|
576 |
+
an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
|
577 |
+
of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
|
578 |
+
predict which example in the pair is more relevant to the task at hand.
|
579 |
+
|
580 |
+
The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
|
581 |
+
if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
|
582 |
+
- `input_ids_chosen`
|
583 |
+
- `attention_mask_chosen`
|
584 |
+
- `input_ids_rejected`
|
585 |
+
- `attention_mask_rejected`
|
586 |
+
|
587 |
+
Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
|
588 |
+
loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
|
589 |
+
If you don't pass a margin, no margin will be used.
|
590 |
+
|
591 |
+
"""
|
592 |
+
def __init__(
|
593 |
+
self,
|
594 |
+
model = None,
|
595 |
+
args = None,
|
596 |
+
data_collator = None,
|
597 |
+
train_dataset = None,
|
598 |
+
eval_dataset = None,
|
599 |
+
tokenizer = None,
|
600 |
+
model_init = None,
|
601 |
+
compute_metrics = None,
|
602 |
+
callbacks = None,
|
603 |
+
preprocess_logits_for_metrics = None,
|
604 |
+
max_length = None,
|
605 |
+
peft_config = None,
|
606 |
+
**kwargs
|
607 |
+
):
|
608 |
+
if args is None: args = UnslothRewardConfig()
|
609 |
+
use_bf16 = getattr(args, 'bf16', False)
|
610 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
611 |
+
use_fp16 = getattr(args, 'fp16', False)
|
612 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
613 |
+
force_float32 = False
|
614 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
615 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
616 |
+
force_float32 = True
|
617 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
618 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
619 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
620 |
+
from unsloth_zoo.utils import _get_dtype
|
621 |
+
dtype = _get_dtype(dtype)
|
622 |
+
float16 = dtype == torch.float16
|
623 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
624 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
625 |
+
if force_float32:
|
626 |
+
args.fp16 = False
|
627 |
+
args.bf16 = False
|
628 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
629 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
630 |
+
args.fp16 = float16
|
631 |
+
args.bf16 = not float16
|
632 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
633 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
634 |
+
args.eval_strategy = 'steps'
|
635 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
636 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
637 |
+
if ga_steps is not None and ga_steps > 1:
|
638 |
+
from transformers import __version__ as transformers_version
|
639 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
640 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
641 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
642 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
643 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
644 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
645 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
646 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
647 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
648 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
649 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
650 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
651 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
652 |
+
if force_float32:
|
653 |
+
args.bf16_full_eval = False
|
654 |
+
args.fp16_full_eval = False
|
655 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
656 |
+
args.bf16_full_eval = True
|
657 |
+
args.fp16_full_eval = False
|
658 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
659 |
+
args.bf16_full_eval = args.bf16
|
660 |
+
args.fp16_full_eval = args.fp16
|
661 |
+
_output_logits = False
|
662 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
663 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
664 |
+
if _output_logits:
|
665 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
666 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
667 |
+
pass
|
668 |
+
else:
|
669 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
670 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
671 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
672 |
+
max_seq_length = model.max_seq_length
|
673 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
674 |
+
if model is not None and hasattr(model, 'for_training'):
|
675 |
+
model.for_training()
|
676 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
677 |
+
if 'processing_class' in locals():
|
678 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
679 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
680 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
681 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
682 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
683 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
684 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
685 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
686 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
687 |
+
else:
|
688 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
689 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
690 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
691 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
692 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
693 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
694 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
695 |
+
else:
|
696 |
+
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
697 |
+
other_metrics = []
|
698 |
+
|
699 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
700 |
+
PatchRLStatistics('reward_trainer', other_metrics)
|
701 |
+
|
702 |
+
super().__init__(
|
703 |
+
model = model,
|
704 |
+
args = args,
|
705 |
+
data_collator = data_collator,
|
706 |
+
train_dataset = train_dataset,
|
707 |
+
eval_dataset = eval_dataset,
|
708 |
+
tokenizer = tokenizer,
|
709 |
+
model_init = model_init,
|
710 |
+
compute_metrics = compute_metrics,
|
711 |
+
callbacks = callbacks,
|
712 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
713 |
+
max_length = max_length,
|
714 |
+
peft_config = peft_config,**kwargs)
|
715 |
+
if hasattr(self, 'neftune_hook_handle'):
|
716 |
+
self.neftune_hook_handle.remove()
|
717 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
718 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
719 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
720 |
+
pass
|
721 |
+
|
722 |
+
pass
|
compilefcach/__pycache__/UnslothCPOTrainer.cpython-311.pyc
ADDED
Binary file (68.7 kB). View file
|
|
compilefcach/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
ADDED
Binary file (38.7 kB). View file
|
|
compilefcach/__pycache__/UnslothKTOTrainer.cpython-311.pyc
ADDED
Binary file (81.8 kB). View file
|
|
compilefcach/__pycache__/UnslothORPOTrainer.cpython-311.pyc
ADDED
Binary file (69.8 kB). View file
|
|
compilefcach/__pycache__/UnslothPPOTrainer.cpython-311.pyc
ADDED
Binary file (83.5 kB). View file
|
|
compilefcach/__pycache__/UnslothRewardTrainer.cpython-311.pyc
ADDED
Binary file (33.1 kB). View file
|
|