starsofchance commited on
Commit
7e4d3fd
·
verified ·
1 Parent(s): a105bb3

Uploaded lora adapters after finetuning on primevul

Browse files
compilefcach/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, List, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainerCallback, Tuple, Union, defaultdict, disable_dropout_in_model, inspect, is_peft_available, is_torch_fx_proxy, is_wandb_available, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothCPOConfig(CPOConfig):
44
+ """
45
+
46
+ CPOConfig collects all training arguments related to the [`CPOTrainer`] class.
47
+
48
+ Using [`HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int`, defaults to `None`):
54
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
55
+ max_prompt_length (`int`, defaults to `None`):
56
+ The maximum length of the prompt. This argument is required if you want to use the default data collator.
57
+ max_target_length (`int`, defaults to `None`):
58
+ The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
59
+ beta (`float`, defaults to 0.1):
60
+ The beta factor in CPO loss.
61
+ label_smoothing (`float`, defaults to 0):
62
+ The label smoothing factor. This argument is required if you want to use the default data collator.
63
+ loss_type (`str`, defaults to `sigmoid`):
64
+ The type of loss to use. This argument is required if you want to use the default data collator.
65
+ label_pad_token_id (`int`, defaults to `-100`):
66
+ The label pad token id. This argument is required if you want to use the default data collator.
67
+ padding_value (`int`, defaults to `None`):
68
+ The padding value if it is different to the tokenizer's pad_token_id.
69
+ truncation_mode (`str`, defaults to `keep_end`):
70
+ The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
71
+ generate_during_eval (`bool`, defaults to `False`):
72
+ Whether to sample and log generations during evaluation step.
73
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
74
+ If no model is provided, we need to know if the model_init returns an encoder-decoder.
75
+ disable_dropout (`bool`, defaults to `True`):
76
+ Whether or not to disable dropouts in `model`.
77
+ model_init_kwargs (`Optional[Dict]`, *optional*):
78
+ Dict of Optional kwargs to pass when instantiating the model from a string
79
+ dataset_num_proc (`Optional[int]`, *optional*):
80
+ The number of workers to use to tokenize the data. Defaults to None.
81
+
82
+ """
83
+ vllm_sampling_params: Optional[Any] = field(
84
+ default = None,
85
+ metadata = {'help': 'vLLM SamplingParams'},
86
+ )
87
+ unsloth_num_chunks : Optional[int] = field(
88
+ default = -1,
89
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
+ )
91
+ def __init__(
92
+ self,
93
+ output_dir = None,
94
+ overwrite_output_dir = None,
95
+ do_train = False,
96
+ do_eval = False,
97
+ do_predict = False,
98
+ eval_strategy = 'no',
99
+ prediction_loss_only = False,
100
+ per_device_train_batch_size = 4,
101
+ per_device_eval_batch_size = 4,
102
+ per_gpu_train_batch_size = None,
103
+ per_gpu_eval_batch_size = None,
104
+ gradient_accumulation_steps = 2,
105
+ eval_accumulation_steps = 2,
106
+ eval_delay = 0,
107
+ torch_empty_cache_steps = 250,
108
+ learning_rate = 5e-05,
109
+ weight_decay = 0.01,
110
+ adam_beta1 = 0.9,
111
+ adam_beta2 = 0.999,
112
+ adam_epsilon = 1e-08,
113
+ max_grad_norm = 1.0,
114
+ num_train_epochs = 3.0,
115
+ max_steps = -1,
116
+ lr_scheduler_type = 'linear',
117
+ warmup_ratio = 0.1,
118
+ warmup_steps = 0,
119
+ log_level = 'passive',
120
+ log_level_replica = 'warning',
121
+ log_on_each_node = True,
122
+ logging_dir = None,
123
+ logging_strategy = 'steps',
124
+ logging_first_step = False,
125
+ logging_steps = 1,
126
+ logging_nan_inf_filter = False,
127
+ save_strategy = 'steps',
128
+ save_steps = 500,
129
+ save_total_limit = None,
130
+ save_safetensors = True,
131
+ save_on_each_node = False,
132
+ save_only_model = False,
133
+ restore_callback_states_from_checkpoint = False,
134
+ no_cuda = False,
135
+ use_cpu = False,
136
+ use_mps_device = False,
137
+ seed = 3407,
138
+ data_seed = 3407,
139
+ jit_mode_eval = False,
140
+ use_ipex = False,
141
+ bf16 = False,
142
+ fp16 = False,
143
+ fp16_opt_level = 'O1',
144
+ half_precision_backend = 'auto',
145
+ bf16_full_eval = False,
146
+ fp16_full_eval = False,
147
+ tf32 = None,
148
+ local_rank = -1,
149
+ ddp_backend = None,
150
+ tpu_num_cores = None,
151
+ tpu_metrics_debug = False,
152
+ debug = '',
153
+ dataloader_drop_last = False,
154
+ eval_steps = None,
155
+ dataloader_num_workers = 0,
156
+ dataloader_prefetch_factor = None,
157
+ past_index = -1,
158
+ run_name = None,
159
+ disable_tqdm = None,
160
+ remove_unused_columns = True,
161
+ label_names = None,
162
+ load_best_model_at_end = False,
163
+ metric_for_best_model = None,
164
+ greater_is_better = None,
165
+ ignore_data_skip = False,
166
+ fsdp = '',
167
+ fsdp_min_num_params = 0,
168
+ fsdp_config = None,
169
+ fsdp_transformer_layer_cls_to_wrap = None,
170
+ accelerator_config = None,
171
+ deepspeed = None,
172
+ label_smoothing_factor = 0.0,
173
+ optim = 'adamw_8bit',
174
+ optim_args = None,
175
+ adafactor = False,
176
+ group_by_length = False,
177
+ length_column_name = 'length',
178
+ report_to = None,
179
+ ddp_find_unused_parameters = None,
180
+ ddp_bucket_cap_mb = None,
181
+ ddp_broadcast_buffers = None,
182
+ dataloader_pin_memory = True,
183
+ dataloader_persistent_workers = False,
184
+ skip_memory_metrics = True,
185
+ use_legacy_prediction_loop = False,
186
+ push_to_hub = False,
187
+ resume_from_checkpoint = None,
188
+ hub_model_id = None,
189
+ hub_strategy = 'every_save',
190
+ hub_token = None,
191
+ hub_private_repo = None,
192
+ hub_always_push = False,
193
+ hub_revision = None,
194
+ gradient_checkpointing = False,
195
+ gradient_checkpointing_kwargs = None,
196
+ include_inputs_for_metrics = False,
197
+ eval_do_concat_batches = True,
198
+ fp16_backend = 'auto',
199
+ push_to_hub_model_id = None,
200
+ push_to_hub_organization = None,
201
+ push_to_hub_token = None,
202
+ mp_parameters = '',
203
+ auto_find_batch_size = False,
204
+ full_determinism = False,
205
+ torchdynamo = None,
206
+ ray_scope = 'last',
207
+ ddp_timeout = 1800,
208
+ torch_compile = False,
209
+ torch_compile_backend = None,
210
+ torch_compile_mode = None,
211
+ include_tokens_per_second = False,
212
+ include_num_input_tokens_seen = False,
213
+ neftune_noise_alpha = None,
214
+ optim_target_modules = None,
215
+ batch_eval_metrics = False,
216
+ eval_on_start = False,
217
+ use_liger_kernel = False,
218
+ liger_kernel_config = None,
219
+ eval_use_gather_object = False,
220
+ average_tokens_across_devices = False,
221
+ max_length = None,
222
+ max_prompt_length = None,
223
+ max_completion_length = None,
224
+ max_target_length = None,
225
+ beta = 0.1,
226
+ label_smoothing = 0,
227
+ loss_type = 'sigmoid',
228
+ disable_dropout = True,
229
+ label_pad_token_id = -100,
230
+ padding_value = None,
231
+ truncation_mode = 'keep_end',
232
+ generate_during_eval = False,
233
+ is_encoder_decoder = None,
234
+ model_init_kwargs = None,
235
+ dataset_num_proc = None,
236
+ vllm_sampling_params = None,
237
+ unsloth_num_chunks = -1,
238
+ **kwargs,
239
+ ):
240
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
241
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
242
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
243
+ output_dir = 'unsloth_training_checkpoints'
244
+ save_strategy = 'no'
245
+ if dataset_num_proc is None:
246
+ from multiprocessing import cpu_count
247
+ dataset_num_proc = cpu_count()
248
+
249
+ super().__init__(
250
+ output_dir = output_dir,
251
+ overwrite_output_dir = overwrite_output_dir,
252
+ do_train = do_train,
253
+ do_eval = do_eval,
254
+ do_predict = do_predict,
255
+ eval_strategy = eval_strategy,
256
+ prediction_loss_only = prediction_loss_only,
257
+ per_device_train_batch_size = per_device_train_batch_size,
258
+ per_device_eval_batch_size = per_device_eval_batch_size,
259
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
260
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
261
+ gradient_accumulation_steps = gradient_accumulation_steps,
262
+ eval_accumulation_steps = eval_accumulation_steps,
263
+ eval_delay = eval_delay,
264
+ torch_empty_cache_steps = torch_empty_cache_steps,
265
+ learning_rate = learning_rate,
266
+ weight_decay = weight_decay,
267
+ adam_beta1 = adam_beta1,
268
+ adam_beta2 = adam_beta2,
269
+ adam_epsilon = adam_epsilon,
270
+ max_grad_norm = max_grad_norm,
271
+ num_train_epochs = num_train_epochs,
272
+ max_steps = max_steps,
273
+ lr_scheduler_type = lr_scheduler_type,
274
+ warmup_ratio = warmup_ratio,
275
+ warmup_steps = warmup_steps,
276
+ log_level = log_level,
277
+ log_level_replica = log_level_replica,
278
+ log_on_each_node = log_on_each_node,
279
+ logging_dir = logging_dir,
280
+ logging_strategy = logging_strategy,
281
+ logging_first_step = logging_first_step,
282
+ logging_steps = logging_steps,
283
+ logging_nan_inf_filter = logging_nan_inf_filter,
284
+ save_strategy = save_strategy,
285
+ save_steps = save_steps,
286
+ save_total_limit = save_total_limit,
287
+ save_safetensors = save_safetensors,
288
+ save_on_each_node = save_on_each_node,
289
+ save_only_model = save_only_model,
290
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
291
+ no_cuda = no_cuda,
292
+ use_cpu = use_cpu,
293
+ use_mps_device = use_mps_device,
294
+ seed = seed,
295
+ data_seed = data_seed,
296
+ jit_mode_eval = jit_mode_eval,
297
+ use_ipex = use_ipex,
298
+ bf16 = bf16,
299
+ fp16 = fp16,
300
+ fp16_opt_level = fp16_opt_level,
301
+ half_precision_backend = half_precision_backend,
302
+ bf16_full_eval = bf16_full_eval,
303
+ fp16_full_eval = fp16_full_eval,
304
+ tf32 = tf32,
305
+ local_rank = local_rank,
306
+ ddp_backend = ddp_backend,
307
+ tpu_num_cores = tpu_num_cores,
308
+ tpu_metrics_debug = tpu_metrics_debug,
309
+ debug = debug,
310
+ dataloader_drop_last = dataloader_drop_last,
311
+ eval_steps = eval_steps,
312
+ dataloader_num_workers = dataloader_num_workers,
313
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
314
+ past_index = past_index,
315
+ run_name = run_name,
316
+ disable_tqdm = disable_tqdm,
317
+ remove_unused_columns = remove_unused_columns,
318
+ label_names = label_names,
319
+ load_best_model_at_end = load_best_model_at_end,
320
+ metric_for_best_model = metric_for_best_model,
321
+ greater_is_better = greater_is_better,
322
+ ignore_data_skip = ignore_data_skip,
323
+ fsdp = fsdp,
324
+ fsdp_min_num_params = fsdp_min_num_params,
325
+ fsdp_config = fsdp_config,
326
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
327
+ accelerator_config = accelerator_config,
328
+ deepspeed = deepspeed,
329
+ label_smoothing_factor = label_smoothing_factor,
330
+ optim = optim,
331
+ optim_args = optim_args,
332
+ adafactor = adafactor,
333
+ group_by_length = group_by_length,
334
+ length_column_name = length_column_name,
335
+ report_to = report_to,
336
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
337
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
338
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
339
+ dataloader_pin_memory = dataloader_pin_memory,
340
+ dataloader_persistent_workers = dataloader_persistent_workers,
341
+ skip_memory_metrics = skip_memory_metrics,
342
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
343
+ push_to_hub = push_to_hub,
344
+ resume_from_checkpoint = resume_from_checkpoint,
345
+ hub_model_id = hub_model_id,
346
+ hub_strategy = hub_strategy,
347
+ hub_token = hub_token,
348
+ hub_private_repo = hub_private_repo,
349
+ hub_always_push = hub_always_push,
350
+ hub_revision = hub_revision,
351
+ gradient_checkpointing = gradient_checkpointing,
352
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
353
+ include_inputs_for_metrics = include_inputs_for_metrics,
354
+ eval_do_concat_batches = eval_do_concat_batches,
355
+ fp16_backend = fp16_backend,
356
+ push_to_hub_model_id = push_to_hub_model_id,
357
+ push_to_hub_organization = push_to_hub_organization,
358
+ push_to_hub_token = push_to_hub_token,
359
+ mp_parameters = mp_parameters,
360
+ auto_find_batch_size = auto_find_batch_size,
361
+ full_determinism = full_determinism,
362
+ torchdynamo = torchdynamo,
363
+ ray_scope = ray_scope,
364
+ ddp_timeout = ddp_timeout,
365
+ torch_compile = torch_compile,
366
+ torch_compile_backend = torch_compile_backend,
367
+ torch_compile_mode = torch_compile_mode,
368
+ include_tokens_per_second = include_tokens_per_second,
369
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
370
+ neftune_noise_alpha = neftune_noise_alpha,
371
+ optim_target_modules = optim_target_modules,
372
+ batch_eval_metrics = batch_eval_metrics,
373
+ eval_on_start = eval_on_start,
374
+ use_liger_kernel = use_liger_kernel,
375
+ liger_kernel_config = liger_kernel_config,
376
+ eval_use_gather_object = eval_use_gather_object,
377
+ average_tokens_across_devices = average_tokens_across_devices,
378
+ max_length = max_length,
379
+ max_prompt_length = max_prompt_length,
380
+ max_completion_length = max_completion_length,
381
+ max_target_length = max_target_length,
382
+ beta = beta,
383
+ label_smoothing = label_smoothing,
384
+ loss_type = loss_type,
385
+ disable_dropout = disable_dropout,
386
+ label_pad_token_id = label_pad_token_id,
387
+ padding_value = padding_value,
388
+ truncation_mode = truncation_mode,
389
+ generate_during_eval = generate_during_eval,
390
+ is_encoder_decoder = is_encoder_decoder,
391
+ model_init_kwargs = model_init_kwargs,
392
+ dataset_num_proc = dataset_num_proc,**kwargs)
393
+ self.vllm_sampling_params = vllm_sampling_params
394
+ self.unsloth_num_chunks = unsloth_num_chunks
395
+ pass
396
+
397
+ class _UnslothCPOTrainer(Trainer):
398
+ r""""""
399
+
400
+ _tag_names = ["trl", "cpo"]
401
+
402
+ def __init__(
403
+ self,
404
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
405
+ args: Optional[CPOConfig] = None,
406
+ data_collator: Optional[DataCollator] = None,
407
+ train_dataset: Optional[Dataset] = None,
408
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
409
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
410
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
411
+ callbacks: Optional[List[TrainerCallback]] = None,
412
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
413
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
414
+ peft_config: Optional[Dict] = None,
415
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
416
+ ):
417
+ if args.model_init_kwargs is None:
418
+ model_init_kwargs = {}
419
+ elif not isinstance(model, str):
420
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
421
+ else:
422
+ model_init_kwargs = args.model_init_kwargs
423
+ model_init_kwargs["torch_dtype"] = (
424
+ model_init_kwargs["torch_dtype"]
425
+ if model_init_kwargs["torch_dtype"] in ["auto", None]
426
+ else getattr(torch, model_init_kwargs["torch_dtype"])
427
+ )
428
+
429
+ if isinstance(model, str):
430
+ warnings.warn(
431
+ "You passed a model_id to the CPOTrainer. This will automatically create an "
432
+ "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
433
+ )
434
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
435
+
436
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
437
+ # has been called in order to properly call autocast if needed.
438
+ self._peft_has_been_casted_to_bf16 = False
439
+
440
+ if not is_peft_available() and peft_config is not None:
441
+ raise ValueError(
442
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
443
+ )
444
+ elif is_peft_available() and peft_config is not None:
445
+ # if model is a peft model and we have a peft_config, we merge and unload it first
446
+ if isinstance(model, PeftModel):
447
+ model = model.merge_and_unload()
448
+
449
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
450
+ _support_gc_kwargs = hasattr(
451
+ args, "gradient_checkpointing_kwargs"
452
+ ) and "gradient_checkpointing_kwargs" in list(
453
+ inspect.signature(prepare_model_for_kbit_training).parameters
454
+ )
455
+
456
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
457
+
458
+ if _support_gc_kwargs:
459
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
460
+
461
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
462
+ elif getattr(args, "gradient_checkpointing", False):
463
+ # For backward compatibility with older versions of transformers
464
+ if hasattr(model, "enable_input_require_grads"):
465
+ model.enable_input_require_grads()
466
+ else:
467
+
468
+ def make_inputs_require_grad(module, input, output):
469
+ output.requires_grad_(True)
470
+
471
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
472
+
473
+ # get peft model with the given config
474
+ model = model
475
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
476
+ peft_module_casting_to_bf16(model)
477
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
478
+ self._peft_has_been_casted_to_bf16 = True
479
+
480
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
481
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
482
+ # fail or completely fail.
483
+ elif getattr(args, "gradient_checkpointing", False):
484
+ # For backward compatibility with older versions of transformers
485
+ if hasattr(model, "enable_input_require_grads"):
486
+ model.enable_input_require_grads()
487
+ else:
488
+
489
+ def make_inputs_require_grad(module, input, output):
490
+ output.requires_grad_(True)
491
+
492
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
493
+
494
+ if args.generate_during_eval and not is_wandb_available():
495
+ raise ValueError(
496
+ "`generate_during_eval=True` requires Weights and Biases to be installed."
497
+ " Please install `wandb` to resolve."
498
+ )
499
+
500
+ if model is not None:
501
+ self.is_encoder_decoder = model.config.is_encoder_decoder
502
+ elif args.is_encoder_decoder is None:
503
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
504
+ else:
505
+ self.is_encoder_decoder = args.is_encoder_decoder
506
+
507
+ if self.is_encoder_decoder:
508
+ self.decoder_start_token_id = model.config.decoder_start_token_id
509
+ self.pad_token_id = model.config.pad_token_id
510
+
511
+ if tokenizer is None:
512
+ raise ValueError("tokenizer must be specified to tokenize a CPO dataset.")
513
+ if args.max_length is None:
514
+ warnings.warn(
515
+ "`max_length` is not set in the CPOConfig's init"
516
+ " it will default to `512` by default, but you should do it yourself in the future.",
517
+ UserWarning,
518
+ )
519
+ max_length = 512
520
+ else:
521
+ max_length = args.max_length
522
+ if args.max_prompt_length is None:
523
+ warnings.warn(
524
+ "`max_prompt_length` is not set in the CPOConfig's init"
525
+ " it will default to `128` by default, but you should do it yourself in the future.",
526
+ UserWarning,
527
+ )
528
+ max_prompt_length = 128
529
+ else:
530
+ max_prompt_length = args.max_prompt_length
531
+
532
+ if args.max_target_length is None and self.is_encoder_decoder:
533
+ warnings.warn(
534
+ "When using an encoder decoder architecture, you should set `max_target_length` in the CPOConfig's init"
535
+ " it will default to `128` by default, but you should do it yourself in the future.",
536
+ UserWarning,
537
+ )
538
+ max_target_length = 128
539
+ else:
540
+ max_target_length = args.max_target_length
541
+
542
+ if data_collator is None:
543
+ data_collator = DPODataCollatorWithPadding(
544
+ pad_token_id=tokenizer.pad_token_id,
545
+ label_pad_token_id=args.label_pad_token_id,
546
+ is_encoder_decoder=self.is_encoder_decoder,
547
+ )
548
+
549
+ if args.remove_unused_columns:
550
+ args.remove_unused_columns = False
551
+ # warn users
552
+ warnings.warn(
553
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
554
+ " we have set it for you, but you should do it yourself in the future.",
555
+ UserWarning,
556
+ )
557
+
558
+ self.use_dpo_data_collator = True
559
+ else:
560
+ self.use_dpo_data_collator = False
561
+
562
+ if args.disable_dropout:
563
+ disable_dropout_in_model(model)
564
+
565
+ self.max_length = max_length
566
+ self.generate_during_eval = args.generate_during_eval
567
+ self.label_pad_token_id = args.label_pad_token_id
568
+ self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
569
+ self.max_prompt_length = max_prompt_length
570
+ self.truncation_mode = args.truncation_mode
571
+ self.max_target_length = max_target_length
572
+ self.tokenizer = tokenizer
573
+
574
+ if args.loss_type in ["hinge", "ipo", "kto_pair"] and args.label_smoothing > 0:
575
+ warnings.warn(
576
+ "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
577
+ )
578
+
579
+ self.beta = args.beta
580
+ self.label_smoothing = args.label_smoothing
581
+ self.loss_type = args.loss_type
582
+
583
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
584
+
585
+ # Compute that only on the main process for faster data processing.
586
+ # see: https://github.com/huggingface/trl/pull/1255
587
+ with PartialState().local_main_process_first():
588
+ # tokenize the dataset
589
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
590
+ if eval_dataset is not None:
591
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
592
+
593
+ super().__init__(
594
+ model=model,
595
+ args=args,
596
+ data_collator=data_collator,
597
+ train_dataset=train_dataset,
598
+ eval_dataset=eval_dataset,
599
+ tokenizer=tokenizer,
600
+ model_init=model_init,
601
+ compute_metrics=compute_metrics,
602
+ callbacks=callbacks,
603
+ optimizers=optimizers,
604
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
605
+ )
606
+
607
+ # Add tags for models that have been loaded with the correct transformers version
608
+ if hasattr(self.model, "add_model_tags"):
609
+ self.model.add_model_tags(self._tag_names)
610
+
611
+ if not hasattr(self, "accelerator"):
612
+ raise AttributeError(
613
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
614
+ )
615
+
616
+ def build_tokenized_answer(self, prompt, answer):
617
+ """
618
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
619
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
620
+ Reference:
621
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
622
+ """
623
+
624
+ full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
625
+ prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
626
+
627
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
628
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
629
+
630
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
631
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
632
+
633
+ # Prepare input tokens for token by token comparison
634
+ full_input_ids = np.array(full_tokenized["input_ids"])
635
+
636
+ if len(full_input_ids) != len(full_concat_input_ids):
637
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
638
+
639
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
640
+ # can be merged together when tokenizing prompt+answer. This could result
641
+ # on the last token from the prompt being different when tokenized on its own
642
+ # vs when done as prompt+answer.
643
+ response_token_ids_start_idx = len(prompt_input_ids)
644
+
645
+ # If tokenized prompt is different than both prompt+answer, then it means the
646
+ # last token has changed due to merging.
647
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
648
+ response_token_ids_start_idx -= 1
649
+
650
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
651
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
652
+
653
+ if len(prompt_input_ids) != len(prompt_attention_mask):
654
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
655
+
656
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
657
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
658
+
659
+ return dict(
660
+ prompt_input_ids=prompt_input_ids,
661
+ prompt_attention_mask=prompt_attention_mask,
662
+ input_ids=answer_input_ids,
663
+ attention_mask=answer_attention_mask,
664
+ )
665
+
666
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
667
+ """Tokenize a single row from a CPO specific dataset.
668
+
669
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
670
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
671
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
672
+
673
+ We also create the labels for the chosen/rejected responses, which are of length equal to
674
+ the sum of the length of the prompt and the chosen/rejected response, with
675
+ label_pad_token_id for the prompt tokens.
676
+ """
677
+ batch = {}
678
+ prompt = feature["prompt"]
679
+ chosen = feature["chosen"]
680
+ rejected = feature["rejected"]
681
+
682
+ if not self.is_encoder_decoder:
683
+ # Check issues below for more details
684
+ # 1. https://github.com/huggingface/trl/issues/907
685
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
686
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
687
+
688
+ if not isinstance(prompt, str):
689
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
690
+ prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
691
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
692
+
693
+ if not isinstance(chosen, str):
694
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
695
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
696
+
697
+ if not isinstance(rejected, str):
698
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
699
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
700
+
701
+ # Last prompt token might get merged by tokenizer and
702
+ # it should not be included for generation if that happens
703
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
704
+
705
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
706
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
707
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
708
+
709
+ for k, v in prompt_tokens.items():
710
+ prompt_tokens[k] = v[:prompt_len_input_ids]
711
+
712
+ # Make sure prompts only have one different token at most an
713
+ # and length only differs by 1 at most
714
+ num_diff_tokens = sum(
715
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
716
+ )
717
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
718
+ if num_diff_tokens > 1 or num_diff_len > 1:
719
+ raise ValueError(
720
+ "Chosen and rejected prompt_input_ids might only differ on the "
721
+ "last token due to tokenizer merge ops."
722
+ )
723
+
724
+ # add BOS token to head of prompt
725
+ prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
726
+ chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
727
+ rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
728
+
729
+ prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
730
+ chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
731
+ rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
732
+
733
+ # add EOS token to end of answer
734
+ chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
735
+ chosen_tokens["attention_mask"].append(1)
736
+
737
+ rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
738
+ rejected_tokens["attention_mask"].append(1)
739
+
740
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
741
+
742
+ # if combined sequence is too long, truncate the prompt
743
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
744
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
745
+ if self.truncation_mode == "keep_start":
746
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
747
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
748
+ elif self.truncation_mode == "keep_end":
749
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
750
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
751
+ else:
752
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
753
+
754
+ # if that's still too long, truncate the response
755
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
756
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
757
+ for k in ["input_ids", "attention_mask"]:
758
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
759
+
760
+ # Create labels
761
+ chosen_sequence_tokens = {
762
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
763
+ }
764
+ rejected_sequence_tokens = {
765
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
766
+ }
767
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
768
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
769
+ self.label_pad_token_id
770
+ ] * len(chosen_tokens["prompt_input_ids"])
771
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
772
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
773
+ self.label_pad_token_id
774
+ ] * len(rejected_tokens["prompt_input_ids"])
775
+
776
+ for k, toks in {
777
+ "chosen_": chosen_sequence_tokens,
778
+ "rejected_": rejected_sequence_tokens,
779
+ "": prompt_tokens,
780
+ }.items():
781
+ for type_key, tokens in toks.items():
782
+ if type_key == "token_type_ids":
783
+ continue
784
+ batch[f"{k}{type_key}"] = tokens
785
+
786
+ else:
787
+ chosen_tokens = self.tokenizer(
788
+ chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
789
+ )
790
+ rejected_tokens = self.tokenizer(
791
+ rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
792
+ )
793
+ prompt_tokens = self.tokenizer(
794
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
795
+ )
796
+
797
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
798
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
799
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
800
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
801
+
802
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
803
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
804
+ labels=torch.tensor(batch["rejected_labels"])
805
+ )
806
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
807
+ labels=torch.tensor(batch["chosen_labels"])
808
+ )
809
+
810
+ return batch
811
+
812
+ @staticmethod
813
+ def concatenated_inputs(
814
+ batch: Dict[str, Union[List, torch.LongTensor]],
815
+ is_encoder_decoder: bool = False,
816
+ label_pad_token_id: int = -100,
817
+ padding_value: int = 0,
818
+ device: Optional[torch.device] = None,
819
+ ) -> Dict[str, torch.LongTensor]:
820
+ """Concatenate the chosen and rejected inputs into a single tensor.
821
+
822
+ Args:
823
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
824
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
825
+ label_pad_token_id: The label pad token id.
826
+ padding_value: The padding value to use for the concatenated inputs_ids.
827
+ device: The device for the concatenated inputs.
828
+
829
+ Returns:
830
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
831
+ """
832
+ concatenated_batch = {}
833
+
834
+ if is_encoder_decoder:
835
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
836
+ else:
837
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
838
+
839
+ for k in batch:
840
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
841
+ if "labels" in k or is_encoder_decoder:
842
+ pad_value = label_pad_token_id
843
+ elif k.endswith("_input_ids"):
844
+ pad_value = padding_value
845
+ elif k.endswith("_attention_mask"):
846
+ pad_value = 0
847
+ concatenated_key = k.replace("chosen", "concatenated")
848
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
849
+ for k in batch:
850
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
851
+ if "labels" in k or is_encoder_decoder:
852
+ pad_value = label_pad_token_id
853
+ elif k.endswith("_input_ids"):
854
+ pad_value = padding_value
855
+ elif k.endswith("_attention_mask"):
856
+ pad_value = 0
857
+ concatenated_key = k.replace("rejected", "concatenated")
858
+ concatenated_batch[concatenated_key] = torch.cat(
859
+ (
860
+ concatenated_batch[concatenated_key],
861
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
862
+ ),
863
+ dim=0,
864
+ ).to(device=device)
865
+
866
+ if is_encoder_decoder:
867
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
868
+ concatenated_batch["concatenated_attention_mask"] = (
869
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
870
+ )
871
+
872
+ return concatenated_batch
873
+
874
+ def cpo_loss(
875
+ self,
876
+ policy_chosen_logps: torch.FloatTensor,
877
+ policy_rejected_logps: torch.FloatTensor,
878
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
879
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
880
+
881
+ Args:
882
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
883
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
884
+
885
+ Returns:
886
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
887
+ The losses tensor contains the CPO loss for each example in the batch.
888
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
889
+ """
890
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
891
+
892
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
893
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
894
+ # calculates a conservative CPO loss.
895
+ if self.loss_type == "sigmoid":
896
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
897
+ losses = (
898
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
899
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
900
+ )
901
+ elif self.loss_type == "hinge":
902
+ losses = torch.relu(1 - self.beta * logits)
903
+ elif self.loss_type == "ipo":
904
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
905
+ losses = (logits - 1 / (2 * self.beta)) ** 2
906
+ else:
907
+ raise ValueError(
908
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
909
+ )
910
+
911
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
912
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
913
+
914
+ return losses, chosen_rewards, rejected_rewards
915
+
916
+ @staticmethod
917
+ def get_batch_logps(
918
+ logits: torch.FloatTensor,
919
+ labels: torch.LongTensor,
920
+ average_log_prob: bool = False,
921
+ label_pad_token_id: int = -100,
922
+ is_encoder_decoder: bool = False,
923
+ ) -> torch.FloatTensor:
924
+ """Compute the log probabilities of the given labels under the given logits.
925
+
926
+ Args:
927
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
928
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
929
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
930
+ label_pad_token_id: The label pad token id.
931
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
932
+
933
+ Returns:
934
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
935
+ """
936
+ if logits.shape[:-1] != labels.shape:
937
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
938
+
939
+ if not is_encoder_decoder:
940
+ labels = labels[:, 1:].clone()
941
+ logits = logits[:, :-1, :]
942
+ loss_mask = labels != label_pad_token_id
943
+
944
+ # dummy token; we'll ignore the losses on these tokens later
945
+ labels[labels == label_pad_token_id] = 0
946
+
947
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
948
+
949
+ if average_log_prob:
950
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
951
+ else:
952
+ return (per_token_logps * loss_mask).sum(-1)
953
+
954
+ def concatenated_forward(
955
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
956
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
957
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
958
+
959
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
960
+ """
961
+ concatenated_batch = self.concatenated_inputs(
962
+ batch,
963
+ is_encoder_decoder=self.is_encoder_decoder,
964
+ label_pad_token_id=self.label_pad_token_id,
965
+ padding_value=self.padding_value,
966
+ device=self.accelerator.device,
967
+ )
968
+ len_chosen = batch["chosen_labels"].shape[0]
969
+
970
+ model_kwargs = (
971
+ {
972
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
973
+ }
974
+ if self.is_encoder_decoder
975
+ else {}
976
+ )
977
+
978
+ outputs = model(
979
+ concatenated_batch["concatenated_input_ids"],
980
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
981
+ use_cache=False,
982
+ **model_kwargs,
983
+ )
984
+ all_logits = outputs.logits
985
+
986
+ def cross_entropy_loss(logits, labels):
987
+ if not self.is_encoder_decoder:
988
+ # Shift so that tokens < n predict n
989
+ logits = logits[..., :-1, :].contiguous()
990
+ labels = labels[..., 1:].contiguous()
991
+ # Flatten the tokens
992
+ loss_fct = nn.CrossEntropyLoss()
993
+ logits = logits.view(-1, logits.shape[-1])
994
+ labels = labels.view(-1)
995
+ # Enable model parallelism
996
+ labels = labels.to(logits.device)
997
+ loss = loss_fct(logits, labels)
998
+ return loss
999
+
1000
+ labels = concatenated_batch["concatenated_labels"].clone()
1001
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1002
+
1003
+ all_logps = self.get_batch_logps(
1004
+ all_logits,
1005
+ concatenated_batch["concatenated_labels"],
1006
+ average_log_prob=self.loss_type == "ipo",
1007
+ is_encoder_decoder=self.is_encoder_decoder,
1008
+ label_pad_token_id=self.label_pad_token_id,
1009
+ )
1010
+
1011
+ chosen_logps = all_logps[:len_chosen]
1012
+ rejected_logps = all_logps[len_chosen:]
1013
+
1014
+ chosen_logits = all_logits[:len_chosen]
1015
+ rejected_logits = all_logits[len_chosen:]
1016
+
1017
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1018
+
1019
+ def get_batch_loss_metrics(
1020
+ self,
1021
+ model,
1022
+ batch: Dict[str, Union[List, torch.LongTensor]],
1023
+ train_eval: Literal["train", "eval"] = "train",
1024
+ ):
1025
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1026
+ metrics = {}
1027
+
1028
+ (
1029
+ policy_chosen_logps,
1030
+ policy_rejected_logps,
1031
+ policy_chosen_logits,
1032
+ policy_rejected_logits,
1033
+ policy_nll_loss,
1034
+ ) = self.concatenated_forward(model, batch)
1035
+
1036
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1037
+ policy_chosen_logps,
1038
+ policy_rejected_logps,
1039
+ )
1040
+
1041
+ loss = losses.mean() + policy_nll_loss
1042
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1043
+
1044
+ prefix = "eval_" if train_eval == "eval" else ""
1045
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
1046
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
1047
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
1048
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
1049
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
1050
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
1051
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
1052
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
1053
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
1054
+
1055
+ return loss, metrics
1056
+
1057
+ def compute_loss(
1058
+ self,
1059
+ model: Union[PreTrainedModel, nn.Module],
1060
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1061
+ return_outputs=False,
1062
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
1063
+ if not self.use_dpo_data_collator:
1064
+ warnings.warn(
1065
+ "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1066
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1067
+ )
1068
+
1069
+ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1070
+
1071
+ with compute_loss_context_manager():
1072
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1073
+
1074
+ # force log the metrics
1075
+ self.store_metrics(metrics, train_eval="train")
1076
+
1077
+ if return_outputs:
1078
+ return (loss, metrics)
1079
+ return loss
1080
+
1081
+ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
1082
+ """Generate samples from the model and reference model for the given batch of inputs."""
1083
+
1084
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1085
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1086
+ generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
1087
+
1088
+ with generate_context_manager():
1089
+ policy_output = model.generate(
1090
+ input_ids=batch["prompt_input_ids"],
1091
+ attention_mask=batch["prompt_attention_mask"],
1092
+ max_length=self.max_length,
1093
+ do_sample=True,
1094
+ pad_token_id=self.tokenizer.pad_token_id,
1095
+ )
1096
+
1097
+ policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
1098
+ policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
1099
+
1100
+ return policy_output_decoded
1101
+
1102
+ def prediction_step(
1103
+ self,
1104
+ model: Union[PreTrainedModel, nn.Module],
1105
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1106
+ prediction_loss_only: bool,
1107
+ ignore_keys: Optional[List[str]] = None,
1108
+ ):
1109
+ if not self.use_dpo_data_collator:
1110
+ warnings.warn(
1111
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1112
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1113
+ )
1114
+ if ignore_keys is None:
1115
+ if hasattr(model, "config"):
1116
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1117
+ else:
1118
+ ignore_keys = []
1119
+
1120
+ prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1121
+
1122
+ with torch.no_grad(), prediction_context_manager():
1123
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1124
+
1125
+ # force log the metrics
1126
+ self.store_metrics(metrics, train_eval="eval")
1127
+
1128
+ if prediction_loss_only:
1129
+ return (loss.detach(), None, None)
1130
+
1131
+ # logits for the chosen and rejected samples from model
1132
+ logits_dict = {
1133
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1134
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1135
+ }
1136
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1137
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1138
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1139
+
1140
+ return (loss.detach(), logits, labels)
1141
+
1142
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1143
+ for key, value in metrics.items():
1144
+ self._stored_metrics[train_eval][key].append(value)
1145
+
1146
+ def evaluation_loop(
1147
+ self,
1148
+ dataloader: DataLoader,
1149
+ description: str,
1150
+ prediction_loss_only: Optional[bool] = None,
1151
+ ignore_keys: Optional[List[str]] = None,
1152
+ metric_key_prefix: str = "eval",
1153
+ ) -> EvalLoopOutput:
1154
+ """
1155
+ Overriding built-in evaluation loop to store metrics for each batch.
1156
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1157
+
1158
+ Works both with or without labels.
1159
+ """
1160
+
1161
+ # Sample and save to game log if requested (for one batch to save time)
1162
+ if self.generate_during_eval:
1163
+ # Generate random indices within the range of the total number of samples
1164
+ num_samples = len(dataloader.dataset)
1165
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1166
+
1167
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1168
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1169
+ random_batch = self.data_collator(random_batch_dataset)
1170
+ random_batch = self._prepare_inputs(random_batch)
1171
+
1172
+ policy_output_decoded = self.get_batch_samples(self.model, random_batch)
1173
+
1174
+ self.log(
1175
+ {
1176
+ "game_log": wandb.Table(
1177
+ columns=["Prompt", "Policy"],
1178
+ rows=[
1179
+ [prompt, pol[len(prompt) :]]
1180
+ for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1181
+ ],
1182
+ )
1183
+ }
1184
+ )
1185
+ self.state.log_history.pop()
1186
+
1187
+ # Base evaluation
1188
+ initial_output = super().evaluation_loop(
1189
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1190
+ )
1191
+
1192
+ return initial_output
1193
+
1194
+ def log(self, logs: Dict[str, float]) -> None:
1195
+ """
1196
+ Log `logs` on the various objects watching training, including stored metrics.
1197
+
1198
+ Args:
1199
+ logs (`Dict[str, float]`):
1200
+ The values to log.
1201
+ """
1202
+ # logs either has 'loss' or 'eval_loss'
1203
+ train_eval = "train" if "loss" in logs else "eval"
1204
+ # Add averaged stored metrics to logs
1205
+ for key, metrics in self._stored_metrics[train_eval].items():
1206
+ logs[key] = torch.tensor(metrics).mean().item()
1207
+ del self._stored_metrics[train_eval]
1208
+ return super().log(logs)
1209
+
1210
+ def _shift_right(self, input_ids):
1211
+ if self.decoder_start_token_id is None:
1212
+ raise ValueError(
1213
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1214
+ )
1215
+
1216
+ # shift inputs to the right
1217
+ if is_torch_fx_proxy(input_ids):
1218
+ # Item assignment is not supported natively for proxies.
1219
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1220
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1221
+ else:
1222
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1223
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1224
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1225
+
1226
+ if self.pad_token_id is None:
1227
+ raise ValueError("model.config.pad_token_id has to be defined.")
1228
+ # replace possible -100 values in labels by `pad_token_id`
1229
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1230
+
1231
+ return shifted_input_ids
1232
+
1233
+ @wraps(Trainer.push_to_hub)
1234
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1235
+ """
1236
+ Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the
1237
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
1238
+ """
1239
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1240
+
1241
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1242
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1243
+ """
1244
+
1245
+ Initialize CPOTrainer.
1246
+
1247
+ Args:
1248
+ model (`transformers.PreTrainedModel`):
1249
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1250
+ args (`CPOConfig`):
1251
+ The CPO config arguments to use for training.
1252
+ data_collator (`transformers.DataCollator`):
1253
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1254
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1255
+ train_dataset (`datasets.Dataset`):
1256
+ The dataset to use for training.
1257
+ eval_dataset (`datasets.Dataset`):
1258
+ The dataset to use for evaluation.
1259
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
1260
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
1261
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1262
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1263
+ callbacks (`List[transformers.TrainerCallback]`):
1264
+ The callbacks to use for training.
1265
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1266
+ The optimizer and scheduler to use for training.
1267
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1268
+ The function to use to preprocess the logits before computing the metrics.
1269
+ peft_config (`Dict`, defaults to `None`):
1270
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1271
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
1272
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1273
+ a dictionary string to metric values.
1274
+
1275
+ """
1276
+ def __init__(
1277
+ self,
1278
+ model = None,
1279
+ args = None,
1280
+ data_collator = None,
1281
+ train_dataset = None,
1282
+ eval_dataset = None,
1283
+ tokenizer = None,
1284
+ model_init = None,
1285
+ callbacks = None,
1286
+ preprocess_logits_for_metrics = None,
1287
+ peft_config = None,
1288
+ compute_metrics = None,
1289
+ **kwargs
1290
+ ):
1291
+ if args is None: args = UnslothCPOConfig()
1292
+ use_bf16 = getattr(args, 'bf16', False)
1293
+ if type(use_bf16) is not bool: use_bf16 = False
1294
+ use_fp16 = getattr(args, 'fp16', False)
1295
+ if type(use_fp16) is not bool: use_fp16 = False
1296
+ force_float32 = False
1297
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1298
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1299
+ force_float32 = True
1300
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1301
+ dtype = getattr(model.config, 'torch_dtype', None)
1302
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1303
+ from unsloth_zoo.utils import _get_dtype
1304
+ dtype = _get_dtype(dtype)
1305
+ float16 = dtype == torch.float16
1306
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1307
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1308
+ if force_float32:
1309
+ args.fp16 = False
1310
+ args.bf16 = False
1311
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1312
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1313
+ args.fp16 = float16
1314
+ args.bf16 = not float16
1315
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1316
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1317
+ args.eval_strategy = 'steps'
1318
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1319
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1320
+ if ga_steps is not None and ga_steps > 1:
1321
+ from transformers import __version__ as transformers_version
1322
+ if Version(transformers_version) <= Version('4.45.2'):
1323
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1324
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1325
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1326
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1327
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1328
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1329
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1330
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1331
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1332
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1333
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1334
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1335
+ if force_float32:
1336
+ args.bf16_full_eval = False
1337
+ args.fp16_full_eval = False
1338
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1339
+ args.bf16_full_eval = True
1340
+ args.fp16_full_eval = False
1341
+ elif not bf16_full_eval and not fp16_full_eval:
1342
+ args.bf16_full_eval = args.bf16
1343
+ args.fp16_full_eval = args.fp16
1344
+ _output_logits = False
1345
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1346
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1347
+ if _output_logits:
1348
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1349
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1350
+ pass
1351
+ else:
1352
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1353
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1354
+ if args_max_seq_length is None and model_max_seq_length is not None:
1355
+ max_seq_length = model.max_seq_length
1356
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1357
+ if model is not None and hasattr(model, 'for_training'):
1358
+ model.for_training()
1359
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1360
+ if 'processing_class' in locals():
1361
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1362
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1363
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1364
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1365
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1366
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1367
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1368
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1369
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1370
+ else:
1371
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1372
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1373
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1374
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1375
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1376
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1377
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1378
+ else:
1379
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1380
+ other_metrics = []
1381
+
1382
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1383
+ PatchRLStatistics('cpo_trainer', other_metrics)
1384
+
1385
+ super().__init__(
1386
+ model = model,
1387
+ args = args,
1388
+ data_collator = data_collator,
1389
+ train_dataset = train_dataset,
1390
+ eval_dataset = eval_dataset,
1391
+ tokenizer = tokenizer,
1392
+ model_init = model_init,
1393
+ callbacks = callbacks,
1394
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1395
+ peft_config = peft_config,
1396
+ compute_metrics = compute_metrics,**kwargs)
1397
+ if hasattr(self, 'neftune_hook_handle'):
1398
+ self.neftune_hook_handle.remove()
1399
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1400
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1401
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1402
+ pass
1403
+
1404
+ pass
compilefcach/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, BaseTrainer, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, MODEL_CARD_TEMPLATE, Optional, PerPromptStatTracker, ProjectConfiguration, Tuple, defaultdict, futures, logger, os, set_seed, torch, warn, warnings, whoami)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothDDPOConfig(DDPOConfig):
44
+ """
45
+
46
+ Configuration class for DDPOTrainer
47
+
48
+ """
49
+ vllm_sampling_params: Optional[Any] = field(
50
+ default = None,
51
+ metadata = {'help': 'vLLM SamplingParams'},
52
+ )
53
+ unsloth_num_chunks : Optional[int] = field(
54
+ default = -1,
55
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
56
+ )
57
+ def __init__(
58
+ self,
59
+ exp_name = 'colab_kernel_launcher',
60
+ run_name = '',
61
+ seed = 3407,
62
+ log_with = None,
63
+ tracker_project_name = 'trl',
64
+ logdir = 'logs',
65
+ num_epochs = 100,
66
+ save_freq = 1,
67
+ num_checkpoint_limit = 5,
68
+ mixed_precision = 'fp16',
69
+ allow_tf32 = True,
70
+ resume_from = '',
71
+ sample_num_steps = 50,
72
+ sample_eta = 1.0,
73
+ sample_guidance_scale = 5.0,
74
+ sample_batch_size = 1,
75
+ sample_num_batches_per_epoch = 2,
76
+ train_batch_size = 1,
77
+ train_use_8bit_adam = False,
78
+ train_learning_rate = 5e-05,
79
+ train_adam_beta1 = 0.9,
80
+ train_adam_beta2 = 0.999,
81
+ train_adam_weight_decay = 0.01,
82
+ train_adam_epsilon = 1e-08,
83
+ train_gradient_accumulation_steps = 2,
84
+ train_max_grad_norm = 1.0,
85
+ train_num_inner_epochs = 1,
86
+ train_cfg = True,
87
+ train_adv_clip_max = 5,
88
+ train_clip_range = 0.0001,
89
+ train_timestep_fraction = 1.0,
90
+ per_prompt_stat_tracking = False,
91
+ per_prompt_stat_tracking_buffer_size = 16,
92
+ per_prompt_stat_tracking_min_count = 16,
93
+ async_reward_computation = False,
94
+ max_workers = 2,
95
+ negative_prompts = '',
96
+ vllm_sampling_params = None,
97
+ unsloth_num_chunks = -1,
98
+ **kwargs,
99
+ ):
100
+
101
+ super().__init__(
102
+ exp_name = exp_name,
103
+ run_name = run_name,
104
+ seed = seed,
105
+ log_with = log_with,
106
+ tracker_project_name = tracker_project_name,
107
+ logdir = logdir,
108
+ num_epochs = num_epochs,
109
+ save_freq = save_freq,
110
+ num_checkpoint_limit = num_checkpoint_limit,
111
+ mixed_precision = mixed_precision,
112
+ allow_tf32 = allow_tf32,
113
+ resume_from = resume_from,
114
+ sample_num_steps = sample_num_steps,
115
+ sample_eta = sample_eta,
116
+ sample_guidance_scale = sample_guidance_scale,
117
+ sample_batch_size = sample_batch_size,
118
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
119
+ train_batch_size = train_batch_size,
120
+ train_use_8bit_adam = train_use_8bit_adam,
121
+ train_learning_rate = train_learning_rate,
122
+ train_adam_beta1 = train_adam_beta1,
123
+ train_adam_beta2 = train_adam_beta2,
124
+ train_adam_weight_decay = train_adam_weight_decay,
125
+ train_adam_epsilon = train_adam_epsilon,
126
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
127
+ train_max_grad_norm = train_max_grad_norm,
128
+ train_num_inner_epochs = train_num_inner_epochs,
129
+ train_cfg = train_cfg,
130
+ train_adv_clip_max = train_adv_clip_max,
131
+ train_clip_range = train_clip_range,
132
+ train_timestep_fraction = train_timestep_fraction,
133
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
134
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
135
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
136
+ async_reward_computation = async_reward_computation,
137
+ max_workers = max_workers,
138
+ negative_prompts = negative_prompts,**kwargs)
139
+ self.vllm_sampling_params = vllm_sampling_params
140
+ self.unsloth_num_chunks = unsloth_num_chunks
141
+ pass
142
+
143
+ class _UnslothDDPOTrainer(BaseTrainer):
144
+ """"""
145
+
146
+ _tag_names = ["trl", "ddpo"]
147
+
148
+ def __init__(
149
+ self,
150
+ config: DDPOConfig,
151
+ reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
152
+ prompt_function: Callable[[], Tuple[str, Any]],
153
+ sd_pipeline: DDPOStableDiffusionPipeline,
154
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
155
+ ):
156
+ if image_samples_hook is None:
157
+ warn("No image_samples_hook provided; no images will be logged")
158
+
159
+ self.prompt_fn = prompt_function
160
+ self.reward_fn = reward_function
161
+ self.config = config
162
+ self.image_samples_callback = image_samples_hook
163
+
164
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
165
+
166
+ if self.config.resume_from:
167
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
168
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
169
+ # get the most recent checkpoint in this directory
170
+ checkpoints = list(
171
+ filter(
172
+ lambda x: "checkpoint_" in x,
173
+ os.listdir(self.config.resume_from),
174
+ )
175
+ )
176
+ if len(checkpoints) == 0:
177
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
178
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
179
+ self.config.resume_from = os.path.join(
180
+ self.config.resume_from,
181
+ f"checkpoint_{checkpoint_numbers[-1]}",
182
+ )
183
+
184
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
185
+
186
+ # number of timesteps within each trajectory to train on
187
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
188
+
189
+ self.accelerator = Accelerator(
190
+ log_with=self.config.log_with,
191
+ mixed_precision=self.config.mixed_precision,
192
+ project_config=accelerator_project_config,
193
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
194
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
195
+ # the total number of optimizer steps to accumulate across.
196
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
197
+ **self.config.accelerator_kwargs,
198
+ )
199
+
200
+ is_okay, message = self._config_check()
201
+ if not is_okay:
202
+ raise ValueError(message)
203
+
204
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
205
+
206
+ if self.accelerator.is_main_process:
207
+ self.accelerator.init_trackers(
208
+ self.config.tracker_project_name,
209
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
210
+ init_kwargs=self.config.tracker_kwargs,
211
+ )
212
+
213
+ logger.info(f"\n{config}")
214
+
215
+ set_seed(self.config.seed, device_specific=True)
216
+
217
+ self.sd_pipeline = sd_pipeline
218
+
219
+ self.sd_pipeline.set_progress_bar_config(
220
+ position=1,
221
+ disable=not self.accelerator.is_local_main_process,
222
+ leave=False,
223
+ desc="Timestep",
224
+ dynamic_ncols=True,
225
+ )
226
+
227
+ # For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
228
+ # as these weights are only used for inference, keeping weights in full precision is not required.
229
+ if self.accelerator.mixed_precision == "fp16":
230
+ inference_dtype = torch.float16
231
+ elif self.accelerator.mixed_precision == "bf16":
232
+ inference_dtype = torch.bfloat16
233
+ else:
234
+ inference_dtype = torch.float32
235
+
236
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
237
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
238
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
239
+
240
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
241
+
242
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
243
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
244
+
245
+ # Enable TF32 for faster training on Ampere GPUs,
246
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
247
+ if self.config.allow_tf32:
248
+ torch.backends.cuda.matmul.allow_tf32 = True
249
+
250
+ self.optimizer = self._setup_optimizer(
251
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
252
+ )
253
+
254
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
255
+ self.sd_pipeline.tokenizer(
256
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
257
+ return_tensors="pt",
258
+ padding="max_length",
259
+ truncation=True,
260
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
261
+ ).input_ids.to(self.accelerator.device)
262
+ )[0]
263
+
264
+ if config.per_prompt_stat_tracking:
265
+ self.stat_tracker = PerPromptStatTracker(
266
+ config.per_prompt_stat_tracking_buffer_size,
267
+ config.per_prompt_stat_tracking_min_count,
268
+ )
269
+
270
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
271
+ # more memory
272
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
273
+
274
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
275
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
276
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
277
+ else:
278
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
279
+
280
+ if self.config.async_reward_computation:
281
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
282
+
283
+ if config.resume_from:
284
+ logger.info(f"Resuming from {config.resume_from}")
285
+ self.accelerator.load_state(config.resume_from)
286
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
287
+ else:
288
+ self.first_epoch = 0
289
+
290
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
291
+ if not is_async:
292
+ rewards = []
293
+ for images, prompts, prompt_metadata in prompt_image_pairs:
294
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
295
+ rewards.append(
296
+ (
297
+ torch.as_tensor(reward, device=self.accelerator.device),
298
+ reward_metadata,
299
+ )
300
+ )
301
+ else:
302
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
303
+ rewards = [
304
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
305
+ for reward, reward_metadata in rewards
306
+ ]
307
+
308
+ return zip(*rewards)
309
+
310
+ def step(self, epoch: int, global_step: int):
311
+ """
312
+ Perform a single step of training.
313
+
314
+ Args:
315
+ epoch (int): The current epoch.
316
+ global_step (int): The current global step.
317
+
318
+ Side Effects:
319
+ - Model weights are updated
320
+ - Logs the statistics to the accelerator trackers.
321
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
322
+
323
+ Returns:
324
+ global_step (int): The updated global step.
325
+
326
+ """
327
+ samples, prompt_image_data = self._generate_samples(
328
+ iterations=self.config.sample_num_batches_per_epoch,
329
+ batch_size=self.config.sample_batch_size,
330
+ )
331
+
332
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
333
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
334
+ rewards, rewards_metadata = self.compute_rewards(
335
+ prompt_image_data, is_async=self.config.async_reward_computation
336
+ )
337
+
338
+ for i, image_data in enumerate(prompt_image_data):
339
+ image_data.extend([rewards[i], rewards_metadata[i]])
340
+
341
+ if self.image_samples_callback is not None:
342
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
343
+
344
+ rewards = torch.cat(rewards)
345
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
346
+
347
+ self.accelerator.log(
348
+ {
349
+ "reward": rewards,
350
+ "epoch": epoch,
351
+ "reward_mean": rewards.mean(),
352
+ "reward_std": rewards.std(),
353
+ },
354
+ step=global_step,
355
+ )
356
+
357
+ if self.config.per_prompt_stat_tracking:
358
+ # gather the prompts across processes
359
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
360
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
361
+ advantages = self.stat_tracker.update(prompts, rewards)
362
+ else:
363
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
364
+
365
+ # ungather advantages; keep the entries corresponding to the samples on this process
366
+ samples["advantages"] = (
367
+ torch.as_tensor(advantages)
368
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
369
+ .to(self.accelerator.device)
370
+ )
371
+
372
+ del samples["prompt_ids"]
373
+
374
+ total_batch_size, num_timesteps = samples["timesteps"].shape
375
+
376
+ for inner_epoch in range(self.config.train_num_inner_epochs):
377
+ # shuffle samples along batch dimension
378
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
379
+ samples = {k: v[perm] for k, v in samples.items()}
380
+
381
+ # shuffle along time dimension independently for each sample
382
+ # still trying to understand the code below
383
+ perms = torch.stack(
384
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
385
+ )
386
+
387
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
388
+ samples[key] = samples[key][
389
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
390
+ perms,
391
+ ]
392
+
393
+ original_keys = samples.keys()
394
+ original_values = samples.values()
395
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
396
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
397
+
398
+ # Transpose the list of original values
399
+ transposed_values = zip(*reshaped_values)
400
+ # Create new dictionaries for each row of transposed values
401
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
402
+
403
+ self.sd_pipeline.unet.train()
404
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
405
+ # ensure optimization step at the end of the inner epoch
406
+ if not self.accelerator.sync_gradients:
407
+ raise ValueError(
408
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
409
+ )
410
+
411
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
412
+ self.accelerator.save_state()
413
+
414
+ return global_step
415
+
416
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
417
+ """
418
+ Calculate the loss for a batch of an unpacked sample
419
+
420
+ Args:
421
+ latents (torch.Tensor):
422
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
423
+ timesteps (torch.Tensor):
424
+ The timesteps sampled from the diffusion model, shape: [batch_size]
425
+ next_latents (torch.Tensor):
426
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
427
+ log_probs (torch.Tensor):
428
+ The log probabilities of the latents, shape: [batch_size]
429
+ advantages (torch.Tensor):
430
+ The advantages of the latents, shape: [batch_size]
431
+ embeds (torch.Tensor):
432
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
433
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
434
+
435
+ Returns:
436
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
437
+ (all of these are of shape (1,))
438
+ """
439
+ with self.autocast():
440
+ if self.config.train_cfg:
441
+ noise_pred = self.sd_pipeline.unet(
442
+ torch.cat([latents] * 2),
443
+ torch.cat([timesteps] * 2),
444
+ embeds,
445
+ ).sample
446
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
447
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
448
+ noise_pred_text - noise_pred_uncond
449
+ )
450
+ else:
451
+ noise_pred = self.sd_pipeline.unet(
452
+ latents,
453
+ timesteps,
454
+ embeds,
455
+ ).sample
456
+ # compute the log prob of next_latents given latents under the current model
457
+
458
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
459
+ noise_pred,
460
+ timesteps,
461
+ latents,
462
+ eta=self.config.sample_eta,
463
+ prev_sample=next_latents,
464
+ )
465
+
466
+ log_prob = scheduler_step_output.log_probs
467
+
468
+ advantages = torch.clamp(
469
+ advantages,
470
+ -self.config.train_adv_clip_max,
471
+ self.config.train_adv_clip_max,
472
+ )
473
+
474
+ ratio = torch.exp(log_prob - log_probs)
475
+
476
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
477
+
478
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
479
+
480
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
481
+
482
+ return loss, approx_kl, clipfrac
483
+
484
+ def loss(
485
+ self,
486
+ advantages: torch.Tensor,
487
+ clip_range: float,
488
+ ratio: torch.Tensor,
489
+ ):
490
+ unclipped_loss = -advantages * ratio
491
+ clipped_loss = -advantages * torch.clamp(
492
+ ratio,
493
+ 1.0 - clip_range,
494
+ 1.0 + clip_range,
495
+ )
496
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
497
+
498
+ def _setup_optimizer(self, trainable_layers_parameters):
499
+ if self.config.train_use_8bit_adam:
500
+ import bitsandbytes
501
+
502
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
503
+ else:
504
+ optimizer_cls = torch.optim.AdamW
505
+
506
+ return optimizer_cls(
507
+ trainable_layers_parameters,
508
+ lr=self.config.train_learning_rate,
509
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
510
+ weight_decay=self.config.train_adam_weight_decay,
511
+ eps=self.config.train_adam_epsilon,
512
+ )
513
+
514
+ def _save_model_hook(self, models, weights, output_dir):
515
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
516
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
517
+
518
+ def _load_model_hook(self, models, input_dir):
519
+ self.sd_pipeline.load_checkpoint(models, input_dir)
520
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
521
+
522
+ def _generate_samples(self, iterations, batch_size):
523
+ """
524
+ Generate samples from the model
525
+
526
+ Args:
527
+ iterations (int): Number of iterations to generate samples for
528
+ batch_size (int): Batch size to use for sampling
529
+
530
+ Returns:
531
+ samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
532
+ """
533
+ samples = []
534
+ prompt_image_pairs = []
535
+ self.sd_pipeline.unet.eval()
536
+
537
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
538
+
539
+ for _ in range(iterations):
540
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
541
+
542
+ prompt_ids = self.sd_pipeline.tokenizer(
543
+ prompts,
544
+ return_tensors="pt",
545
+ padding="max_length",
546
+ truncation=True,
547
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
548
+ ).input_ids.to(self.accelerator.device)
549
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
550
+
551
+ with self.autocast():
552
+ sd_output = self.sd_pipeline(
553
+ prompt_embeds=prompt_embeds,
554
+ negative_prompt_embeds=sample_neg_prompt_embeds,
555
+ num_inference_steps=self.config.sample_num_steps,
556
+ guidance_scale=self.config.sample_guidance_scale,
557
+ eta=self.config.sample_eta,
558
+ output_type="pt",
559
+ )
560
+
561
+ images = sd_output.images
562
+ latents = sd_output.latents
563
+ log_probs = sd_output.log_probs
564
+
565
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
566
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
567
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
568
+
569
+ samples.append(
570
+ {
571
+ "prompt_ids": prompt_ids,
572
+ "prompt_embeds": prompt_embeds,
573
+ "timesteps": timesteps,
574
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
575
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
576
+ "log_probs": log_probs,
577
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
578
+ }
579
+ )
580
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
581
+
582
+ return samples, prompt_image_pairs
583
+
584
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
585
+ """
586
+ Train on a batch of samples. Main training segment
587
+
588
+ Args:
589
+ inner_epoch (int): The current inner epoch
590
+ epoch (int): The current epoch
591
+ global_step (int): The current global step
592
+ batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
593
+
594
+ Side Effects:
595
+ - Model weights are updated
596
+ - Logs the statistics to the accelerator trackers.
597
+
598
+ Returns:
599
+ global_step (int): The updated global step
600
+ """
601
+ info = defaultdict(list)
602
+ for _i, sample in enumerate(batched_samples):
603
+ if self.config.train_cfg:
604
+ # concat negative prompts to sample prompts to avoid two forward passes
605
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
606
+ else:
607
+ embeds = sample["prompt_embeds"]
608
+
609
+ for j in range(self.num_train_timesteps):
610
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
611
+ loss, approx_kl, clipfrac = self.calculate_loss(
612
+ sample["latents"][:, j],
613
+ sample["timesteps"][:, j],
614
+ sample["next_latents"][:, j],
615
+ sample["log_probs"][:, j],
616
+ sample["advantages"],
617
+ embeds,
618
+ )
619
+ info["approx_kl"].append(approx_kl)
620
+ info["clipfrac"].append(clipfrac)
621
+ info["loss"].append(loss)
622
+
623
+ self.accelerator.backward(loss)
624
+ if self.accelerator.sync_gradients:
625
+ self.accelerator.clip_grad_norm_(
626
+ self.trainable_layers.parameters()
627
+ if not isinstance(self.trainable_layers, list)
628
+ else self.trainable_layers,
629
+ self.config.train_max_grad_norm,
630
+ )
631
+ self.optimizer.step()
632
+ self.optimizer.zero_grad()
633
+
634
+ # Checks if the accelerator has performed an optimization step behind the scenes
635
+ if self.accelerator.sync_gradients:
636
+ # log training-related stuff
637
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
638
+ info = self.accelerator.reduce(info, reduction="mean")
639
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
640
+ self.accelerator.log(info, step=global_step)
641
+ global_step += 1
642
+ info = defaultdict(list)
643
+ return global_step
644
+
645
+ def _config_check(self) -> Tuple[bool, str]:
646
+ samples_per_epoch = (
647
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
648
+ )
649
+ total_train_batch_size = (
650
+ self.config.train_batch_size
651
+ * self.accelerator.num_processes
652
+ * self.config.train_gradient_accumulation_steps
653
+ )
654
+
655
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
656
+ return (
657
+ False,
658
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
659
+ )
660
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
661
+ return (
662
+ False,
663
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
664
+ )
665
+ if not samples_per_epoch % total_train_batch_size == 0:
666
+ return (
667
+ False,
668
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
669
+ )
670
+ return True, ""
671
+
672
+ def train(self, epochs: Optional[int] = None):
673
+ """
674
+ Train the model for a given number of epochs
675
+ """
676
+ global_step = 0
677
+ if epochs is None:
678
+ epochs = self.config.num_epochs
679
+ for epoch in range(self.first_epoch, epochs):
680
+ global_step = self.step(epoch, global_step)
681
+
682
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None:
683
+ """Creates and saves a model card for a TRL model.
684
+
685
+ Args:
686
+ path (`str`): The path to save the model card to.
687
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`.
688
+ """
689
+ try:
690
+ user = whoami()["name"]
691
+ # handle the offline case
692
+ except Exception:
693
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
694
+ return
695
+
696
+ if not os.path.exists(path):
697
+ os.makedirs(path)
698
+
699
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
700
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
701
+ f.write(model_card_content)
702
+
703
+ def _save_pretrained(self, save_directory):
704
+ self.sd_pipeline.save_pretrained(save_directory)
705
+ self.create_model_card(save_directory)
706
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
707
+ """
708
+
709
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
710
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
711
+ As of now only Stable Diffusion based pipelines are supported
712
+
713
+ Attributes:
714
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
715
+ details.
716
+ **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
717
+ **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
718
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
719
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
720
+
721
+ """
722
+ def __init__(
723
+ self,
724
+ config,
725
+ reward_function,
726
+ prompt_function,
727
+ sd_pipeline,
728
+ image_samples_hook = None,
729
+ **kwargs
730
+ ):
731
+ if args is None: args = UnslothDDPOConfig()
732
+ other_metrics = []
733
+
734
+ from unsloth_zoo.logging_utils import PatchRLStatistics
735
+ PatchRLStatistics('ddpo_trainer', other_metrics)
736
+
737
+ super().__init__(
738
+ config = config,
739
+ reward_function = reward_function,
740
+ prompt_function = prompt_function,
741
+ sd_pipeline = sd_pipeline,
742
+ image_samples_hook = image_samples_hook,**kwargs)
743
+
744
+ pass
compilefcach/UnslothKTOTrainer.py ADDED
@@ -0,0 +1,1629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, KTOConfig, KTOTrainer, List, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Tuple, Union, _get_kl_dataset, _process_tokens, _tokenize, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, has_length, inspect, is_peft_available, is_wandb_available, itemgetter, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, tqdm, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothKTOConfig(KTOConfig):
44
+ """
45
+
46
+ KTOConfig collects all training arguments related to the [`KTOTrainer`] class.
47
+
48
+ Using [`HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int`, *optional*, defaults to `None`):
54
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
55
+ max_prompt_length (`int`, *optional*, defaults to `None`):
56
+ The maximum length of the prompt. This argument is required if you want to use the default data collator.
57
+ max_completion_length (`int`, *optional*, defaults to `None`):
58
+ The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
59
+ beta (`float`, defaults to 0.1):
60
+ The beta factor in KTO loss. Higher beta means less divergence from the initial policy.
61
+ desirable_weight (`float`, *optional*, defaults to 1.0):
62
+ The desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
63
+ undesirable_weight (`float`, *optional*, defaults to 1.0):
64
+ The undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
65
+ label_pad_token_id (`int`, defaults to `-100`):
66
+ The label pad token id. This argument is required if you want to use the default data collator.
67
+ padding_value (`int`, defaults to `0`):
68
+ The padding value if it is different to the tokenizer's pad_token_id.
69
+ truncation_mode (`str`, defaults to `keep_end`):
70
+ The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
71
+ generate_during_eval (`bool`, defaults to `False`):
72
+ Whether to sample and log generations during evaluation step.
73
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
74
+ If no model is provided, we need to know if the model_init returns an encoder-decoder.
75
+ precompute_ref_log_probs (`bool`, defaults to `False`):
76
+ Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
77
+ without the reference model and reduce the total GPU memory needed.
78
+ model_init_kwargs: (`Optional[Dict]`, *optional*):
79
+ Dict of Optional kwargs to pass when instantiating the model from a string.
80
+ ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
81
+ Dict of Optional kwargs to pass when instantiating the ref model from a string.
82
+ dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
83
+ Number of processes to use for processing the datasets.
84
+
85
+ """
86
+ vllm_sampling_params: Optional[Any] = field(
87
+ default = None,
88
+ metadata = {'help': 'vLLM SamplingParams'},
89
+ )
90
+ unsloth_num_chunks : Optional[int] = field(
91
+ default = -1,
92
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
93
+ )
94
+ def __init__(
95
+ self,
96
+ output_dir = None,
97
+ overwrite_output_dir = None,
98
+ do_train = False,
99
+ do_eval = False,
100
+ do_predict = False,
101
+ eval_strategy = 'no',
102
+ prediction_loss_only = False,
103
+ per_device_train_batch_size = 4,
104
+ per_device_eval_batch_size = 4,
105
+ per_gpu_train_batch_size = None,
106
+ per_gpu_eval_batch_size = None,
107
+ gradient_accumulation_steps = 2,
108
+ eval_accumulation_steps = 2,
109
+ eval_delay = 0,
110
+ torch_empty_cache_steps = 250,
111
+ learning_rate = 5e-05,
112
+ weight_decay = 0.01,
113
+ adam_beta1 = 0.9,
114
+ adam_beta2 = 0.999,
115
+ adam_epsilon = 1e-08,
116
+ max_grad_norm = 1.0,
117
+ num_train_epochs = 3.0,
118
+ max_steps = -1,
119
+ lr_scheduler_type = 'linear',
120
+ warmup_ratio = 0.1,
121
+ warmup_steps = 0,
122
+ log_level = 'passive',
123
+ log_level_replica = 'warning',
124
+ log_on_each_node = True,
125
+ logging_dir = None,
126
+ logging_strategy = 'steps',
127
+ logging_first_step = False,
128
+ logging_steps = 1,
129
+ logging_nan_inf_filter = False,
130
+ save_strategy = 'steps',
131
+ save_steps = 500,
132
+ save_total_limit = None,
133
+ save_safetensors = True,
134
+ save_on_each_node = False,
135
+ save_only_model = False,
136
+ restore_callback_states_from_checkpoint = False,
137
+ no_cuda = False,
138
+ use_cpu = False,
139
+ use_mps_device = False,
140
+ seed = 3407,
141
+ data_seed = 3407,
142
+ jit_mode_eval = False,
143
+ use_ipex = False,
144
+ bf16 = False,
145
+ fp16 = False,
146
+ fp16_opt_level = 'O1',
147
+ half_precision_backend = 'auto',
148
+ bf16_full_eval = False,
149
+ fp16_full_eval = False,
150
+ tf32 = None,
151
+ local_rank = -1,
152
+ ddp_backend = None,
153
+ tpu_num_cores = None,
154
+ tpu_metrics_debug = False,
155
+ debug = '',
156
+ dataloader_drop_last = False,
157
+ eval_steps = None,
158
+ dataloader_num_workers = 0,
159
+ dataloader_prefetch_factor = None,
160
+ past_index = -1,
161
+ run_name = None,
162
+ disable_tqdm = None,
163
+ remove_unused_columns = True,
164
+ label_names = None,
165
+ load_best_model_at_end = False,
166
+ metric_for_best_model = None,
167
+ greater_is_better = None,
168
+ ignore_data_skip = False,
169
+ fsdp = '',
170
+ fsdp_min_num_params = 0,
171
+ fsdp_config = None,
172
+ fsdp_transformer_layer_cls_to_wrap = None,
173
+ accelerator_config = None,
174
+ deepspeed = None,
175
+ label_smoothing_factor = 0.0,
176
+ optim = 'adamw_8bit',
177
+ optim_args = None,
178
+ adafactor = False,
179
+ group_by_length = False,
180
+ length_column_name = 'length',
181
+ report_to = None,
182
+ ddp_find_unused_parameters = None,
183
+ ddp_bucket_cap_mb = None,
184
+ ddp_broadcast_buffers = None,
185
+ dataloader_pin_memory = True,
186
+ dataloader_persistent_workers = False,
187
+ skip_memory_metrics = True,
188
+ use_legacy_prediction_loop = False,
189
+ push_to_hub = False,
190
+ resume_from_checkpoint = None,
191
+ hub_model_id = None,
192
+ hub_strategy = 'every_save',
193
+ hub_token = None,
194
+ hub_private_repo = None,
195
+ hub_always_push = False,
196
+ hub_revision = None,
197
+ gradient_checkpointing = False,
198
+ gradient_checkpointing_kwargs = None,
199
+ include_inputs_for_metrics = False,
200
+ eval_do_concat_batches = True,
201
+ fp16_backend = 'auto',
202
+ push_to_hub_model_id = None,
203
+ push_to_hub_organization = None,
204
+ push_to_hub_token = None,
205
+ mp_parameters = '',
206
+ auto_find_batch_size = False,
207
+ full_determinism = False,
208
+ torchdynamo = None,
209
+ ray_scope = 'last',
210
+ ddp_timeout = 1800,
211
+ torch_compile = False,
212
+ torch_compile_backend = None,
213
+ torch_compile_mode = None,
214
+ include_tokens_per_second = False,
215
+ include_num_input_tokens_seen = False,
216
+ neftune_noise_alpha = None,
217
+ optim_target_modules = None,
218
+ batch_eval_metrics = False,
219
+ eval_on_start = False,
220
+ use_liger_kernel = False,
221
+ liger_kernel_config = None,
222
+ eval_use_gather_object = False,
223
+ average_tokens_across_devices = False,
224
+ max_length = None,
225
+ max_prompt_length = None,
226
+ max_completion_length = None,
227
+ beta = 0.1,
228
+ desirable_weight = 1.0,
229
+ undesirable_weight = 1.0,
230
+ label_pad_token_id = -100,
231
+ padding_value = None,
232
+ truncation_mode = 'keep_end',
233
+ generate_during_eval = False,
234
+ is_encoder_decoder = None,
235
+ precompute_ref_log_probs = False,
236
+ model_init_kwargs = None,
237
+ ref_model_init_kwargs = None,
238
+ dataset_num_proc = None,
239
+ vllm_sampling_params = None,
240
+ unsloth_num_chunks = -1,
241
+ **kwargs,
242
+ ):
243
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
244
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
245
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
246
+ output_dir = 'unsloth_training_checkpoints'
247
+ save_strategy = 'no'
248
+ if dataset_num_proc is None:
249
+ from multiprocessing import cpu_count
250
+ dataset_num_proc = cpu_count()
251
+
252
+ super().__init__(
253
+ output_dir = output_dir,
254
+ overwrite_output_dir = overwrite_output_dir,
255
+ do_train = do_train,
256
+ do_eval = do_eval,
257
+ do_predict = do_predict,
258
+ eval_strategy = eval_strategy,
259
+ prediction_loss_only = prediction_loss_only,
260
+ per_device_train_batch_size = per_device_train_batch_size,
261
+ per_device_eval_batch_size = per_device_eval_batch_size,
262
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
263
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
264
+ gradient_accumulation_steps = gradient_accumulation_steps,
265
+ eval_accumulation_steps = eval_accumulation_steps,
266
+ eval_delay = eval_delay,
267
+ torch_empty_cache_steps = torch_empty_cache_steps,
268
+ learning_rate = learning_rate,
269
+ weight_decay = weight_decay,
270
+ adam_beta1 = adam_beta1,
271
+ adam_beta2 = adam_beta2,
272
+ adam_epsilon = adam_epsilon,
273
+ max_grad_norm = max_grad_norm,
274
+ num_train_epochs = num_train_epochs,
275
+ max_steps = max_steps,
276
+ lr_scheduler_type = lr_scheduler_type,
277
+ warmup_ratio = warmup_ratio,
278
+ warmup_steps = warmup_steps,
279
+ log_level = log_level,
280
+ log_level_replica = log_level_replica,
281
+ log_on_each_node = log_on_each_node,
282
+ logging_dir = logging_dir,
283
+ logging_strategy = logging_strategy,
284
+ logging_first_step = logging_first_step,
285
+ logging_steps = logging_steps,
286
+ logging_nan_inf_filter = logging_nan_inf_filter,
287
+ save_strategy = save_strategy,
288
+ save_steps = save_steps,
289
+ save_total_limit = save_total_limit,
290
+ save_safetensors = save_safetensors,
291
+ save_on_each_node = save_on_each_node,
292
+ save_only_model = save_only_model,
293
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
294
+ no_cuda = no_cuda,
295
+ use_cpu = use_cpu,
296
+ use_mps_device = use_mps_device,
297
+ seed = seed,
298
+ data_seed = data_seed,
299
+ jit_mode_eval = jit_mode_eval,
300
+ use_ipex = use_ipex,
301
+ bf16 = bf16,
302
+ fp16 = fp16,
303
+ fp16_opt_level = fp16_opt_level,
304
+ half_precision_backend = half_precision_backend,
305
+ bf16_full_eval = bf16_full_eval,
306
+ fp16_full_eval = fp16_full_eval,
307
+ tf32 = tf32,
308
+ local_rank = local_rank,
309
+ ddp_backend = ddp_backend,
310
+ tpu_num_cores = tpu_num_cores,
311
+ tpu_metrics_debug = tpu_metrics_debug,
312
+ debug = debug,
313
+ dataloader_drop_last = dataloader_drop_last,
314
+ eval_steps = eval_steps,
315
+ dataloader_num_workers = dataloader_num_workers,
316
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
317
+ past_index = past_index,
318
+ run_name = run_name,
319
+ disable_tqdm = disable_tqdm,
320
+ remove_unused_columns = remove_unused_columns,
321
+ label_names = label_names,
322
+ load_best_model_at_end = load_best_model_at_end,
323
+ metric_for_best_model = metric_for_best_model,
324
+ greater_is_better = greater_is_better,
325
+ ignore_data_skip = ignore_data_skip,
326
+ fsdp = fsdp,
327
+ fsdp_min_num_params = fsdp_min_num_params,
328
+ fsdp_config = fsdp_config,
329
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
330
+ accelerator_config = accelerator_config,
331
+ deepspeed = deepspeed,
332
+ label_smoothing_factor = label_smoothing_factor,
333
+ optim = optim,
334
+ optim_args = optim_args,
335
+ adafactor = adafactor,
336
+ group_by_length = group_by_length,
337
+ length_column_name = length_column_name,
338
+ report_to = report_to,
339
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
340
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
341
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
342
+ dataloader_pin_memory = dataloader_pin_memory,
343
+ dataloader_persistent_workers = dataloader_persistent_workers,
344
+ skip_memory_metrics = skip_memory_metrics,
345
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
346
+ push_to_hub = push_to_hub,
347
+ resume_from_checkpoint = resume_from_checkpoint,
348
+ hub_model_id = hub_model_id,
349
+ hub_strategy = hub_strategy,
350
+ hub_token = hub_token,
351
+ hub_private_repo = hub_private_repo,
352
+ hub_always_push = hub_always_push,
353
+ hub_revision = hub_revision,
354
+ gradient_checkpointing = gradient_checkpointing,
355
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
356
+ include_inputs_for_metrics = include_inputs_for_metrics,
357
+ eval_do_concat_batches = eval_do_concat_batches,
358
+ fp16_backend = fp16_backend,
359
+ push_to_hub_model_id = push_to_hub_model_id,
360
+ push_to_hub_organization = push_to_hub_organization,
361
+ push_to_hub_token = push_to_hub_token,
362
+ mp_parameters = mp_parameters,
363
+ auto_find_batch_size = auto_find_batch_size,
364
+ full_determinism = full_determinism,
365
+ torchdynamo = torchdynamo,
366
+ ray_scope = ray_scope,
367
+ ddp_timeout = ddp_timeout,
368
+ torch_compile = torch_compile,
369
+ torch_compile_backend = torch_compile_backend,
370
+ torch_compile_mode = torch_compile_mode,
371
+ include_tokens_per_second = include_tokens_per_second,
372
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
373
+ neftune_noise_alpha = neftune_noise_alpha,
374
+ optim_target_modules = optim_target_modules,
375
+ batch_eval_metrics = batch_eval_metrics,
376
+ eval_on_start = eval_on_start,
377
+ use_liger_kernel = use_liger_kernel,
378
+ liger_kernel_config = liger_kernel_config,
379
+ eval_use_gather_object = eval_use_gather_object,
380
+ average_tokens_across_devices = average_tokens_across_devices,
381
+ max_length = max_length,
382
+ max_prompt_length = max_prompt_length,
383
+ max_completion_length = max_completion_length,
384
+ beta = beta,
385
+ desirable_weight = desirable_weight,
386
+ undesirable_weight = undesirable_weight,
387
+ label_pad_token_id = label_pad_token_id,
388
+ padding_value = padding_value,
389
+ truncation_mode = truncation_mode,
390
+ generate_during_eval = generate_during_eval,
391
+ is_encoder_decoder = is_encoder_decoder,
392
+ precompute_ref_log_probs = precompute_ref_log_probs,
393
+ model_init_kwargs = model_init_kwargs,
394
+ ref_model_init_kwargs = ref_model_init_kwargs,
395
+ dataset_num_proc = dataset_num_proc,**kwargs)
396
+ self.vllm_sampling_params = vllm_sampling_params
397
+ self.unsloth_num_chunks = unsloth_num_chunks
398
+ pass
399
+
400
+ class _UnslothKTOTrainer(Trainer):
401
+ r""""""
402
+
403
+ _tag_names = ["trl", "kto"]
404
+
405
+ def __init__(
406
+ self,
407
+ model: Union[PreTrainedModel, nn.Module, str] = None,
408
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
409
+ args: KTOConfig = None,
410
+ train_dataset: Optional[Dataset] = None,
411
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
412
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
413
+ data_collator: Optional[DataCollator] = None,
414
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
415
+ callbacks: Optional[List[TrainerCallback]] = None,
416
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
417
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
418
+ peft_config: Optional[Dict] = None,
419
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
420
+ model_adapter_name: Optional[str] = None,
421
+ ref_adapter_name: Optional[str] = None,
422
+ ):
423
+ if type(args) == TrainingArguments:
424
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
425
+
426
+ if args.model_init_kwargs is None:
427
+ model_init_kwargs = {}
428
+ elif not isinstance(model, str):
429
+ raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
430
+ else:
431
+ model_init_kwargs = args.model_init_kwargs
432
+ model_init_kwargs["torch_dtype"] = (
433
+ model_init_kwargs["torch_dtype"]
434
+ if model_init_kwargs["torch_dtype"] in ["auto", None]
435
+ else getattr(torch, model_init_kwargs["torch_dtype"])
436
+ )
437
+
438
+ if args.ref_model_init_kwargs is None:
439
+ ref_model_init_kwargs = {}
440
+ elif not isinstance(ref_model, str):
441
+ raise ValueError(
442
+ "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
443
+ )
444
+ else:
445
+ ref_model_init_kwargs = args.ref_model_init_kwargs
446
+ ref_model_init_kwargs["torch_dtype"] = (
447
+ ref_model_init_kwargs["torch_dtype"]
448
+ if ref_model_init_kwargs["torch_dtype"] in ["auto", None]
449
+ else getattr(torch, ref_model_init_kwargs["torch_dtype"])
450
+ )
451
+
452
+ if isinstance(model, str):
453
+ warnings.warn(
454
+ "You passed a model_id to the KTOTrainer. This will automatically create an "
455
+ "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
456
+ )
457
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
458
+
459
+ if isinstance(ref_model, str):
460
+ warnings.warn(
461
+ "You passed a ref model_id to the KTOTrainer. This will automatically create an "
462
+ "`AutoModelForCausalLM`"
463
+ )
464
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
465
+
466
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
467
+ # has been called in order to properly call autocast if needed.
468
+ self._peft_has_been_casted_to_bf16 = False
469
+
470
+ if not is_peft_available() and peft_config is not None:
471
+ raise ValueError(
472
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
473
+ )
474
+ elif is_peft_available() and peft_config is not None:
475
+ # if model is a peft model and we have a peft_config, we merge and unload it first
476
+ if isinstance(model, PeftModel):
477
+ model = model.merge_and_unload()
478
+
479
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
480
+ _support_gc_kwargs = hasattr(
481
+ args, "gradient_checkpointing_kwargs"
482
+ ) and "gradient_checkpointing_kwargs" in list(
483
+ inspect.signature(prepare_model_for_kbit_training).parameters
484
+ )
485
+
486
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
487
+
488
+ if _support_gc_kwargs:
489
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
490
+
491
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
492
+ elif getattr(args, "gradient_checkpointing", False):
493
+ # For backward compatibility with older versions of transformers
494
+ if hasattr(model, "enable_input_require_grads"):
495
+ model.enable_input_require_grads()
496
+ else:
497
+
498
+ def make_inputs_require_grad(module, input, output):
499
+ output.requires_grad_(True)
500
+
501
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
502
+
503
+ # get peft model with the given config
504
+ model = model
505
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
506
+ peft_module_casting_to_bf16(model)
507
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
508
+ self._peft_has_been_casted_to_bf16 = True
509
+
510
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
511
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
512
+ # fail or completely fail.
513
+ elif getattr(args, "gradient_checkpointing", False):
514
+ # For backward compatibility with older versions of transformers
515
+ if hasattr(model, "enable_input_require_grads"):
516
+ model.enable_input_require_grads()
517
+ else:
518
+
519
+ def make_inputs_require_grad(module, input, output):
520
+ output.requires_grad_(True)
521
+
522
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
523
+
524
+ if args.generate_during_eval and not is_wandb_available():
525
+ raise ValueError(
526
+ "`generate_during_eval=True` requires Weights and Biases to be installed."
527
+ " Please install with `pip install wandb` to resolve."
528
+ )
529
+
530
+ if model is not None:
531
+ self.is_encoder_decoder = model.config.is_encoder_decoder
532
+ elif args.is_encoder_decoder is None:
533
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
534
+ else:
535
+ self.is_encoder_decoder = args.is_encoder_decoder
536
+
537
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
538
+ self.model_adapter_name = model_adapter_name
539
+ self.ref_adapter_name = ref_adapter_name
540
+
541
+ if ref_model:
542
+ self.ref_model = ref_model
543
+ elif self.is_peft_model or args.precompute_ref_log_probs:
544
+ # The `model` with adapters turned off will be used as the reference model
545
+ self.ref_model = None
546
+ else:
547
+ self.ref_model = create_reference_model(model)
548
+
549
+ if tokenizer is None:
550
+ raise ValueError(
551
+ "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
552
+ )
553
+ if args.max_length is None:
554
+ warnings.warn(
555
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
556
+ " it will be set to `512` by default, but you should do it yourself in the future.",
557
+ UserWarning,
558
+ )
559
+ max_length = 512
560
+ if args.max_length is not None:
561
+ max_length = args.max_length
562
+
563
+ if args.max_prompt_length is None:
564
+ warnings.warn(
565
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
566
+ " it will be set to `128` by default, but you should do it yourself in the future.",
567
+ UserWarning,
568
+ )
569
+ max_prompt_length = 128
570
+ if args.max_prompt_length is not None:
571
+ max_prompt_length = args.max_prompt_length
572
+
573
+ max_completion_length = None
574
+ if args.max_completion_length is None and self.is_encoder_decoder:
575
+ warnings.warn(
576
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
577
+ " it will be set to `128` by default, but you should do it yourself in the future.",
578
+ UserWarning,
579
+ )
580
+ max_completion_length = 128
581
+ if args.max_completion_length is not None and self.is_encoder_decoder:
582
+ max_completion_length = args.max_completion_length
583
+
584
+ if data_collator is None:
585
+ data_collator = DPODataCollatorWithPadding(
586
+ pad_token_id=tokenizer.pad_token_id,
587
+ label_pad_token_id=args.label_pad_token_id,
588
+ is_encoder_decoder=self.is_encoder_decoder,
589
+ )
590
+
591
+ if args.remove_unused_columns:
592
+ args.remove_unused_columns = False
593
+ # warn users
594
+ warnings.warn(
595
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
596
+ " we have set it for you, but you should do it yourself in the future.",
597
+ UserWarning,
598
+ )
599
+
600
+ self.use_dpo_data_collator = True
601
+ else:
602
+ self.use_dpo_data_collator = False
603
+
604
+ # disable dropout in the model and reference model
605
+ disable_dropout_in_model(model)
606
+ if self.ref_model is not None:
607
+ disable_dropout_in_model(self.ref_model)
608
+
609
+ self.max_length = max_length
610
+ self.generate_during_eval = args.generate_during_eval
611
+ self.label_pad_token_id = args.label_pad_token_id
612
+ self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
613
+ self.max_prompt_length = max_prompt_length
614
+ self.truncation_mode = args.truncation_mode
615
+ self.max_completion_length = max_completion_length
616
+ self.tokenizer = tokenizer
617
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
618
+
619
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
620
+ # keep track of first called to avoid computation of future calls
621
+ self._precomputed_train_ref_log_probs = False
622
+ self._precomputed_eval_ref_log_probs = False
623
+
624
+ # metric
625
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
626
+
627
+ # KTO parameter
628
+ self.beta = args.beta
629
+ self.desirable_weight = args.desirable_weight
630
+ self.undesirable_weight = args.undesirable_weight
631
+
632
+ with PartialState().local_main_process_first():
633
+ # Shuffle the datasets
634
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
635
+ if eval_dataset is not None:
636
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
637
+ # Tokenize and prepare the training datasets
638
+ train_dataset = train_dataset.map(
639
+ _tokenize,
640
+ fn_kwargs={"tokenizer": self.tokenizer},
641
+ batched=True,
642
+ desc="Tokenizing train dataset",
643
+ )
644
+ # Get KL datasets
645
+ total_batch_size = (
646
+ max(torch.cuda.device_count(), 1) * args.per_device_train_batch_size * args.gradient_accumulation_steps
647
+ )
648
+ if total_batch_size <= 1:
649
+ raise ValueError(
650
+ "Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
651
+ )
652
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
653
+ # i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n]
654
+ train_kl_dataset = train_dataset.map(
655
+ _get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting KL train dataset"
656
+ )
657
+ # Prepare the datasets
658
+ fn_kwargs = {
659
+ "prefix": "",
660
+ "is_encoder_decoder": self.is_encoder_decoder,
661
+ "tokenizer": self.tokenizer,
662
+ "max_length": self.max_length,
663
+ "truncation_mode": self.truncation_mode,
664
+ "label_pad_token_id": self.label_pad_token_id,
665
+ "max_prompt_length": self.max_prompt_length,
666
+ }
667
+ train_dataset = train_dataset.map(
668
+ _process_tokens,
669
+ fn_kwargs=fn_kwargs,
670
+ num_proc=args.dataset_num_proc,
671
+ desc="Processing tokenized train dataset",
672
+ )
673
+ fn_kwargs["prefix"] = "KL_"
674
+ train_kl_dataset = train_kl_dataset.map(
675
+ _process_tokens,
676
+ fn_kwargs=fn_kwargs,
677
+ num_proc=args.dataset_num_proc,
678
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
679
+ desc="Processing tokenized train KL dataset",
680
+ )
681
+
682
+ # merge the datasets
683
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
684
+
685
+ if eval_dataset is not None:
686
+ # Tokenize
687
+ eval_dataset = eval_dataset.map(
688
+ _tokenize,
689
+ fn_kwargs={"tokenizer": self.tokenizer},
690
+ batched=True,
691
+ desc="Tokenizing eval dataset",
692
+ )
693
+ # Get KL dataset
694
+ eval_kl_dataset = eval_dataset.map(
695
+ _get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting eval KL dataset"
696
+ )
697
+ # Process
698
+ fn_kwargs = {
699
+ "prefix": "",
700
+ "is_encoder_decoder": self.is_encoder_decoder,
701
+ "tokenizer": self.tokenizer,
702
+ "max_length": self.max_length,
703
+ "truncation_mode": self.truncation_mode,
704
+ "label_pad_token_id": self.label_pad_token_id,
705
+ "max_prompt_length": self.max_prompt_length,
706
+ }
707
+ eval_dataset = eval_dataset.map(
708
+ _process_tokens,
709
+ fn_kwargs=fn_kwargs,
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Processing tokenized eval dataset",
712
+ )
713
+ fn_kwargs["prefix"] = "KL_"
714
+ eval_kl_dataset = eval_kl_dataset.map(
715
+ _process_tokens,
716
+ fn_kwargs=fn_kwargs,
717
+ num_proc=args.dataset_num_proc,
718
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
719
+ desc="Processing tokenized eval KL dataset",
720
+ )
721
+
722
+ # merge the datasets
723
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
724
+
725
+ desirable = train_dataset.filter(
726
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
727
+ )
728
+ undesirable = train_dataset.filter(
729
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
730
+ )
731
+
732
+ if len(desirable) != len(undesirable):
733
+ # The lower and upper bounds come from Eq. [8] of https://arxiv.org/abs/2402.01306
734
+ des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
735
+ des_weight_upper_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33, 2)
736
+ und_weight_lower_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1.33, 2)
737
+ und_weight_upper_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1, 2)
738
+
739
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
740
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
741
+
742
+ if not (des_weight_in_range or und_weight_in_range):
743
+ warnings.warn(
744
+ f"""
745
+ You have different amounts of desirable/positive and undesirable/negative examples but the
746
+ weights on the desirable and undesirable losses don't seem to be in an ideal range. Based
747
+ on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}]
748
+ or undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH).
749
+ See the documentation on how to optimally set these weights.""",
750
+ UserWarning,
751
+ )
752
+
753
+ super().__init__(
754
+ model=model,
755
+ args=args,
756
+ data_collator=data_collator,
757
+ train_dataset=train_dataset,
758
+ eval_dataset=eval_dataset,
759
+ tokenizer=tokenizer,
760
+ model_init=model_init,
761
+ compute_metrics=compute_metrics,
762
+ callbacks=callbacks,
763
+ optimizers=optimizers,
764
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
765
+ )
766
+
767
+ # Add tags for models that have been loaded with the correct transformers version
768
+ if hasattr(self.model, "add_model_tags"):
769
+ self.model.add_model_tags(self._tag_names)
770
+
771
+ if not hasattr(self, "accelerator"):
772
+ raise AttributeError(
773
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
774
+ )
775
+
776
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
777
+ if self.is_deepspeed_enabled:
778
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
779
+ raise ValueError(
780
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
781
+ )
782
+
783
+ if self.ref_model is None:
784
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
785
+ raise ValueError(
786
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
787
+ )
788
+ else:
789
+ if self.is_deepspeed_enabled:
790
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
791
+ else:
792
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
793
+
794
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
795
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
796
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
797
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
798
+
799
+ if model is not None:
800
+ if hasattr(model, "config"):
801
+ hidden_size = (
802
+ max(model.config.hidden_sizes)
803
+ if getattr(model.config, "hidden_sizes", None)
804
+ else getattr(model.config, "hidden_size", None)
805
+ )
806
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
807
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
808
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
809
+ config_kwargs.update(
810
+ {
811
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
812
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
813
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
814
+ }
815
+ )
816
+
817
+ # If ZeRO-3 is used, we shard both the active and reference model.
818
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
819
+ if config_kwargs["zero_optimization"]["stage"] != 3:
820
+ config_kwargs["zero_optimization"]["stage"] = 0
821
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
822
+ model.eval()
823
+ return model
824
+
825
+ @contextmanager
826
+ def null_ref_context(self):
827
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
828
+ with self.accelerator.unwrap_model(
829
+ self.model
830
+ ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
831
+ if self.ref_adapter_name:
832
+ self.model.set_adapter(self.ref_adapter_name)
833
+ yield
834
+ if self.ref_adapter_name:
835
+ self.model.set_adapter(self.model_adapter_name or "default")
836
+
837
+ def get_train_dataloader(self) -> DataLoader:
838
+ """
839
+ Returns the training [`~torch.utils.data.DataLoader`].
840
+
841
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
842
+ """
843
+
844
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
845
+ dataloader_params = {
846
+ "batch_size": self.args.per_device_train_batch_size,
847
+ "collate_fn": self.data_collator,
848
+ "num_workers": self.args.dataloader_num_workers,
849
+ "pin_memory": self.args.dataloader_pin_memory,
850
+ "shuffle": False,
851
+ }
852
+
853
+ # prepare dataloader
854
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
855
+ reference_completion_logps = []
856
+ reference_KL_logps = []
857
+
858
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
859
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
860
+
861
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
862
+ reference_completion_logps.append(reference_completion_logp.cpu())
863
+
864
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
865
+ reference_KL_logps.append(reference_KL_logp.cpu())
866
+
867
+ self.train_dataset = self.train_dataset.add_column(
868
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
869
+ )
870
+ self.train_dataset = self.train_dataset.add_column(
871
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
872
+ )
873
+
874
+ self._precomputed_train_ref_log_probs = True
875
+
876
+ return super().get_train_dataloader()
877
+
878
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
879
+ """
880
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
881
+
882
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
883
+
884
+ Args:
885
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
886
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
887
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
888
+ """
889
+ if eval_dataset is None and self.eval_dataset is None:
890
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
891
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
892
+
893
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
894
+ dataloader_params = {
895
+ "batch_size": self.args.per_device_eval_batch_size,
896
+ "collate_fn": self.data_collator,
897
+ "num_workers": self.args.dataloader_num_workers,
898
+ "pin_memory": self.args.dataloader_pin_memory,
899
+ "shuffle": False,
900
+ }
901
+
902
+ # prepare dataloader
903
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
904
+
905
+ reference_completion_logps = []
906
+ reference_KL_logps = []
907
+
908
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
909
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
910
+
911
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
912
+ reference_completion_logps.append(reference_completion_logp.cpu())
913
+
914
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
915
+ reference_KL_logps.append(reference_KL_logp.cpu())
916
+
917
+ eval_dataset = eval_dataset.add_column(
918
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
919
+ )
920
+ eval_dataset = eval_dataset.add_column(
921
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
922
+ )
923
+
924
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
925
+ if self.eval_dataset is not None:
926
+ self.eval_dataset = eval_dataset
927
+ self._precomputed_eval_ref_log_probs = True
928
+
929
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
930
+
931
+ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
932
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
933
+ with torch.no_grad():
934
+ if self.ref_model is None:
935
+ with self.null_ref_context():
936
+ if self.is_encoder_decoder:
937
+ completion_logits = self.model(
938
+ padded_batch["prompt_input_ids"],
939
+ attention_mask=padded_batch["prompt_attention_mask"],
940
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
941
+ labels=padded_batch["completion_labels"],
942
+ ).logits
943
+
944
+ KL_logits = self.model(
945
+ padded_batch["KL_prompt_input_ids"],
946
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
947
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
948
+ labels=padded_batch["KL_completion_labels"],
949
+ ).logits
950
+ else:
951
+ completion_logits = self.model(
952
+ padded_batch["completion_input_ids"],
953
+ attention_mask=padded_batch["completion_attention_mask"],
954
+ ).logits
955
+
956
+ KL_logits = self.model(
957
+ padded_batch["KL_completion_input_ids"],
958
+ attention_mask=padded_batch["KL_completion_attention_mask"],
959
+ ).logits
960
+ else:
961
+ if self.is_encoder_decoder:
962
+ completion_logits = self.ref_model(
963
+ padded_batch["prompt_input_ids"],
964
+ attention_mask=padded_batch["prompt_attention_mask"],
965
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
966
+ labels=padded_batch["completion_labels"],
967
+ ).logits
968
+
969
+ KL_logits = self.ref_model(
970
+ padded_batch["KL_prompt_input_ids"],
971
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
972
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
973
+ labels=padded_batch["KL_completion_labels"],
974
+ ).logits
975
+ else:
976
+ completion_logits = self.ref_model(
977
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
978
+ ).logits
979
+
980
+ KL_logits = self.ref_model(
981
+ padded_batch["KL_completion_input_ids"],
982
+ attention_mask=padded_batch["KL_completion_attention_mask"],
983
+ ).logits
984
+
985
+ completion_logps = self.get_batch_logps(
986
+ completion_logits,
987
+ padded_batch["completion_labels"],
988
+ average_log_prob=False,
989
+ is_encoder_decoder=self.is_encoder_decoder,
990
+ label_pad_token_id=self.label_pad_token_id,
991
+ )
992
+
993
+ KL_logps = self.get_batch_logps(
994
+ KL_logits,
995
+ padded_batch["KL_completion_labels"],
996
+ average_log_prob=False,
997
+ is_encoder_decoder=self.is_encoder_decoder,
998
+ label_pad_token_id=self.label_pad_token_id,
999
+ )
1000
+
1001
+ return completion_logps, KL_logps
1002
+
1003
+ @staticmethod
1004
+ def get_batch_logps(
1005
+ logits: torch.FloatTensor,
1006
+ labels: torch.LongTensor,
1007
+ average_log_prob: bool = False,
1008
+ label_pad_token_id: int = -100,
1009
+ is_encoder_decoder: bool = False,
1010
+ ) -> torch.FloatTensor:
1011
+ """Compute the log probabilities of the given labels under the given logits.
1012
+
1013
+ Args:
1014
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1015
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1016
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1017
+
1018
+ Returns:
1019
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1020
+ """
1021
+ if logits.shape[:-1] != labels.shape:
1022
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1023
+
1024
+ if not is_encoder_decoder:
1025
+ labels = labels[:, 1:].clone()
1026
+ logits = logits[:, :-1, :]
1027
+ else:
1028
+ # Fixes end-dec RuntimeError
1029
+ labels = labels.clone()
1030
+
1031
+ loss_mask = labels != label_pad_token_id
1032
+
1033
+ # dummy token; we'll ignore the losses on these tokens later
1034
+ labels[labels == label_pad_token_id] = 0
1035
+
1036
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
1037
+
1038
+ if average_log_prob:
1039
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1040
+ else:
1041
+ return (per_token_logps * loss_mask).sum(-1)
1042
+
1043
+ def forward(
1044
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
1045
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1046
+ if self.is_encoder_decoder:
1047
+ with torch.no_grad():
1048
+ KL_logits = model(
1049
+ batch["KL_prompt_input_ids"],
1050
+ attention_mask=batch["KL_prompt_attention_mask"],
1051
+ decoder_input_ids=batch.get("KL_completion_decoder_input_ids"),
1052
+ labels=batch["KL_completion_labels"],
1053
+ ).logits
1054
+
1055
+ completion_logits = model(
1056
+ batch["prompt_input_ids"],
1057
+ attention_mask=batch["prompt_attention_mask"],
1058
+ decoder_input_ids=batch.get("completion_decoder_input_ids"),
1059
+ labels=batch["completion_labels"],
1060
+ ).logits
1061
+ else:
1062
+ with torch.no_grad():
1063
+ KL_logits = model(
1064
+ batch["KL_completion_input_ids"],
1065
+ attention_mask=batch["KL_completion_attention_mask"],
1066
+ ).logits
1067
+
1068
+ completion_logits = model(
1069
+ batch["completion_input_ids"],
1070
+ attention_mask=batch["completion_attention_mask"],
1071
+ ).logits
1072
+
1073
+ completion_logps = self.get_batch_logps(
1074
+ completion_logits,
1075
+ batch["completion_labels"],
1076
+ average_log_prob=False,
1077
+ is_encoder_decoder=self.is_encoder_decoder,
1078
+ label_pad_token_id=self.label_pad_token_id,
1079
+ )
1080
+
1081
+ KL_logps = self.get_batch_logps(
1082
+ KL_logits,
1083
+ batch["KL_completion_labels"],
1084
+ average_log_prob=False,
1085
+ is_encoder_decoder=self.is_encoder_decoder,
1086
+ label_pad_token_id=self.label_pad_token_id,
1087
+ )
1088
+
1089
+ if completion_logps.shape[0] != len(batch["label"]):
1090
+ raise ValueError(
1091
+ "There is a mismatch between the number of examples in this batch and the number of "
1092
+ "examples for which an output sequence was predicted."
1093
+ )
1094
+
1095
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1096
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1097
+
1098
+ chosen_logps = completion_logps[chosen_idx, ...]
1099
+ rejected_logps = completion_logps[rejected_idx, ...]
1100
+
1101
+ chosen_logits = completion_logits[chosen_idx, ...]
1102
+ rejected_logits = completion_logits[rejected_idx, ...]
1103
+
1104
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1105
+
1106
+ def kto_loss(
1107
+ self,
1108
+ policy_chosen_logps: torch.FloatTensor,
1109
+ policy_rejected_logps: torch.FloatTensor,
1110
+ policy_KL_logps: torch.FloatTensor,
1111
+ reference_chosen_logps: torch.FloatTensor,
1112
+ reference_rejected_logps: torch.FloatTensor,
1113
+ reference_KL_logps: torch.FloatTensor,
1114
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1115
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
1116
+
1117
+ Args:
1118
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1119
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1120
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1121
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1122
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1123
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1124
+
1125
+ Returns:
1126
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1127
+ The losses tensor contains the KTO loss for each example in the batch.
1128
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1129
+ The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1130
+ """
1131
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1132
+ kl = self.accelerator.gather(kl).mean().clamp(min=0)
1133
+
1134
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1135
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1136
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1137
+ chosen_rewards = self.beta * chosen_logratios.detach()
1138
+ else:
1139
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1140
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1141
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1142
+
1143
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1144
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1145
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1146
+ rejected_rewards = self.beta * rejected_logratios.detach()
1147
+ else:
1148
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1149
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1150
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1151
+
1152
+ losses = torch.cat(
1153
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1154
+ 0,
1155
+ )
1156
+
1157
+ return losses, chosen_rewards, rejected_rewards, kl
1158
+
1159
+ def get_batch_loss_metrics(
1160
+ self,
1161
+ model,
1162
+ batch: Dict[str, Union[List, torch.LongTensor]],
1163
+ ):
1164
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1165
+ metrics = {}
1166
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1167
+
1168
+ (
1169
+ policy_chosen_logps,
1170
+ policy_rejected_logps,
1171
+ policy_chosen_logits,
1172
+ policy_rejected_logits,
1173
+ policy_KL_logps,
1174
+ ) = self.forward(model, batch)
1175
+
1176
+ # if reference_logps in batch use them, otherwise use the reference model
1177
+ if "reference_logps" in batch:
1178
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1179
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1180
+
1181
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1182
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1183
+ reference_KL_logps = batch["reference_KL_logps"]
1184
+ else:
1185
+ with torch.no_grad():
1186
+ if self.ref_model is None:
1187
+ with self.null_ref_context():
1188
+ (
1189
+ reference_chosen_logps,
1190
+ reference_rejected_logps,
1191
+ _,
1192
+ _,
1193
+ reference_KL_logps,
1194
+ ) = self.forward(self.model, batch)
1195
+ else:
1196
+ (
1197
+ reference_chosen_logps,
1198
+ reference_rejected_logps,
1199
+ _,
1200
+ _,
1201
+ reference_KL_logps,
1202
+ ) = self.forward(self.ref_model, batch)
1203
+
1204
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1205
+ policy_chosen_logps,
1206
+ policy_rejected_logps,
1207
+ policy_KL_logps,
1208
+ reference_chosen_logps,
1209
+ reference_rejected_logps,
1210
+ reference_KL_logps,
1211
+ )
1212
+
1213
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1214
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1215
+
1216
+ all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
1217
+ all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
1218
+
1219
+ if all_num_chosen > 0:
1220
+ metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
1221
+ metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
1222
+ metrics["count/chosen"] = all_num_chosen
1223
+
1224
+ if all_num_rejected > 0:
1225
+ metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
1226
+ metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
1227
+ metrics["count/rejected"] = all_num_rejected
1228
+
1229
+ metrics["kl"] = kl.item()
1230
+
1231
+ return losses.nanmean(), metrics
1232
+
1233
+ def compute_loss(
1234
+ self,
1235
+ model: Union[PreTrainedModel, nn.Module],
1236
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1237
+ return_outputs=False,
1238
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
1239
+ if not self.use_dpo_data_collator:
1240
+ warnings.warn(
1241
+ "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1242
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1243
+ )
1244
+ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1245
+
1246
+ with compute_loss_context_manager():
1247
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1248
+
1249
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1250
+ loss = loss.to(self.args.device)
1251
+ # force log the metrics
1252
+ if self.accelerator.is_main_process:
1253
+ self.store_metrics(metrics, train_eval="train")
1254
+
1255
+ if return_outputs:
1256
+ return (loss, metrics)
1257
+ return loss
1258
+
1259
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1260
+ for key, value in metrics.items():
1261
+ self._stored_metrics[train_eval][key].append(value)
1262
+
1263
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1264
+ if self.train_dataset is None or not has_length(self.train_dataset):
1265
+ return None
1266
+ return SequentialSampler(self.train_dataset)
1267
+
1268
+ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
1269
+ """Generate samples from the model and reference model for the given batch of inputs."""
1270
+
1271
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1272
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1273
+ generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
1274
+
1275
+ with generate_context_manager():
1276
+ policy_output = model.generate(
1277
+ input_ids=batch["prompt_input_ids"],
1278
+ attention_mask=batch["prompt_attention_mask"],
1279
+ max_length=self.max_length,
1280
+ do_sample=True,
1281
+ pad_token_id=self.tokenizer.pad_token_id,
1282
+ )
1283
+
1284
+ # if reference_output in batch use that otherwise use the reference model
1285
+ if "reference_output" in batch:
1286
+ reference_output = batch["reference_output"]
1287
+ else:
1288
+ if self.ref_model is None:
1289
+ with self.null_ref_context():
1290
+ reference_output = self.model.generate(
1291
+ input_ids=batch["prompt_input_ids"],
1292
+ attention_mask=batch["prompt_attention_mask"],
1293
+ max_length=self.max_length,
1294
+ do_sample=True,
1295
+ pad_token_id=self.tokenizer.pad_token_id,
1296
+ )
1297
+ else:
1298
+ reference_output = self.ref_model.generate(
1299
+ input_ids=batch["prompt_input_ids"],
1300
+ attention_mask=batch["prompt_attention_mask"],
1301
+ max_length=self.max_length,
1302
+ do_sample=True,
1303
+ pad_token_id=self.tokenizer.pad_token_id,
1304
+ )
1305
+
1306
+ policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
1307
+ policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
1308
+
1309
+ reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
1310
+ reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
1311
+
1312
+ return policy_output_decoded, reference_output_decoded
1313
+
1314
+ def prediction_step(
1315
+ self,
1316
+ model: Union[PreTrainedModel, nn.Module],
1317
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1318
+ prediction_loss_only: bool,
1319
+ ignore_keys: Optional[List[str]] = None,
1320
+ ):
1321
+ if not self.use_dpo_data_collator:
1322
+ warnings.warn(
1323
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1324
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1325
+ )
1326
+ if ignore_keys is None:
1327
+ if hasattr(model, "config"):
1328
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1329
+ else:
1330
+ ignore_keys = []
1331
+
1332
+ prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1333
+ with torch.no_grad(), prediction_context_manager():
1334
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1335
+
1336
+ # force log the metrics
1337
+ if self.accelerator.is_main_process:
1338
+ self.store_metrics(metrics, train_eval="eval")
1339
+
1340
+ if prediction_loss_only:
1341
+ return (loss.detach(), None, None)
1342
+
1343
+ # logits for the chosen and rejected samples from model
1344
+ logits_dict = {
1345
+ "eval_logits/chosen": metrics["logits/chosen"],
1346
+ "eval_logits/rejected": metrics["logits/rejected"],
1347
+ }
1348
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1349
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1350
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1351
+
1352
+ return (loss.detach(), logits, labels)
1353
+
1354
+ def evaluation_loop(
1355
+ self,
1356
+ dataloader: DataLoader,
1357
+ description: str,
1358
+ prediction_loss_only: Optional[bool] = None,
1359
+ ignore_keys: Optional[List[str]] = None,
1360
+ metric_key_prefix: str = "eval",
1361
+ ) -> EvalLoopOutput:
1362
+ """
1363
+ Overriding built-in evaluation loop to store metrics for each batch.
1364
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1365
+
1366
+ Works both with or without labels.
1367
+ """
1368
+
1369
+ # Sample and save to game log if requested (for one batch to save time)
1370
+ if self.generate_during_eval:
1371
+ # Generate random indices within the range of the total number of samples
1372
+ num_samples = len(dataloader.dataset)
1373
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1374
+
1375
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1376
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1377
+ random_batch = self.data_collator(random_batch_dataset)
1378
+ random_batch = self._prepare_inputs(random_batch)
1379
+
1380
+ target_indicies = [i for i in range(len(random_batch["kl"])) if random_batch["kl"][i] is False]
1381
+ target_batch = {
1382
+ "prompt_input_ids": itemgetter(*target_indicies)(random_batch["prompt_input_ids"]),
1383
+ "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
1384
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1385
+ }
1386
+ policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
1387
+
1388
+ self.log(
1389
+ {
1390
+ "game_log": wandb.Table(
1391
+ columns=["Prompt", "Policy", "Ref Model"],
1392
+ rows=[
1393
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1394
+ for prompt, pol, ref in zip(
1395
+ target_batch["prompt"], policy_output_decoded, ref_output_decoded
1396
+ )
1397
+ ],
1398
+ )
1399
+ }
1400
+ )
1401
+ self.state.log_history.pop()
1402
+
1403
+ # Base evaluation
1404
+ initial_output = super().evaluation_loop(
1405
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1406
+ )
1407
+
1408
+ return initial_output
1409
+
1410
+ def log(self, logs: Dict[str, float]) -> None:
1411
+ """
1412
+ Log `logs` on the various objects watching training, including stored metrics.
1413
+
1414
+ Args:
1415
+ logs (`Dict[str, float]`):
1416
+ The values to log.
1417
+ """
1418
+ # logs either has 'loss' or 'eval_loss'
1419
+ train_eval = "train" if "loss" in logs else "eval"
1420
+ # accumulate average metrics from sums and lengths
1421
+ for split in ["chosen", "rejected"]:
1422
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1423
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1424
+ logs[f"{train_eval}/rewards/{split}"] = (
1425
+ torch.Tensor(self._stored_metrics[train_eval][f"rewards/{split}_sum"]).sum().item() / count_sum
1426
+ )
1427
+ logs[f"{train_eval}/logps/{split}"] = (
1428
+ torch.Tensor(self._stored_metrics[train_eval][f"logps/{split}_sum"]).sum().item() / count_sum
1429
+ )
1430
+ for key in [f"count/{split}", f"rewards/{split}_sum", f"logps/{split}_sum"]:
1431
+ del self._stored_metrics[train_eval][key]
1432
+ # calculate reward margin
1433
+ if f"{train_eval}/rewards/chosen" in logs and f"{train_eval}/rewards/rejected" in logs:
1434
+ logs[f"{train_eval}/rewards/margins"] = (
1435
+ logs[f"{train_eval}/rewards/chosen"] - logs[f"{train_eval}/rewards/rejected"]
1436
+ )
1437
+ # Add averaged stored metrics to logs
1438
+ for key, metrics in self._stored_metrics[train_eval].items():
1439
+ logs[f"{train_eval}/{key}"] = torch.Tensor(metrics).mean().item()
1440
+ del self._stored_metrics[train_eval]
1441
+ return super().log(logs)
1442
+
1443
+ @wraps(Trainer.push_to_hub)
1444
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1445
+ """
1446
+ Overwrite the `push_to_hub` method in order to force-add the tag "kto" when pushing the
1447
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
1448
+ """
1449
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1450
+
1451
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1452
+ class UnslothKTOTrainer(_UnslothKTOTrainer):
1453
+ """
1454
+
1455
+ Initialize KTOTrainer.
1456
+
1457
+ Args:
1458
+ model (`transformers.PreTrainedModel`):
1459
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1460
+ ref_model (`PreTrainedModelWrapper`):
1461
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1462
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1463
+ args (`KTOConfig`):
1464
+ The arguments to use for training.
1465
+ train_dataset (`datasets.Dataset`):
1466
+ The dataset to use for training.
1467
+ eval_dataset (`datasets.Dataset`):
1468
+ The dataset to use for evaluation.
1469
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
1470
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
1471
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1472
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1473
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1474
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1475
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1476
+ callbacks (`List[transformers.TrainerCallback]`):
1477
+ The callbacks to use for training.
1478
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1479
+ The optimizer and scheduler to use for training.
1480
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1481
+ The function to use to preprocess the logits before computing the metrics.
1482
+ peft_config (`Dict`, defaults to `None`):
1483
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1484
+ disable_dropout (`bool`, defaults to `True`):
1485
+ Whether or not to disable dropouts in `model` and `ref_model`.
1486
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
1487
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1488
+ a dictionary string to metric values.
1489
+ model_adapter_name (`str`, defaults to `None`):
1490
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1491
+ ref_adapter_name (`str`, defaults to `None`):
1492
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1493
+
1494
+ """
1495
+ def __init__(
1496
+ self,
1497
+ model = None,
1498
+ ref_model = None,
1499
+ args = None,
1500
+ train_dataset = None,
1501
+ eval_dataset = None,
1502
+ tokenizer = None,
1503
+ data_collator = None,
1504
+ model_init = None,
1505
+ callbacks = None,
1506
+ preprocess_logits_for_metrics = None,
1507
+ peft_config = None,
1508
+ compute_metrics = None,
1509
+ model_adapter_name = None,
1510
+ ref_adapter_name = None,
1511
+ **kwargs
1512
+ ):
1513
+ if args is None: args = UnslothKTOConfig()
1514
+ use_bf16 = getattr(args, 'bf16', False)
1515
+ if type(use_bf16) is not bool: use_bf16 = False
1516
+ use_fp16 = getattr(args, 'fp16', False)
1517
+ if type(use_fp16) is not bool: use_fp16 = False
1518
+ force_float32 = False
1519
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1520
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1521
+ force_float32 = True
1522
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1523
+ dtype = getattr(model.config, 'torch_dtype', None)
1524
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1525
+ from unsloth_zoo.utils import _get_dtype
1526
+ dtype = _get_dtype(dtype)
1527
+ float16 = dtype == torch.float16
1528
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1529
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1530
+ if force_float32:
1531
+ args.fp16 = False
1532
+ args.bf16 = False
1533
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1534
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1535
+ args.fp16 = float16
1536
+ args.bf16 = not float16
1537
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1538
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1539
+ args.eval_strategy = 'steps'
1540
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1541
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1542
+ if ga_steps is not None and ga_steps > 1:
1543
+ from transformers import __version__ as transformers_version
1544
+ if Version(transformers_version) <= Version('4.45.2'):
1545
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1546
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1547
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1548
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1549
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1550
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1551
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1552
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1553
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1554
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1555
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1556
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1557
+ if force_float32:
1558
+ args.bf16_full_eval = False
1559
+ args.fp16_full_eval = False
1560
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1561
+ args.bf16_full_eval = True
1562
+ args.fp16_full_eval = False
1563
+ elif not bf16_full_eval and not fp16_full_eval:
1564
+ args.bf16_full_eval = args.bf16
1565
+ args.fp16_full_eval = args.fp16
1566
+ _output_logits = False
1567
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1568
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1569
+ if _output_logits:
1570
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1571
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1572
+ pass
1573
+ else:
1574
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1575
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1576
+ if args_max_seq_length is None and model_max_seq_length is not None:
1577
+ max_seq_length = model.max_seq_length
1578
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1579
+ if model is not None and hasattr(model, 'for_training'):
1580
+ model.for_training()
1581
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1582
+ if 'processing_class' in locals():
1583
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1584
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1585
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1586
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1587
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1588
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1589
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1590
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1591
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1592
+ else:
1593
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1594
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1595
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1596
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1597
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1598
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1599
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1600
+ else:
1601
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1602
+ other_metrics = []
1603
+
1604
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1605
+ PatchRLStatistics('kto_trainer', other_metrics)
1606
+
1607
+ super().__init__(
1608
+ model = model,
1609
+ ref_model = ref_model,
1610
+ args = args,
1611
+ train_dataset = train_dataset,
1612
+ eval_dataset = eval_dataset,
1613
+ tokenizer = tokenizer,
1614
+ data_collator = data_collator,
1615
+ model_init = model_init,
1616
+ callbacks = callbacks,
1617
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1618
+ peft_config = peft_config,
1619
+ compute_metrics = compute_metrics,
1620
+ model_adapter_name = model_adapter_name,
1621
+ ref_adapter_name = ref_adapter_name,**kwargs)
1622
+ if hasattr(self, 'neftune_hook_handle'):
1623
+ self.neftune_hook_handle.remove()
1624
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1625
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1626
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1627
+ pass
1628
+
1629
+ pass
compilefcach/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, Dict, EvalLoopOutput, F, List, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, Trainer, TrainerCallback, Tuple, Union, deepcopy, defaultdict, disable_dropout_in_model, inspect, is_peft_available, is_torch_fx_proxy, is_wandb_available, nn, np, nullcontext, pad_to_length, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, torch, trl_sanitze_kwargs_for_tagging, wandb, warnings, wraps)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothORPOConfig(ORPOConfig):
44
+ """
45
+
46
+ ORPOConfig collects all training arguments related to the [`ORPOTrainer`] class.
47
+
48
+ Using [`HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int`, defaults to `None`):
54
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
55
+ max_prompt_length (`int`, defaults to `None`):
56
+ The maximum length of the prompt. This argument is required if you want to use the default data collator.
57
+ max_completion_length (`int`, defaults to `None`):
58
+ The maximum length of the completions. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
59
+ beta (`float`, defaults to 0.1):
60
+ The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss.
61
+ label_pad_token_id (`int`, defaults to `-100`):
62
+ The label pad token id. This argument is required if you want to use the default data collator.
63
+ padding_value (`int`, defaults to `None`):
64
+ The padding value if it is different to the tokenizer's pad_token_id.
65
+ truncation_mode (`str`, defaults to `keep_end`):
66
+ The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
67
+ generate_during_eval (`bool`, defaults to `False`):
68
+ Whether to sample and log generations during evaluation step.
69
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
70
+ If no model is provided, we need to know if the model_init returns an encoder-decoder.
71
+ disable_dropout (`bool`, defaults to `True`):
72
+ Whether or not to disable dropouts in `model`.
73
+ model_init_kwargs (`Optional[Dict]`, *optional*):
74
+ Dict of Optional kwargs to pass when instantiating the model from a string
75
+ dataset_num_proc (`Optional[int]`, *optional*):
76
+ The number of workers to use to tokenize the data. Defaults to None.
77
+
78
+ """
79
+ vllm_sampling_params: Optional[Any] = field(
80
+ default = None,
81
+ metadata = {'help': 'vLLM SamplingParams'},
82
+ )
83
+ unsloth_num_chunks : Optional[int] = field(
84
+ default = -1,
85
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
86
+ )
87
+ def __init__(
88
+ self,
89
+ output_dir = None,
90
+ overwrite_output_dir = None,
91
+ do_train = False,
92
+ do_eval = False,
93
+ do_predict = False,
94
+ eval_strategy = 'no',
95
+ prediction_loss_only = False,
96
+ per_device_train_batch_size = 4,
97
+ per_device_eval_batch_size = 4,
98
+ per_gpu_train_batch_size = None,
99
+ per_gpu_eval_batch_size = None,
100
+ gradient_accumulation_steps = 2,
101
+ eval_accumulation_steps = 2,
102
+ eval_delay = 0,
103
+ torch_empty_cache_steps = 250,
104
+ learning_rate = 5e-05,
105
+ weight_decay = 0.01,
106
+ adam_beta1 = 0.9,
107
+ adam_beta2 = 0.999,
108
+ adam_epsilon = 1e-08,
109
+ max_grad_norm = 1.0,
110
+ num_train_epochs = 3.0,
111
+ max_steps = -1,
112
+ lr_scheduler_type = 'linear',
113
+ warmup_ratio = 0.1,
114
+ warmup_steps = 0,
115
+ log_level = 'passive',
116
+ log_level_replica = 'warning',
117
+ log_on_each_node = True,
118
+ logging_dir = None,
119
+ logging_strategy = 'steps',
120
+ logging_first_step = False,
121
+ logging_steps = 1,
122
+ logging_nan_inf_filter = False,
123
+ save_strategy = 'steps',
124
+ save_steps = 500,
125
+ save_total_limit = None,
126
+ save_safetensors = True,
127
+ save_on_each_node = False,
128
+ save_only_model = False,
129
+ restore_callback_states_from_checkpoint = False,
130
+ no_cuda = False,
131
+ use_cpu = False,
132
+ use_mps_device = False,
133
+ seed = 3407,
134
+ data_seed = 3407,
135
+ jit_mode_eval = False,
136
+ use_ipex = False,
137
+ bf16 = False,
138
+ fp16 = False,
139
+ fp16_opt_level = 'O1',
140
+ half_precision_backend = 'auto',
141
+ bf16_full_eval = False,
142
+ fp16_full_eval = False,
143
+ tf32 = None,
144
+ local_rank = -1,
145
+ ddp_backend = None,
146
+ tpu_num_cores = None,
147
+ tpu_metrics_debug = False,
148
+ debug = '',
149
+ dataloader_drop_last = False,
150
+ eval_steps = None,
151
+ dataloader_num_workers = 0,
152
+ dataloader_prefetch_factor = None,
153
+ past_index = -1,
154
+ run_name = None,
155
+ disable_tqdm = None,
156
+ remove_unused_columns = True,
157
+ label_names = None,
158
+ load_best_model_at_end = False,
159
+ metric_for_best_model = None,
160
+ greater_is_better = None,
161
+ ignore_data_skip = False,
162
+ fsdp = '',
163
+ fsdp_min_num_params = 0,
164
+ fsdp_config = None,
165
+ fsdp_transformer_layer_cls_to_wrap = None,
166
+ accelerator_config = None,
167
+ deepspeed = None,
168
+ label_smoothing_factor = 0.0,
169
+ optim = 'adamw_8bit',
170
+ optim_args = None,
171
+ adafactor = False,
172
+ group_by_length = False,
173
+ length_column_name = 'length',
174
+ report_to = None,
175
+ ddp_find_unused_parameters = None,
176
+ ddp_bucket_cap_mb = None,
177
+ ddp_broadcast_buffers = None,
178
+ dataloader_pin_memory = True,
179
+ dataloader_persistent_workers = False,
180
+ skip_memory_metrics = True,
181
+ use_legacy_prediction_loop = False,
182
+ push_to_hub = False,
183
+ resume_from_checkpoint = None,
184
+ hub_model_id = None,
185
+ hub_strategy = 'every_save',
186
+ hub_token = None,
187
+ hub_private_repo = None,
188
+ hub_always_push = False,
189
+ hub_revision = None,
190
+ gradient_checkpointing = False,
191
+ gradient_checkpointing_kwargs = None,
192
+ include_inputs_for_metrics = False,
193
+ eval_do_concat_batches = True,
194
+ fp16_backend = 'auto',
195
+ push_to_hub_model_id = None,
196
+ push_to_hub_organization = None,
197
+ push_to_hub_token = None,
198
+ mp_parameters = '',
199
+ auto_find_batch_size = False,
200
+ full_determinism = False,
201
+ torchdynamo = None,
202
+ ray_scope = 'last',
203
+ ddp_timeout = 1800,
204
+ torch_compile = False,
205
+ torch_compile_backend = None,
206
+ torch_compile_mode = None,
207
+ include_tokens_per_second = False,
208
+ include_num_input_tokens_seen = False,
209
+ neftune_noise_alpha = None,
210
+ optim_target_modules = None,
211
+ batch_eval_metrics = False,
212
+ eval_on_start = False,
213
+ use_liger_kernel = False,
214
+ liger_kernel_config = None,
215
+ eval_use_gather_object = False,
216
+ average_tokens_across_devices = False,
217
+ max_length = None,
218
+ max_prompt_length = None,
219
+ max_completion_length = None,
220
+ beta = 0.1,
221
+ disable_dropout = True,
222
+ label_pad_token_id = -100,
223
+ padding_value = None,
224
+ truncation_mode = 'keep_end',
225
+ generate_during_eval = False,
226
+ is_encoder_decoder = None,
227
+ model_init_kwargs = None,
228
+ dataset_num_proc = None,
229
+ vllm_sampling_params = None,
230
+ unsloth_num_chunks = -1,
231
+ **kwargs,
232
+ ):
233
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
234
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
235
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
236
+ output_dir = 'unsloth_training_checkpoints'
237
+ save_strategy = 'no'
238
+ if dataset_num_proc is None:
239
+ from multiprocessing import cpu_count
240
+ dataset_num_proc = cpu_count()
241
+
242
+ super().__init__(
243
+ output_dir = output_dir,
244
+ overwrite_output_dir = overwrite_output_dir,
245
+ do_train = do_train,
246
+ do_eval = do_eval,
247
+ do_predict = do_predict,
248
+ eval_strategy = eval_strategy,
249
+ prediction_loss_only = prediction_loss_only,
250
+ per_device_train_batch_size = per_device_train_batch_size,
251
+ per_device_eval_batch_size = per_device_eval_batch_size,
252
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
253
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
254
+ gradient_accumulation_steps = gradient_accumulation_steps,
255
+ eval_accumulation_steps = eval_accumulation_steps,
256
+ eval_delay = eval_delay,
257
+ torch_empty_cache_steps = torch_empty_cache_steps,
258
+ learning_rate = learning_rate,
259
+ weight_decay = weight_decay,
260
+ adam_beta1 = adam_beta1,
261
+ adam_beta2 = adam_beta2,
262
+ adam_epsilon = adam_epsilon,
263
+ max_grad_norm = max_grad_norm,
264
+ num_train_epochs = num_train_epochs,
265
+ max_steps = max_steps,
266
+ lr_scheduler_type = lr_scheduler_type,
267
+ warmup_ratio = warmup_ratio,
268
+ warmup_steps = warmup_steps,
269
+ log_level = log_level,
270
+ log_level_replica = log_level_replica,
271
+ log_on_each_node = log_on_each_node,
272
+ logging_dir = logging_dir,
273
+ logging_strategy = logging_strategy,
274
+ logging_first_step = logging_first_step,
275
+ logging_steps = logging_steps,
276
+ logging_nan_inf_filter = logging_nan_inf_filter,
277
+ save_strategy = save_strategy,
278
+ save_steps = save_steps,
279
+ save_total_limit = save_total_limit,
280
+ save_safetensors = save_safetensors,
281
+ save_on_each_node = save_on_each_node,
282
+ save_only_model = save_only_model,
283
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
284
+ no_cuda = no_cuda,
285
+ use_cpu = use_cpu,
286
+ use_mps_device = use_mps_device,
287
+ seed = seed,
288
+ data_seed = data_seed,
289
+ jit_mode_eval = jit_mode_eval,
290
+ use_ipex = use_ipex,
291
+ bf16 = bf16,
292
+ fp16 = fp16,
293
+ fp16_opt_level = fp16_opt_level,
294
+ half_precision_backend = half_precision_backend,
295
+ bf16_full_eval = bf16_full_eval,
296
+ fp16_full_eval = fp16_full_eval,
297
+ tf32 = tf32,
298
+ local_rank = local_rank,
299
+ ddp_backend = ddp_backend,
300
+ tpu_num_cores = tpu_num_cores,
301
+ tpu_metrics_debug = tpu_metrics_debug,
302
+ debug = debug,
303
+ dataloader_drop_last = dataloader_drop_last,
304
+ eval_steps = eval_steps,
305
+ dataloader_num_workers = dataloader_num_workers,
306
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
307
+ past_index = past_index,
308
+ run_name = run_name,
309
+ disable_tqdm = disable_tqdm,
310
+ remove_unused_columns = remove_unused_columns,
311
+ label_names = label_names,
312
+ load_best_model_at_end = load_best_model_at_end,
313
+ metric_for_best_model = metric_for_best_model,
314
+ greater_is_better = greater_is_better,
315
+ ignore_data_skip = ignore_data_skip,
316
+ fsdp = fsdp,
317
+ fsdp_min_num_params = fsdp_min_num_params,
318
+ fsdp_config = fsdp_config,
319
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
320
+ accelerator_config = accelerator_config,
321
+ deepspeed = deepspeed,
322
+ label_smoothing_factor = label_smoothing_factor,
323
+ optim = optim,
324
+ optim_args = optim_args,
325
+ adafactor = adafactor,
326
+ group_by_length = group_by_length,
327
+ length_column_name = length_column_name,
328
+ report_to = report_to,
329
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
330
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
331
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
332
+ dataloader_pin_memory = dataloader_pin_memory,
333
+ dataloader_persistent_workers = dataloader_persistent_workers,
334
+ skip_memory_metrics = skip_memory_metrics,
335
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
336
+ push_to_hub = push_to_hub,
337
+ resume_from_checkpoint = resume_from_checkpoint,
338
+ hub_model_id = hub_model_id,
339
+ hub_strategy = hub_strategy,
340
+ hub_token = hub_token,
341
+ hub_private_repo = hub_private_repo,
342
+ hub_always_push = hub_always_push,
343
+ hub_revision = hub_revision,
344
+ gradient_checkpointing = gradient_checkpointing,
345
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
346
+ include_inputs_for_metrics = include_inputs_for_metrics,
347
+ eval_do_concat_batches = eval_do_concat_batches,
348
+ fp16_backend = fp16_backend,
349
+ push_to_hub_model_id = push_to_hub_model_id,
350
+ push_to_hub_organization = push_to_hub_organization,
351
+ push_to_hub_token = push_to_hub_token,
352
+ mp_parameters = mp_parameters,
353
+ auto_find_batch_size = auto_find_batch_size,
354
+ full_determinism = full_determinism,
355
+ torchdynamo = torchdynamo,
356
+ ray_scope = ray_scope,
357
+ ddp_timeout = ddp_timeout,
358
+ torch_compile = torch_compile,
359
+ torch_compile_backend = torch_compile_backend,
360
+ torch_compile_mode = torch_compile_mode,
361
+ include_tokens_per_second = include_tokens_per_second,
362
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
363
+ neftune_noise_alpha = neftune_noise_alpha,
364
+ optim_target_modules = optim_target_modules,
365
+ batch_eval_metrics = batch_eval_metrics,
366
+ eval_on_start = eval_on_start,
367
+ use_liger_kernel = use_liger_kernel,
368
+ liger_kernel_config = liger_kernel_config,
369
+ eval_use_gather_object = eval_use_gather_object,
370
+ average_tokens_across_devices = average_tokens_across_devices,
371
+ max_length = max_length,
372
+ max_prompt_length = max_prompt_length,
373
+ max_completion_length = max_completion_length,
374
+ beta = beta,
375
+ disable_dropout = disable_dropout,
376
+ label_pad_token_id = label_pad_token_id,
377
+ padding_value = padding_value,
378
+ truncation_mode = truncation_mode,
379
+ generate_during_eval = generate_during_eval,
380
+ is_encoder_decoder = is_encoder_decoder,
381
+ model_init_kwargs = model_init_kwargs,
382
+ dataset_num_proc = dataset_num_proc,**kwargs)
383
+ self.vllm_sampling_params = vllm_sampling_params
384
+ self.unsloth_num_chunks = unsloth_num_chunks
385
+ pass
386
+
387
+ class _UnslothORPOTrainer(Trainer):
388
+ r""""""
389
+
390
+ _tag_names = ["trl", "orpo"]
391
+
392
+ def __init__(
393
+ self,
394
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
395
+ args: Optional[ORPOConfig] = None,
396
+ data_collator: Optional[DataCollator] = None,
397
+ train_dataset: Optional[Dataset] = None,
398
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
399
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
400
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
401
+ callbacks: Optional[List[TrainerCallback]] = None,
402
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
403
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
404
+ peft_config: Optional[Dict] = None,
405
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
406
+ ):
407
+ if args.model_init_kwargs is None:
408
+ model_init_kwargs = {}
409
+ elif not isinstance(model, str):
410
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
411
+ else:
412
+ model_init_kwargs = args.model_init_kwargs
413
+ model_init_kwargs["torch_dtype"] = (
414
+ model_init_kwargs["torch_dtype"]
415
+ if model_init_kwargs["torch_dtype"] in ["auto", None]
416
+ else getattr(torch, model_init_kwargs["torch_dtype"])
417
+ )
418
+
419
+ if isinstance(model, str):
420
+ warnings.warn(
421
+ "You passed a model_id to the ORPOTrainer. This will automatically create an "
422
+ "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
423
+ )
424
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
425
+
426
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
427
+ # has been called in order to properly call autocast if needed.
428
+ self._peft_has_been_casted_to_bf16 = False
429
+
430
+ if not is_peft_available() and peft_config is not None:
431
+ raise ValueError(
432
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
433
+ )
434
+ elif is_peft_available() and peft_config is not None:
435
+ # if model is a peft model and we have a peft_config, we merge and unload it first
436
+ if isinstance(model, PeftModel):
437
+ model = model.merge_and_unload()
438
+
439
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
440
+ _support_gc_kwargs = hasattr(
441
+ args, "gradient_checkpointing_kwargs"
442
+ ) and "gradient_checkpointing_kwargs" in list(
443
+ inspect.signature(prepare_model_for_kbit_training).parameters
444
+ )
445
+
446
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
447
+
448
+ if _support_gc_kwargs:
449
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
450
+
451
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
452
+ elif getattr(args, "gradient_checkpointing", False):
453
+ # For backward compatibility with older versions of transformers
454
+ if hasattr(model, "enable_input_require_grads"):
455
+ model.enable_input_require_grads()
456
+ else:
457
+
458
+ def make_inputs_require_grad(module, input, output):
459
+ output.requires_grad_(True)
460
+
461
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
462
+
463
+ # get peft model with the given config
464
+ model = model
465
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
466
+ peft_module_casting_to_bf16(model)
467
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
468
+ self._peft_has_been_casted_to_bf16 = True
469
+
470
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
471
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
472
+ # fail or completely fail.
473
+ elif getattr(args, "gradient_checkpointing", False):
474
+ # For backward compatibility with older versions of transformers
475
+ if hasattr(model, "enable_input_require_grads"):
476
+ model.enable_input_require_grads()
477
+ else:
478
+
479
+ def make_inputs_require_grad(module, input, output):
480
+ output.requires_grad_(True)
481
+
482
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
483
+
484
+ if args.generate_during_eval and not is_wandb_available():
485
+ raise ValueError(
486
+ "`generate_during_eval=True` requires Weights and Biases to be installed."
487
+ " Please install `wandb` to resolve."
488
+ )
489
+
490
+ if model is not None:
491
+ self.is_encoder_decoder = model.config.is_encoder_decoder
492
+ elif args.is_encoder_decoder is None:
493
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
494
+ else:
495
+ self.is_encoder_decoder = args.is_encoder_decoder
496
+
497
+ if self.is_encoder_decoder:
498
+ self.decoder_start_token_id = model.config.decoder_start_token_id
499
+ self.pad_token_id = model.config.pad_token_id
500
+
501
+ if tokenizer is None:
502
+ raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.")
503
+ if args.max_length is None:
504
+ warnings.warn(
505
+ "`max_length` is not set in the ORPOConfig's init"
506
+ " it will default to `512` by default, but you should do it yourself in the future.",
507
+ UserWarning,
508
+ )
509
+ max_length = 512
510
+ else:
511
+ max_length = args.max_length
512
+ if args.max_prompt_length is None:
513
+ warnings.warn(
514
+ "`max_prompt_length` is not set in the ORPOConfig's init"
515
+ " it will default to `128` by default, but you should do it yourself in the future.",
516
+ UserWarning,
517
+ )
518
+ max_prompt_length = 128
519
+ else:
520
+ max_prompt_length = args.max_prompt_length
521
+
522
+ if args.max_completion_length is None and self.is_encoder_decoder:
523
+ warnings.warn(
524
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
525
+ " it will default to `128` by default, but you should do it yourself in the future.",
526
+ UserWarning,
527
+ )
528
+ self.max_completion_length = 128
529
+ else:
530
+ self.max_completion_length = args.max_completion_length
531
+
532
+ if data_collator is None:
533
+ data_collator = DPODataCollatorWithPadding(
534
+ pad_token_id=tokenizer.pad_token_id,
535
+ label_pad_token_id=args.label_pad_token_id,
536
+ is_encoder_decoder=self.is_encoder_decoder,
537
+ )
538
+
539
+ if args.remove_unused_columns:
540
+ args.remove_unused_columns = False
541
+ # warn users
542
+ warnings.warn(
543
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
544
+ " we have set it for you, but you should do it yourself in the future.",
545
+ UserWarning,
546
+ )
547
+
548
+ self.use_dpo_data_collator = True
549
+ else:
550
+ self.use_dpo_data_collator = False
551
+
552
+ if args.disable_dropout:
553
+ disable_dropout_in_model(model)
554
+
555
+ self.max_length = max_length
556
+ self.generate_during_eval = args.generate_during_eval
557
+ self.label_pad_token_id = args.label_pad_token_id
558
+ self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
559
+ self.max_prompt_length = max_prompt_length
560
+ self.truncation_mode = args.truncation_mode
561
+ self.tokenizer = tokenizer
562
+
563
+ self.beta = args.beta
564
+
565
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
566
+
567
+ # Compute that only on the main process for faster data processing.
568
+ # see: https://github.com/huggingface/trl/pull/1255
569
+ with PartialState().local_main_process_first():
570
+ # tokenize the dataset
571
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
572
+ if eval_dataset is not None:
573
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
574
+
575
+ super().__init__(
576
+ model=model,
577
+ args=args,
578
+ data_collator=data_collator,
579
+ train_dataset=train_dataset,
580
+ eval_dataset=eval_dataset,
581
+ tokenizer=tokenizer,
582
+ model_init=model_init,
583
+ compute_metrics=compute_metrics,
584
+ callbacks=callbacks,
585
+ optimizers=optimizers,
586
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
587
+ )
588
+
589
+ # Add tags for models that have been loaded with the correct transformers version
590
+ if hasattr(self.model, "add_model_tags"):
591
+ self.model.add_model_tags(self._tag_names)
592
+
593
+ if not hasattr(self, "accelerator"):
594
+ raise AttributeError(
595
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
596
+ )
597
+
598
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
599
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
600
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
601
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
602
+
603
+ if model is not None:
604
+ if hasattr(model, "config"):
605
+ hidden_size = (
606
+ max(model.config.hidden_sizes)
607
+ if getattr(model.config, "hidden_sizes", None)
608
+ else getattr(model.config, "hidden_size", None)
609
+ )
610
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
611
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
612
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
613
+ config_kwargs.update(
614
+ {
615
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
616
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
617
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
618
+ }
619
+ )
620
+
621
+ # If ZeRO-3 is used, we shard both the active and reference model.
622
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
623
+ if config_kwargs["zero_optimization"]["stage"] != 3:
624
+ config_kwargs["zero_optimization"]["stage"] = 0
625
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
626
+ model.eval()
627
+ return model
628
+
629
+ def build_tokenized_answer(self, prompt, answer):
630
+ """
631
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
632
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
633
+ Reference:
634
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
635
+ """
636
+
637
+ full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
638
+ prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
639
+
640
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
641
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
642
+
643
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
644
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
645
+
646
+ # Prepare input tokens for token by token comparison
647
+ full_input_ids = np.array(full_tokenized["input_ids"])
648
+
649
+ if len(full_input_ids) != len(full_concat_input_ids):
650
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
651
+
652
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
653
+ # can be merged together when tokenizing prompt+answer. This could result
654
+ # on the last token from the prompt being different when tokenized on its own
655
+ # vs when done as prompt+answer.
656
+ response_token_ids_start_idx = len(prompt_input_ids)
657
+
658
+ # If tokenized prompt is different than both prompt+answer, then it means the
659
+ # last token has changed due to merging.
660
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
661
+ response_token_ids_start_idx -= 1
662
+
663
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
664
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
665
+
666
+ if len(prompt_input_ids) != len(prompt_attention_mask):
667
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
668
+
669
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
670
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
671
+
672
+ return dict(
673
+ prompt_input_ids=prompt_input_ids,
674
+ prompt_attention_mask=prompt_attention_mask,
675
+ input_ids=answer_input_ids,
676
+ attention_mask=answer_attention_mask,
677
+ )
678
+
679
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
680
+ """Tokenize a single row from a ORPO specific dataset.
681
+
682
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
683
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
684
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
685
+
686
+ We also create the labels for the chosen/rejected responses, which are of length equal to
687
+ the sum of the length of the prompt and the chosen/rejected response, with
688
+ label_pad_token_id for the prompt tokens.
689
+ """
690
+ batch = {}
691
+ prompt = feature["prompt"]
692
+ chosen = feature["chosen"]
693
+ rejected = feature["rejected"]
694
+
695
+ if not self.is_encoder_decoder:
696
+ # Check issues below for more details
697
+ # 1. https://github.com/huggingface/trl/issues/907
698
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
699
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
700
+
701
+ if not isinstance(prompt, str):
702
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
703
+ prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
704
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
705
+
706
+ if not isinstance(chosen, str):
707
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
708
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
709
+
710
+ if not isinstance(rejected, str):
711
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
712
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
713
+
714
+ # Last prompt token might get merged by tokenizer and
715
+ # it should not be included for generation if that happens
716
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
717
+
718
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
719
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
720
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
721
+
722
+ for k, v in prompt_tokens.items():
723
+ prompt_tokens[k] = v[:prompt_len_input_ids]
724
+
725
+ # Make sure prompts only have one different token at most an
726
+ # and length only differs by 1 at most
727
+ num_diff_tokens = sum(
728
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
729
+ )
730
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
731
+ if num_diff_tokens > 1 or num_diff_len > 1:
732
+ raise ValueError(
733
+ "Chosen and rejected prompt_input_ids might only differ on the "
734
+ "last token due to tokenizer merge ops."
735
+ )
736
+
737
+ # add BOS token to head of prompt
738
+ prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
739
+ chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
740
+ rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
741
+
742
+ prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
743
+ chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
744
+ rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
745
+
746
+ # add EOS token to end of answer
747
+ chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
748
+ chosen_tokens["attention_mask"].append(1)
749
+
750
+ rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
751
+ rejected_tokens["attention_mask"].append(1)
752
+
753
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
754
+
755
+ # if combined sequence is too long, truncate the prompt
756
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
757
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
758
+ if self.truncation_mode == "keep_start":
759
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
760
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
761
+ elif self.truncation_mode == "keep_end":
762
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
763
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
764
+ else:
765
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
766
+
767
+ # if that's still too long, truncate the response
768
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
769
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
770
+ for k in ["input_ids", "attention_mask"]:
771
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
772
+
773
+ # Create labels
774
+ chosen_sequence_tokens = {
775
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
776
+ }
777
+ rejected_sequence_tokens = {
778
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
779
+ }
780
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
781
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
782
+ self.label_pad_token_id
783
+ ] * len(chosen_tokens["prompt_input_ids"])
784
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
785
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
786
+ self.label_pad_token_id
787
+ ] * len(rejected_tokens["prompt_input_ids"])
788
+
789
+ for k, toks in {
790
+ "chosen_": chosen_sequence_tokens,
791
+ "rejected_": rejected_sequence_tokens,
792
+ "": prompt_tokens,
793
+ }.items():
794
+ for type_key, tokens in toks.items():
795
+ if type_key == "token_type_ids":
796
+ continue
797
+ batch[f"{k}{type_key}"] = tokens
798
+
799
+ else:
800
+ chosen_tokens = self.tokenizer(
801
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
802
+ )
803
+ rejected_tokens = self.tokenizer(
804
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
805
+ )
806
+ prompt_tokens = self.tokenizer(
807
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
808
+ )
809
+
810
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
811
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
812
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
813
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
814
+
815
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
816
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
817
+ labels=torch.tensor(batch["rejected_labels"])
818
+ )
819
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
820
+ labels=torch.tensor(batch["chosen_labels"])
821
+ )
822
+
823
+ return batch
824
+
825
+ @staticmethod
826
+ def concatenated_inputs(
827
+ batch: Dict[str, Union[List, torch.LongTensor]],
828
+ is_encoder_decoder: bool = False,
829
+ label_pad_token_id: int = -100,
830
+ padding_value: int = 0,
831
+ device: Optional[torch.device] = None,
832
+ ) -> Dict[str, torch.LongTensor]:
833
+ """Concatenate the chosen and rejected inputs into a single tensor.
834
+
835
+ Args:
836
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
837
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
838
+ label_pad_token_id: The label pad token id.
839
+ padding_value: The padding value to use for the concatenated inputs_ids.
840
+ device: The device for the concatenated inputs.
841
+
842
+ Returns:
843
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
844
+ """
845
+ concatenated_batch = {}
846
+
847
+ if is_encoder_decoder:
848
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
849
+ else:
850
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
851
+
852
+ for k in batch:
853
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
854
+ if "labels" in k or is_encoder_decoder:
855
+ pad_value = label_pad_token_id
856
+ elif k.endswith("_input_ids"):
857
+ pad_value = padding_value
858
+ elif k.endswith("_attention_mask"):
859
+ pad_value = 0
860
+ concatenated_key = k.replace("chosen", "concatenated")
861
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
862
+ for k in batch:
863
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
864
+ if "labels" in k or is_encoder_decoder:
865
+ pad_value = label_pad_token_id
866
+ elif k.endswith("_input_ids"):
867
+ pad_value = padding_value
868
+ elif k.endswith("_attention_mask"):
869
+ pad_value = 0
870
+ concatenated_key = k.replace("rejected", "concatenated")
871
+ concatenated_batch[concatenated_key] = torch.cat(
872
+ (
873
+ concatenated_batch[concatenated_key],
874
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
875
+ ),
876
+ dim=0,
877
+ ).to(device=device)
878
+
879
+ if is_encoder_decoder:
880
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
881
+ concatenated_batch["concatenated_attention_mask"] = (
882
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
883
+ )
884
+
885
+ return concatenated_batch
886
+
887
+ def odds_ratio_loss(
888
+ self,
889
+ policy_chosen_logps: torch.FloatTensor,
890
+ policy_rejected_logps: torch.FloatTensor,
891
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
892
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
893
+
894
+ Args:
895
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
896
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
897
+
898
+ Returns:
899
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
900
+ The losses tensor contains the ORPO loss for each example in the batch.
901
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
902
+ The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
903
+ The `log(sigmoid(log_odds_chosen))` for logging purposes.
904
+ """
905
+
906
+ # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
907
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
908
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
909
+ )
910
+ sig_ratio = F.sigmoid(log_odds)
911
+ ratio = torch.log(sig_ratio)
912
+ losses = self.beta * ratio
913
+
914
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
915
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
916
+
917
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
918
+
919
+ @staticmethod
920
+ def get_batch_logps(
921
+ logits: torch.FloatTensor,
922
+ labels: torch.LongTensor,
923
+ average_log_prob: bool = False,
924
+ label_pad_token_id: int = -100,
925
+ is_encoder_decoder: bool = False,
926
+ ) -> torch.FloatTensor:
927
+ """Compute the log probabilities of the given labels under the given logits.
928
+
929
+ Args:
930
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
931
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
932
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
933
+ label_pad_token_id: The label pad token id.
934
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
935
+
936
+ Returns:
937
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
938
+ """
939
+ if logits.shape[:-1] != labels.shape:
940
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
941
+
942
+ if not is_encoder_decoder:
943
+ labels = labels[:, 1:].clone()
944
+ logits = logits[:, :-1, :]
945
+ loss_mask = labels != label_pad_token_id
946
+
947
+ # dummy token; we'll ignore the losses on these tokens later
948
+ labels[labels == label_pad_token_id] = 0
949
+
950
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
951
+
952
+ if average_log_prob:
953
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
954
+ else:
955
+ return (per_token_logps * loss_mask).sum(-1)
956
+
957
+ def concatenated_forward(
958
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
959
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
960
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
961
+
962
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
963
+ """
964
+ concatenated_batch = self.concatenated_inputs(
965
+ batch,
966
+ is_encoder_decoder=self.is_encoder_decoder,
967
+ label_pad_token_id=self.label_pad_token_id,
968
+ padding_value=self.padding_value,
969
+ device=self.accelerator.device,
970
+ )
971
+ len_chosen = batch["chosen_labels"].shape[0]
972
+
973
+ model_kwargs = (
974
+ {
975
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
976
+ }
977
+ if self.is_encoder_decoder
978
+ else {}
979
+ )
980
+
981
+ outputs = model(
982
+ concatenated_batch["concatenated_input_ids"],
983
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
984
+ use_cache=False,
985
+ **model_kwargs,
986
+ )
987
+ all_logits = outputs.logits
988
+
989
+ def cross_entropy_loss(logits, labels):
990
+ if not self.is_encoder_decoder:
991
+ # Shift so that tokens < n predict n
992
+ logits = logits[..., :-1, :].contiguous()
993
+ labels = labels[..., 1:].contiguous()
994
+ # Flatten the tokens
995
+ loss_fct = nn.CrossEntropyLoss()
996
+ logits = logits.view(-1, logits.shape[-1])
997
+ labels = labels.view(-1)
998
+ # Enable model parallelism
999
+ labels = labels.to(logits.device)
1000
+ loss = loss_fct(logits, labels)
1001
+ return loss
1002
+
1003
+ if self.is_encoder_decoder:
1004
+ labels = concatenated_batch["concatenated_labels"].clone()
1005
+ else:
1006
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1007
+
1008
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1009
+
1010
+ all_logps = self.get_batch_logps(
1011
+ all_logits,
1012
+ concatenated_batch["concatenated_labels"],
1013
+ average_log_prob=True,
1014
+ is_encoder_decoder=self.is_encoder_decoder,
1015
+ label_pad_token_id=self.label_pad_token_id,
1016
+ )
1017
+
1018
+ chosen_logps = all_logps[:len_chosen]
1019
+ rejected_logps = all_logps[len_chosen:]
1020
+
1021
+ chosen_logits = all_logits[:len_chosen]
1022
+ rejected_logits = all_logits[len_chosen:]
1023
+
1024
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1025
+
1026
+ def get_batch_loss_metrics(
1027
+ self,
1028
+ model,
1029
+ batch: Dict[str, Union[List, torch.LongTensor]],
1030
+ train_eval: Literal["train", "eval"] = "train",
1031
+ ):
1032
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1033
+ metrics = {}
1034
+
1035
+ (
1036
+ policy_chosen_logps,
1037
+ policy_rejected_logps,
1038
+ policy_chosen_logits,
1039
+ policy_rejected_logits,
1040
+ policy_nll_loss,
1041
+ ) = self.concatenated_forward(model, batch)
1042
+
1043
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1044
+ policy_chosen_logps, policy_rejected_logps
1045
+ )
1046
+ # full ORPO loss
1047
+ loss = policy_nll_loss - losses.mean()
1048
+
1049
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1050
+
1051
+ prefix = "eval_" if train_eval == "eval" else ""
1052
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
1053
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
1054
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
1055
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
1056
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
1057
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
1058
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
1059
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
1060
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
1061
+ metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
1062
+ metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
1063
+
1064
+ return loss, metrics
1065
+
1066
+ def compute_loss(
1067
+ self,
1068
+ model: Union[PreTrainedModel, nn.Module],
1069
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1070
+ return_outputs=False,
1071
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
1072
+ if not self.use_dpo_data_collator:
1073
+ warnings.warn(
1074
+ "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1075
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1076
+ )
1077
+
1078
+ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1079
+
1080
+ with compute_loss_context_manager():
1081
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1082
+
1083
+ # force log the metrics
1084
+ self.store_metrics(metrics, train_eval="train")
1085
+
1086
+ if return_outputs:
1087
+ return (loss, metrics)
1088
+ return loss
1089
+
1090
+ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
1091
+ """Generate samples from the model and reference model for the given batch of inputs."""
1092
+
1093
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1094
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1095
+ generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
1096
+
1097
+ with generate_context_manager():
1098
+ policy_output = model.generate(
1099
+ input_ids=batch["prompt_input_ids"],
1100
+ attention_mask=batch["prompt_attention_mask"],
1101
+ max_length=self.max_length,
1102
+ do_sample=True,
1103
+ pad_token_id=self.tokenizer.pad_token_id,
1104
+ )
1105
+
1106
+ policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
1107
+ policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
1108
+
1109
+ return policy_output_decoded
1110
+
1111
+ def prediction_step(
1112
+ self,
1113
+ model: Union[PreTrainedModel, nn.Module],
1114
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1115
+ prediction_loss_only: bool,
1116
+ ignore_keys: Optional[List[str]] = None,
1117
+ ):
1118
+ if not self.use_dpo_data_collator:
1119
+ warnings.warn(
1120
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1121
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1122
+ )
1123
+ if ignore_keys is None:
1124
+ if hasattr(model, "config"):
1125
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1126
+ else:
1127
+ ignore_keys = []
1128
+
1129
+ prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1130
+
1131
+ with torch.no_grad(), prediction_context_manager():
1132
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1133
+
1134
+ # force log the metrics
1135
+ self.store_metrics(metrics, train_eval="eval")
1136
+
1137
+ if prediction_loss_only:
1138
+ return (loss.detach(), None, None)
1139
+
1140
+ # logits for the chosen and rejected samples from model
1141
+ logits_dict = {
1142
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1143
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1144
+ }
1145
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1146
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1147
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1148
+
1149
+ return (loss.detach(), logits, labels)
1150
+
1151
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1152
+ for key, value in metrics.items():
1153
+ self._stored_metrics[train_eval][key].append(value)
1154
+
1155
+ def evaluation_loop(
1156
+ self,
1157
+ dataloader: DataLoader,
1158
+ description: str,
1159
+ prediction_loss_only: Optional[bool] = None,
1160
+ ignore_keys: Optional[List[str]] = None,
1161
+ metric_key_prefix: str = "eval",
1162
+ ) -> EvalLoopOutput:
1163
+ """
1164
+ Overriding built-in evaluation loop to store metrics for each batch.
1165
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1166
+
1167
+ Works both with or without labels.
1168
+ """
1169
+
1170
+ # Sample and save to game log if requested (for one batch to save time)
1171
+ if self.generate_during_eval:
1172
+ # Generate random indices within the range of the total number of samples
1173
+ num_samples = len(dataloader.dataset)
1174
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1175
+
1176
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1177
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1178
+ random_batch = self.data_collator(random_batch_dataset)
1179
+ random_batch = self._prepare_inputs(random_batch)
1180
+
1181
+ policy_output_decoded = self.get_batch_samples(self.model, random_batch)
1182
+
1183
+ self.log(
1184
+ {
1185
+ "game_log": wandb.Table(
1186
+ columns=["Prompt", "Policy"],
1187
+ rows=[
1188
+ [prompt, pol[len(prompt) :]]
1189
+ for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1190
+ ],
1191
+ )
1192
+ }
1193
+ )
1194
+ self.state.log_history.pop()
1195
+
1196
+ # Base evaluation
1197
+ initial_output = super().evaluation_loop(
1198
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1199
+ )
1200
+
1201
+ return initial_output
1202
+
1203
+ def log(self, logs: Dict[str, float]) -> None:
1204
+ """
1205
+ Log `logs` on the various objects watching training, including stored metrics.
1206
+
1207
+ Args:
1208
+ logs (`Dict[str, float]`):
1209
+ The values to log.
1210
+ """
1211
+ # logs either has 'loss' or 'eval_loss'
1212
+ train_eval = "train" if "loss" in logs else "eval"
1213
+ # Add averaged stored metrics to logs
1214
+ for key, metrics in self._stored_metrics[train_eval].items():
1215
+ logs[key] = torch.tensor(metrics).mean().item()
1216
+ del self._stored_metrics[train_eval]
1217
+ return super().log(logs)
1218
+
1219
+ def _shift_right(self, input_ids):
1220
+ if self.decoder_start_token_id is None:
1221
+ raise ValueError(
1222
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1223
+ )
1224
+
1225
+ # shift inputs to the right
1226
+ if is_torch_fx_proxy(input_ids):
1227
+ # Item assignment is not supported natively for proxies.
1228
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1229
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1230
+ else:
1231
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1232
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1233
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1234
+
1235
+ if self.pad_token_id is None:
1236
+ raise ValueError("model.config.pad_token_id has to be defined.")
1237
+ # replace possible -100 values in labels by `pad_token_id`
1238
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1239
+
1240
+ return shifted_input_ids
1241
+
1242
+ @wraps(Trainer.push_to_hub)
1243
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1244
+ """
1245
+ Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the
1246
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
1247
+ """
1248
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1249
+
1250
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1251
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1252
+ """
1253
+
1254
+ Initialize ORPOTrainer.
1255
+
1256
+ Args:
1257
+ model (`transformers.PreTrainedModel`):
1258
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1259
+ args (`ORPOConfig`):
1260
+ The ORPO config arguments to use for training.
1261
+ data_collator (`transformers.DataCollator`):
1262
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1263
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1264
+ train_dataset (`datasets.Dataset`):
1265
+ The dataset to use for training.
1266
+ eval_dataset (`datasets.Dataset`):
1267
+ The dataset to use for evaluation.
1268
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
1269
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
1270
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1271
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1272
+ callbacks (`List[transformers.TrainerCallback]`):
1273
+ The callbacks to use for training.
1274
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1275
+ The optimizer and scheduler to use for training.
1276
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1277
+ The function to use to preprocess the logits before computing the metrics.
1278
+ peft_config (`Dict`, defaults to `None`):
1279
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1280
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
1281
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1282
+ a dictionary string to metric values.
1283
+
1284
+ """
1285
+ def __init__(
1286
+ self,
1287
+ model = None,
1288
+ args = None,
1289
+ data_collator = None,
1290
+ train_dataset = None,
1291
+ eval_dataset = None,
1292
+ tokenizer = None,
1293
+ model_init = None,
1294
+ callbacks = None,
1295
+ preprocess_logits_for_metrics = None,
1296
+ peft_config = None,
1297
+ compute_metrics = None,
1298
+ **kwargs
1299
+ ):
1300
+ if args is None: args = UnslothORPOConfig()
1301
+ use_bf16 = getattr(args, 'bf16', False)
1302
+ if type(use_bf16) is not bool: use_bf16 = False
1303
+ use_fp16 = getattr(args, 'fp16', False)
1304
+ if type(use_fp16) is not bool: use_fp16 = False
1305
+ force_float32 = False
1306
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1307
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1308
+ force_float32 = True
1309
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1310
+ dtype = getattr(model.config, 'torch_dtype', None)
1311
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1312
+ from unsloth_zoo.utils import _get_dtype
1313
+ dtype = _get_dtype(dtype)
1314
+ float16 = dtype == torch.float16
1315
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1316
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1317
+ if force_float32:
1318
+ args.fp16 = False
1319
+ args.bf16 = False
1320
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1321
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1322
+ args.fp16 = float16
1323
+ args.bf16 = not float16
1324
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1325
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1326
+ args.eval_strategy = 'steps'
1327
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1328
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1329
+ if ga_steps is not None and ga_steps > 1:
1330
+ from transformers import __version__ as transformers_version
1331
+ if Version(transformers_version) <= Version('4.45.2'):
1332
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1333
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1334
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1335
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1336
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1337
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1338
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1339
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1340
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1341
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1342
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1343
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1344
+ if force_float32:
1345
+ args.bf16_full_eval = False
1346
+ args.fp16_full_eval = False
1347
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1348
+ args.bf16_full_eval = True
1349
+ args.fp16_full_eval = False
1350
+ elif not bf16_full_eval and not fp16_full_eval:
1351
+ args.bf16_full_eval = args.bf16
1352
+ args.fp16_full_eval = args.fp16
1353
+ _output_logits = False
1354
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1355
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1356
+ if _output_logits:
1357
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1358
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1359
+ pass
1360
+ else:
1361
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1362
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1363
+ if args_max_seq_length is None and model_max_seq_length is not None:
1364
+ max_seq_length = model.max_seq_length
1365
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1366
+ if model is not None and hasattr(model, 'for_training'):
1367
+ model.for_training()
1368
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1369
+ if 'processing_class' in locals():
1370
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1371
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1372
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1373
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1374
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1375
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1376
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1377
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1378
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1379
+ else:
1380
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1381
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1382
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1383
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1384
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1385
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1386
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1387
+ else:
1388
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1389
+ other_metrics = []
1390
+
1391
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1392
+ PatchRLStatistics('orpo_trainer', other_metrics)
1393
+
1394
+ super().__init__(
1395
+ model = model,
1396
+ args = args,
1397
+ data_collator = data_collator,
1398
+ train_dataset = train_dataset,
1399
+ eval_dataset = eval_dataset,
1400
+ tokenizer = tokenizer,
1401
+ model_init = model_init,
1402
+ callbacks = callbacks,
1403
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1404
+ peft_config = peft_config,
1405
+ compute_metrics = compute_metrics,**kwargs)
1406
+ if hasattr(self, 'neftune_hook_handle'):
1407
+ self.neftune_hook_handle.remove()
1408
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1409
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1410
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1411
+ pass
1412
+
1413
+ pass
compilefcach/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ppo_trainer import (Accelerator, Adam, AdaptiveKLController, BaseTrainer, Callable, DataCollatorForLanguageModeling, Dataset, F, FixedKLController, List, MODEL_CARD_TEMPLATE, Optional, PPOConfig, PPODecorators, PPOTrainer, PreTrainedModelWrapper, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ProjectConfiguration, RunningMoments, SUPPORTED_ARCHITECTURES, Union, WANDB_PADDING, clip_by_value, convert_to_scalar, create_reference_model, datasets, entropy_from_logits, flatten_dict, gather_object, inspect, is_npu_available, is_torch_greater_2_0, is_xpu_available, logprobs_from_logits, masked_mean, masked_var, masked_whiten, math, np, nullcontext, os, set_seed, stack_dicts, stats_to_np, time, torch, typing, unwrap_model_for_generation, version, warnings, whoami)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPPOConfig(PPOConfig):
44
+ """
45
+
46
+ Configuration class for PPOTrainer
47
+
48
+ """
49
+ vllm_sampling_params: Optional[Any] = field(
50
+ default = None,
51
+ metadata = {'help': 'vLLM SamplingParams'},
52
+ )
53
+ unsloth_num_chunks : Optional[int] = field(
54
+ default = -1,
55
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
56
+ )
57
+ def __init__(
58
+ self,
59
+ exp_name = 'colab_kernel_launcher',
60
+ seed = 3407,
61
+ log_with = None,
62
+ task_name = None,
63
+ model_name = 'gpt2',
64
+ query_dataset = 'imdb',
65
+ reward_model = 'sentiment-analysis:lvwerra/distilbert-imdb',
66
+ remove_unused_columns = True,
67
+ tracker_project_name = 'trl',
68
+ steps = 20000,
69
+ learning_rate = 5e-05,
70
+ adap_kl_ctrl = True,
71
+ init_kl_coef = 0.2,
72
+ kl_penalty = 'kl',
73
+ target = 6,
74
+ horizon = 10000,
75
+ gamma = 1,
76
+ lam = 0.95,
77
+ cliprange = 0.2,
78
+ cliprange_value = 0.2,
79
+ vf_coef = 0.1,
80
+ batch_size = 128,
81
+ forward_batch_size = None,
82
+ mini_batch_size = 128,
83
+ gradient_accumulation_steps = 2,
84
+ world_size = None,
85
+ ppo_epochs = 4,
86
+ max_grad_norm = None,
87
+ optimize_cuda_cache = None,
88
+ optimize_device_cache = False,
89
+ early_stopping = False,
90
+ target_kl = 1,
91
+ compare_steps = 1,
92
+ ratio_threshold = 10.0,
93
+ use_score_scaling = False,
94
+ use_score_norm = False,
95
+ score_clip = None,
96
+ whiten_rewards = False,
97
+ is_encoder_decoder = None,
98
+ is_peft_model = None,
99
+ backward_batch_size = None,
100
+ global_backward_batch_size = None,
101
+ global_batch_size = None,
102
+ vllm_sampling_params = None,
103
+ unsloth_num_chunks = -1,
104
+ **kwargs,
105
+ ):
106
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
107
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
108
+
109
+ super().__init__(
110
+ exp_name = exp_name,
111
+ seed = seed,
112
+ log_with = log_with,
113
+ task_name = task_name,
114
+ model_name = model_name,
115
+ query_dataset = query_dataset,
116
+ reward_model = reward_model,
117
+ remove_unused_columns = remove_unused_columns,
118
+ tracker_project_name = tracker_project_name,
119
+ steps = steps,
120
+ learning_rate = learning_rate,
121
+ adap_kl_ctrl = adap_kl_ctrl,
122
+ init_kl_coef = init_kl_coef,
123
+ kl_penalty = kl_penalty,
124
+ target = target,
125
+ horizon = horizon,
126
+ gamma = gamma,
127
+ lam = lam,
128
+ cliprange = cliprange,
129
+ cliprange_value = cliprange_value,
130
+ vf_coef = vf_coef,
131
+ batch_size = batch_size,
132
+ forward_batch_size = forward_batch_size,
133
+ mini_batch_size = mini_batch_size,
134
+ gradient_accumulation_steps = gradient_accumulation_steps,
135
+ world_size = world_size,
136
+ ppo_epochs = ppo_epochs,
137
+ max_grad_norm = max_grad_norm,
138
+ optimize_cuda_cache = optimize_cuda_cache,
139
+ optimize_device_cache = optimize_device_cache,
140
+ early_stopping = early_stopping,
141
+ target_kl = target_kl,
142
+ compare_steps = compare_steps,
143
+ ratio_threshold = ratio_threshold,
144
+ use_score_scaling = use_score_scaling,
145
+ use_score_norm = use_score_norm,
146
+ score_clip = score_clip,
147
+ whiten_rewards = whiten_rewards,
148
+ is_encoder_decoder = is_encoder_decoder,
149
+ is_peft_model = is_peft_model,
150
+ backward_batch_size = backward_batch_size,
151
+ global_backward_batch_size = global_backward_batch_size,
152
+ global_batch_size = global_batch_size,**kwargs)
153
+ self.vllm_sampling_params = vllm_sampling_params
154
+ self.unsloth_num_chunks = unsloth_num_chunks
155
+ pass
156
+
157
+ class _UnslothPPOTrainer(BaseTrainer):
158
+ """"""
159
+
160
+ _tag_names = ["trl", "ppo"]
161
+
162
+ def __init__(
163
+ self,
164
+ config: Optional[PPOConfig] = None,
165
+ model: Optional[PreTrainedModelWrapper] = None,
166
+ ref_model: Optional[PreTrainedModelWrapper] = None,
167
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
168
+ dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
169
+ optimizer: Optional[torch.optim.Optimizer] = None,
170
+ data_collator: Optional[typing.Callable] = None,
171
+ num_shared_layers: Optional[int] = None,
172
+ lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
173
+ ):
174
+ """
175
+ Initialize PPOTrainer.
176
+
177
+ Args:
178
+ config (`PPOConfig`):
179
+ Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
180
+ model (`PreTrainedModelWrapper`):
181
+ Hugging Face transformer model with a value head.
182
+ ref_model (`PreTrainedModelWrapper`):
183
+ Hugging Face transformer model with a casual language modelling head. Used for KL penalty
184
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
185
+ Hugging Face tokenizer
186
+ dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
187
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
188
+ will be preprocessed by removing the columns that are not used by the model. If none is passed,
189
+ a warning will be raised in a multi-GPU setting.
190
+ optimizer (Optional[`torch.optim.Optimizer`]):
191
+ Optimizer used for training. If `None`, the `Adam` is used as default.
192
+ data_collator (Optional[function]):
193
+ Data collator function.
194
+ num_shared_layers (Optional[int]):
195
+ Number of shared layers between the model and the reference model. If `None`, all layers are shared.
196
+ used only if `ref_model` is `None`.
197
+ lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
198
+ Learning rate scheduler used for training.
199
+ """
200
+ super().__init__(config)
201
+
202
+ # initial seed for reproducible experiments
203
+ set_seed(config.seed)
204
+
205
+ # Step 0: check positional arguments validity
206
+ if not isinstance(config, PPOConfig):
207
+ raise ValueError(f"config must be a PPOConfig, got {type(config)}")
208
+ if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
209
+ raise ValueError(
210
+ f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
211
+ )
212
+ if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
213
+ raise ValueError(
214
+ f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
215
+ )
216
+ # Step 1: Initialize Accelerator
217
+ self.accelerator = Accelerator(
218
+ log_with=config.log_with,
219
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
220
+ project_config=ProjectConfiguration(**config.project_kwargs),
221
+ **config.accelerator_kwargs,
222
+ )
223
+
224
+ # Step 1.1 Runtime variables filled by the accelerator
225
+ config.world_size = self.accelerator.num_processes
226
+ config.global_backward_batch_size = config.backward_batch_size * config.world_size
227
+ config.global_batch_size = config.batch_size * config.world_size
228
+
229
+ self.model = model
230
+ self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
231
+ self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
232
+ self.is_peft_model = getattr(self.model, "is_peft_model", False)
233
+ config.is_encoder_decoder = self.is_encoder_decoder
234
+ config.is_peft_model = self.is_peft_model
235
+
236
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
237
+ self.accelerator.init_trackers(
238
+ config.tracker_project_name,
239
+ config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
240
+ init_kwargs=config.tracker_kwargs,
241
+ )
242
+ self.is_using_text_environment = getattr(config, "use_text_environment", False)
243
+
244
+ if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
245
+ self.ref_model = ref_model
246
+ if num_shared_layers is not None:
247
+ warnings.warn(
248
+ "num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
249
+ "model and the reference model and no layers are shared.",
250
+ UserWarning,
251
+ )
252
+ elif ref_model is None and not self.is_peft_model:
253
+ self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
254
+ elif self.is_peft_model:
255
+ self.ref_model = None
256
+ else:
257
+ raise ValueError(
258
+ f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
259
+ f"architectures are: {SUPPORTED_ARCHITECTURES} "
260
+ )
261
+ self.optional_peft_ctx = (
262
+ self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
263
+ if self.is_peft_model
264
+ else nullcontext
265
+ )
266
+
267
+ if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
268
+ raise ValueError(
269
+ "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
270
+ )
271
+ self.tokenizer = tokenizer
272
+
273
+ if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
274
+ raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
275
+ elif dataset is None:
276
+ warnings.warn(
277
+ "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
278
+ UserWarning,
279
+ )
280
+ self.dataset = dataset
281
+ self._signature_columns = None
282
+ if self.dataset is not None:
283
+ self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
284
+ elif self.dataset is None and self.accelerator.num_processes > 1:
285
+ warnings.warn(
286
+ "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
287
+ " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
288
+ " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
289
+ " refer to the documentation for more details.",
290
+ UserWarning,
291
+ )
292
+ self.dataloader = None
293
+ else:
294
+ self.dataloader = None
295
+
296
+ # Step 3: Initialize optimizer and data collator
297
+ self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
298
+ if optimizer is None:
299
+ self.optimizer = Adam(
300
+ filter(lambda p: p.requires_grad, self.model.parameters()),
301
+ lr=self.config.learning_rate,
302
+ )
303
+ else:
304
+ self.optimizer = optimizer
305
+
306
+ self.lr_scheduler = lr_scheduler
307
+ if self.lr_scheduler is not None:
308
+ lr_scheduler_class = (
309
+ torch.optim.lr_scheduler._LRScheduler
310
+ if not is_torch_greater_2_0()
311
+ else torch.optim.lr_scheduler.LRScheduler
312
+ )
313
+
314
+ if not isinstance(self.lr_scheduler, lr_scheduler_class):
315
+ raise ValueError(
316
+ "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
317
+ )
318
+
319
+ if self.config.adap_kl_ctrl:
320
+ self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
321
+ else:
322
+ self.kl_ctl = FixedKLController(self.config.init_kl_coef)
323
+
324
+ # Safety checkers for DS integration
325
+ is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
326
+ self.accelerator.state, "deepspeed_plugin"
327
+ )
328
+
329
+ (
330
+ self.model,
331
+ self.optimizer,
332
+ self.data_collator,
333
+ self.dataloader,
334
+ self.lr_scheduler,
335
+ ) = self.accelerator.prepare(
336
+ self.model,
337
+ self.optimizer,
338
+ self.data_collator,
339
+ self.dataloader,
340
+ self.lr_scheduler,
341
+ )
342
+ if is_deepspeed_used:
343
+ # Quantized models are already set on the correct device
344
+ if not self.is_peft_model and not (
345
+ getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
346
+ or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
347
+ ):
348
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
349
+ else:
350
+ self.ref_model = self.accelerator.prepare(self.ref_model)
351
+
352
+ # In a distributed setup, only logging needs to be performed on the main process
353
+ # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
354
+ # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
355
+ self.is_distributed = self.accelerator.num_processes > 1
356
+
357
+ # init the current step
358
+ self.current_step = 0
359
+
360
+ # init variables for pushing model to hub
361
+ if config.push_to_hub_if_best_kwargs:
362
+ if "repo_id" not in config.push_to_hub_if_best_kwargs:
363
+ raise ValueError("You have to specify repo_id in order to push the model to the hub!")
364
+ self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
365
+ self.compare_step = 0
366
+ self.highest_reward = torch.tensor(-float("inf"))
367
+
368
+ # post process for PP
369
+ if not getattr(self.model, "is_sequential_parallel", False):
370
+ self.current_device = self.accelerator.device
371
+ else:
372
+ if is_xpu_available():
373
+ self.current_device = torch.device("xpu:0")
374
+ elif is_npu_available():
375
+ self.current_device = torch.device("npu:0")
376
+ else:
377
+ self.current_device = torch.device("cuda:0")
378
+
379
+ PPODecorators.optimize_device_cache = self.config.optimize_device_cache
380
+
381
+ self.running = RunningMoments(self.accelerator)
382
+
383
+ def _filter_kwargs(self, kwargs, target_func):
384
+ """
385
+ filter the keyword arguments that are supported by the target function.
386
+
387
+ Args:
388
+ kwargs (dict):
389
+ Keyword arguments
390
+ target_func (function):
391
+ Target function
392
+ """
393
+ return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
394
+
395
+ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
396
+ """
397
+ Prepare the dataloader for training.
398
+
399
+ Args:
400
+ dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
401
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
402
+ will be preprocessed by removing the columns that are not used by the model.
403
+ data_collator (Optional[function]):
404
+ Data collator function.
405
+
406
+ Returns:
407
+ `torch.utils.data.DataLoader`: PyTorch dataloader
408
+ """
409
+ if isinstance(dataset, Dataset):
410
+ dataset = self._remove_unused_columns(dataset)
411
+ dataloader = torch.utils.data.DataLoader(
412
+ dataset,
413
+ batch_size=self.config.batch_size,
414
+ collate_fn=data_collator,
415
+ shuffle=True,
416
+ drop_last=True,
417
+ )
418
+ return dataloader
419
+
420
+ # Adapted from transformers.Trainer._set_signature_columns_if_needed
421
+ def _set_signature_columns_if_needed(self):
422
+ if self._signature_columns is None:
423
+ # Inspect model forward signature to keep only the arguments it accepts.
424
+ signature = inspect.signature(self.model.forward)
425
+ self._signature_columns = list(signature.parameters.keys())
426
+ # label => sentiment | we need query and response for logging purpose
427
+ self._signature_columns += ["label", "query", "response"]
428
+
429
+ # Adapted from transformers.Trainer._remove_unused_columns
430
+ def _remove_unused_columns(self, dataset: "Dataset"):
431
+ if not self.config.remove_unused_columns:
432
+ return dataset
433
+ self._set_signature_columns_if_needed()
434
+ signature_columns = self._signature_columns
435
+
436
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
437
+
438
+ columns = [k for k in signature_columns if k in dataset.column_names]
439
+
440
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
441
+ dataset.set_format(
442
+ type=dataset.format["type"],
443
+ columns=columns,
444
+ format_kwargs=dataset.format["format_kwargs"],
445
+ )
446
+ return dataset
447
+ else:
448
+ return dataset.remove_columns(ignored_columns)
449
+
450
+ def generate(
451
+ self,
452
+ query_tensor: Union[torch.Tensor, List[torch.Tensor]],
453
+ length_sampler: Optional[Callable] = None,
454
+ batch_size: int = 4,
455
+ return_prompt: bool = True,
456
+ generate_ref_response: bool = False,
457
+ **generation_kwargs,
458
+ ):
459
+ """
460
+ Generate response with the model given the query tensor.
461
+ call the `generate` method of the model.
462
+
463
+ Args:
464
+ query_tensor (`torch.LongTensor`):
465
+ A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
466
+ length_sampler (`Callable`, *optional*):
467
+ Callable that returns the number of newly generated tokens.
468
+ batch_size (`int`, *optional):
469
+ Batch size used for generation, defaults to `4`.
470
+ return_prompt (`bool`, *optional*):
471
+ If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
472
+ generate_ref_response (`bool`, *optional*):
473
+ If set to `True` the reference response is also generated, defaults to `False`.
474
+ generation_kwargs (dict[str, Any]):
475
+ Keyword arguments for generation.
476
+
477
+ Returns:
478
+ `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
479
+ """
480
+ if generate_ref_response:
481
+ ref_model = self.model if self.is_peft_model else self.ref_model
482
+ if isinstance(query_tensor, List):
483
+ response = self._generate_batched(
484
+ self.model,
485
+ query_tensor,
486
+ length_sampler=length_sampler,
487
+ batch_size=batch_size,
488
+ return_prompt=return_prompt,
489
+ **generation_kwargs,
490
+ )
491
+ if generate_ref_response:
492
+ ref_response = self._generate_batched(
493
+ ref_model,
494
+ query_tensor,
495
+ length_sampler=length_sampler,
496
+ batch_size=batch_size,
497
+ return_prompt=return_prompt,
498
+ **generation_kwargs,
499
+ )
500
+
501
+ else:
502
+ if len(query_tensor.shape) == 2:
503
+ raise ValueError(
504
+ "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
505
+ )
506
+
507
+ if length_sampler is not None:
508
+ generation_kwargs["max_new_tokens"] = length_sampler()
509
+
510
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
511
+ response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
512
+
513
+ if generate_ref_response:
514
+ with unwrap_model_for_generation(
515
+ ref_model, self.accelerator, is_peft_model=self.is_peft_model
516
+ ) as unwrapped_model:
517
+ ref_response = unwrapped_model.generate(
518
+ input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
519
+ )
520
+
521
+ if not return_prompt and not self.is_encoder_decoder:
522
+ response = response[:, query_tensor.shape[0] :]
523
+ if generate_ref_response:
524
+ ref_response = ref_response[:, query_tensor.shape[0] :]
525
+
526
+ if generate_ref_response:
527
+ return response, ref_response
528
+ return response
529
+
530
+ def _generate_batched(
531
+ self,
532
+ model: PreTrainedModelWrapper,
533
+ query_tensors: List[torch.Tensor],
534
+ length_sampler: Optional[Callable] = None,
535
+ batch_size: int = 4,
536
+ return_prompt: bool = True,
537
+ pad_to_multiple_of: Optional[int] = None,
538
+ remove_padding: bool = True,
539
+ **generation_kwargs,
540
+ ):
541
+ outputs = []
542
+
543
+ padding_side_default = self.tokenizer.padding_side
544
+ if not self.is_encoder_decoder:
545
+ self.tokenizer.padding_side = "left"
546
+
547
+ # in case we have fewer examples than bs
548
+ batch_size = min(len(query_tensors), batch_size)
549
+
550
+ for i in range(0, len(query_tensors), batch_size):
551
+ if length_sampler is not None:
552
+ generation_kwargs["max_new_tokens"] = length_sampler()
553
+
554
+ # prevent overflow if query tensors are not even multiple of bs
555
+ end_index = min(len(query_tensors), i + batch_size)
556
+
557
+ batch = query_tensors[i:end_index]
558
+ batch_mask = [torch.ones_like(element) for element in batch]
559
+ inputs = {"input_ids": batch, "attention_mask": batch_mask}
560
+
561
+ padded_inputs = self.tokenizer.pad(
562
+ inputs,
563
+ padding=True,
564
+ max_length=None,
565
+ pad_to_multiple_of=pad_to_multiple_of,
566
+ return_tensors="pt",
567
+ ).to(self.current_device)
568
+
569
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
570
+ generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
571
+
572
+ for generation, mask in zip(generations, padded_inputs["attention_mask"]):
573
+ if not self.is_encoder_decoder:
574
+ output = generation[(1 - mask).sum() :] # remove padding
575
+ else:
576
+ output = generation
577
+
578
+ if not return_prompt and not self.is_encoder_decoder:
579
+ output = output[(mask).sum() :] # remove prompt
580
+
581
+ if remove_padding and self.tokenizer.eos_token_id in output:
582
+ pad_mask = output == self.tokenizer.eos_token_id
583
+ pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
584
+ output = output[: pad_start + 1] # keep the eos token at the end
585
+
586
+ outputs.append(output)
587
+
588
+ self.tokenizer.padding_side = padding_side_default
589
+ return outputs
590
+
591
+ def _step_safety_checker(
592
+ self,
593
+ batch_size: int,
594
+ queries: List[torch.LongTensor],
595
+ responses: List[torch.LongTensor],
596
+ scores: List[torch.FloatTensor],
597
+ masks: Optional[List[torch.LongTensor]] = None,
598
+ ):
599
+ """
600
+ Check if the input data is valid for training.
601
+
602
+ Args:
603
+ batch_size (int):
604
+ Batch size from the config file.
605
+ queries (List[`torch.LongTensor`]):
606
+ List of tensors containing the encoded queries of shape (`query_length`)
607
+ responses (List[`torch.LongTensor`]):
608
+ List of tensors containing the encoded responses of shape (`response_length`)
609
+ scores (List[`torch.FloatTensor`]):
610
+ List of tensors containing the scores.
611
+ masks (List[`torch.LongTensor`], *optional*):
612
+ list of optional tensors containing the masks of shape (`query_length` + `response_length`)
613
+ Returns:
614
+ `tuple`: The input processed data.
615
+ """
616
+ for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
617
+ if not isinstance(tensor_list, list):
618
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
619
+ if not isinstance(tensor_list[0], torch.Tensor):
620
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
621
+ if batch_size is not None and len(tensor_list) != batch_size:
622
+ raise ValueError(
623
+ f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
624
+ )
625
+
626
+ # add queries, scores and responses on the correct device
627
+ queries = [tensor.to(self.current_device) for tensor in queries]
628
+ responses = [tensor.to(self.current_device) for tensor in responses]
629
+ scores = [tensor.to(self.current_device) for tensor in scores]
630
+ masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
631
+
632
+ # squeeze scores if needed
633
+ for i, score in enumerate(scores):
634
+ if score.dim() > 1:
635
+ raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
636
+ elif score.dim() == 1:
637
+ scores[i] = score.squeeze()
638
+
639
+ return queries, responses, scores, masks
640
+
641
+ @PPODecorators.empty_device_cache()
642
+ def step(
643
+ self,
644
+ queries: List[torch.LongTensor],
645
+ responses: List[torch.LongTensor],
646
+ scores: List[torch.FloatTensor],
647
+ response_masks: Optional[List[torch.LongTensor]] = None,
648
+ ):
649
+ """
650
+ Run a PPO optimisation step given a list of queries, model responses, and rewards.
651
+
652
+ Args:
653
+ queries (List[`torch.LongTensor`]):
654
+ List of tensors containing the encoded queries of shape (`query_length`)
655
+ responses (List[`torch.LongTensor`]):
656
+ List of tensors containing the encoded responses of shape (`response_length`)
657
+ scores (List[`torch.FloatTensor`]):
658
+ List of tensors containing the scores.
659
+ response_masks (List[`torch.FloatTensor`], *optional*)):
660
+ List of tensors containing masks of the response tokens.
661
+
662
+ Returns:
663
+ `dict[str, Any]`: A summary of the training statistics
664
+ """
665
+ bs = self.config.batch_size
666
+
667
+ queries, responses, scores, response_masks = self._step_safety_checker(
668
+ bs, queries, responses, scores, response_masks
669
+ )
670
+ scores = torch.tensor(scores, device=self.current_device)
671
+ if self.config.use_score_scaling:
672
+ # Score scaling
673
+ scores_mean, scores_std = self.running.update(scores)
674
+ tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
675
+ score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
676
+ if self.config.use_score_norm:
677
+ scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
678
+ else:
679
+ scores /= score_scaling_factor
680
+
681
+ if self.config.score_clip is not None:
682
+ # Score clipping
683
+ scores_dtype = scores.dtype
684
+ scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
685
+
686
+ # if we want to push best model to the hub
687
+ if hasattr(self, "highest_reward"):
688
+ if self.compare_step % self.config.compare_steps == 0:
689
+ curr_mean_reward = scores.mean()
690
+ # if the best reward ever seen
691
+ if curr_mean_reward > self.highest_reward:
692
+ self.highest_reward = curr_mean_reward
693
+ # push model to hub
694
+ self.push_to_hub(**self.push_to_hub_kwargs)
695
+ self.compare_step += 1
696
+
697
+ timing = dict()
698
+ t0 = time.time()
699
+
700
+ t = time.time()
701
+
702
+ model_inputs = self.prepare_model_inputs(queries, responses)
703
+
704
+ if self.is_distributed:
705
+ pad_first = self.tokenizer.padding_side == "left"
706
+
707
+ model_inputs["input_ids"] = self.accelerator.pad_across_processes(
708
+ model_inputs["input_ids"],
709
+ dim=1,
710
+ pad_index=self.tokenizer.pad_token_id,
711
+ pad_first=pad_first,
712
+ )
713
+ model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
714
+ model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
715
+ )
716
+ if self.is_encoder_decoder:
717
+ model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
718
+ model_inputs["decoder_input_ids"],
719
+ dim=1,
720
+ pad_index=self.tokenizer.pad_token_id,
721
+ pad_first=pad_first,
722
+ )
723
+ model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
724
+ model_inputs["decoder_attention_mask"],
725
+ dim=1,
726
+ pad_index=0,
727
+ pad_first=pad_first,
728
+ )
729
+
730
+ model_inputs_names = list(model_inputs.keys())
731
+
732
+ full_kl_penalty = self.config.kl_penalty == "full"
733
+
734
+ with torch.no_grad():
735
+ all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
736
+ self.model,
737
+ queries,
738
+ responses,
739
+ model_inputs,
740
+ response_masks=response_masks,
741
+ return_logits=full_kl_penalty,
742
+ )
743
+ with self.optional_peft_ctx():
744
+ ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
745
+ self.model if self.is_peft_model else self.ref_model,
746
+ queries,
747
+ responses,
748
+ model_inputs,
749
+ return_logits=full_kl_penalty,
750
+ )
751
+
752
+ timing["time/ppo/forward_pass"] = time.time() - t
753
+
754
+ with torch.no_grad():
755
+ t = time.time()
756
+ if full_kl_penalty:
757
+ active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
758
+ ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
759
+
760
+ rewards, non_score_reward, kls = self.compute_rewards(
761
+ scores, active_full_logprobs, ref_full_logprobs, masks
762
+ )
763
+ else:
764
+ rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
765
+ timing["time/ppo/compute_rewards"] = time.time() - t
766
+
767
+ t = time.time()
768
+ values, advantages, returns = self.compute_advantages(values, rewards, masks)
769
+ timing["time/ppo/compute_advantages"] = time.time() - t
770
+
771
+ # upcast to float32 to avoid dataset issues
772
+ batch_dict = {
773
+ "queries": queries,
774
+ "responses": responses,
775
+ "logprobs": all_logprobs.to(torch.float32),
776
+ "values": values.to(torch.float32),
777
+ "masks": masks,
778
+ "advantages": advantages,
779
+ "returns": returns,
780
+ }
781
+ batch_dict.update(model_inputs)
782
+
783
+ t = time.time()
784
+ all_stats = []
785
+ early_stop = False
786
+ for _ in range(self.config.ppo_epochs):
787
+ if early_stop:
788
+ break
789
+ b_inds = np.random.permutation(bs)
790
+ for backward_batch_start in range(0, bs, self.config.backward_batch_size):
791
+ backward_batch_end = backward_batch_start + self.config.backward_batch_size
792
+ backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
793
+
794
+ for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
795
+ mini_batch_end = mini_batch_start + self.config.mini_batch_size
796
+ mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
797
+ mini_batch_dict = {
798
+ "logprobs": batch_dict["logprobs"][mini_batch_inds],
799
+ "values": batch_dict["values"][mini_batch_inds],
800
+ "masks": batch_dict["masks"][mini_batch_inds],
801
+ # hacks: the queries and responses are ragged.
802
+ "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
803
+ "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
804
+ "advantages": batch_dict["advantages"][mini_batch_inds],
805
+ "returns": batch_dict["returns"][mini_batch_inds],
806
+ }
807
+ for k in model_inputs_names:
808
+ mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
809
+ with self.accelerator.accumulate(self.model):
810
+ model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
811
+
812
+ logprobs, logits, vpreds, _ = self.batched_forward_pass(
813
+ self.model,
814
+ mini_batch_dict["queries"],
815
+ mini_batch_dict["responses"],
816
+ model_inputs,
817
+ return_logits=True,
818
+ )
819
+ train_stats = self.train_minibatch(
820
+ mini_batch_dict["logprobs"],
821
+ mini_batch_dict["values"],
822
+ logprobs,
823
+ logits,
824
+ vpreds,
825
+ mini_batch_dict["masks"],
826
+ mini_batch_dict["advantages"],
827
+ mini_batch_dict["returns"],
828
+ )
829
+ all_stats.append(train_stats)
830
+
831
+ # typically, early stopping is done at the epoch level
832
+ if self.config.early_stopping:
833
+ policykl = train_stats["policy/policykl"]
834
+ early_stop = self._early_stop(policykl)
835
+ if early_stop:
836
+ break
837
+
838
+ timing["time/ppo/optimize_step"] = time.time() - t
839
+
840
+ t = time.time()
841
+ train_stats = stack_dicts(all_stats)
842
+
843
+ # reshape advantages/ratios such that they are not averaged.
844
+ train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
845
+ train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
846
+ train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
847
+
848
+ stats = self.record_step_stats(
849
+ scores=scores,
850
+ logprobs=all_logprobs,
851
+ ref_logprobs=ref_logprobs,
852
+ non_score_reward=non_score_reward,
853
+ train_stats=train_stats,
854
+ kl_coef=self.kl_ctl.value,
855
+ masks=masks,
856
+ queries=queries,
857
+ responses=responses,
858
+ kls=kls,
859
+ )
860
+ # Gather/Reduce stats from all processes
861
+ if self.is_distributed:
862
+ stats = self.gather_stats(stats)
863
+ stats = stats_to_np(stats)
864
+ timing["time/ppo/calc_stats"] = time.time() - t
865
+ stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
866
+
867
+ # Update the KL control - multiply the batch_size by the number of processes
868
+ self.kl_ctl.update(
869
+ stats["objective/kl"],
870
+ self.config.batch_size * self.accelerator.num_processes,
871
+ )
872
+
873
+ # Log the total ppo time
874
+ timing["time/ppo/total"] = time.time() - t0
875
+ stats.update(timing)
876
+
877
+ # post-process stats for tensorboard and other loggers
878
+ if self.config.log_with != "wandb":
879
+ stats = convert_to_scalar(stats)
880
+
881
+ if self.lr_scheduler is not None:
882
+ self.lr_scheduler.step()
883
+
884
+ return stats
885
+
886
+ def _early_stop(self, policykl):
887
+ r"""
888
+ Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
889
+ the optimization step is skipped.
890
+ This also handles the multi-gpu case where the policy KL is averaged across all processes.
891
+
892
+ Args:
893
+ policy_kl (torch.Tensor):
894
+ the policy KL
895
+
896
+ Returns:
897
+ `bool`: whether to early stop or not
898
+ """
899
+ early_stop = False
900
+ if not self.config.early_stopping:
901
+ return early_stop
902
+
903
+ if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
904
+ self.optimizer.zero_grad()
905
+ early_stop = True
906
+ elif self.is_distributed:
907
+ import torch.distributed as dist
908
+
909
+ # Wait for all processes to finish
910
+ dist.barrier()
911
+
912
+ # all gather the policykl
913
+ dist.all_reduce(policykl, dist.ReduceOp.SUM)
914
+ policykl /= self.accelerator.num_processes
915
+
916
+ if policykl > 1.5 * self.config.target_kl:
917
+ self.optimizer.zero_grad()
918
+ early_stop = True
919
+ return early_stop
920
+
921
+ def gather_stats(self, stats):
922
+ """
923
+ Gather stats from all processes. Useful in the context of distributed training.
924
+
925
+ Args:
926
+ stats (dict[str, Any]):
927
+ a dictionary of stats to be gathered. The stats should contain torch tensors.
928
+
929
+ Returns:
930
+ `dict[str, Any]`: A dictionary of stats with the tensors gathered.
931
+ """
932
+ import torch.distributed as dist
933
+
934
+ # Wait for all processes to finish
935
+ dist.barrier()
936
+
937
+ for k, v in stats.items():
938
+ if isinstance(v, torch.Tensor):
939
+ dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
940
+ v /= self.accelerator.num_processes
941
+ stats[k] = v
942
+ return stats
943
+
944
+ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
945
+ if self.is_encoder_decoder:
946
+ input_data = self.data_collator(
947
+ [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
948
+ ).to(self.current_device)
949
+
950
+ decoder_inputs = self.data_collator(
951
+ [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
952
+ ).to(self.current_device)
953
+
954
+ input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
955
+ input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
956
+ else:
957
+ input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
958
+ input_data = self.data_collator(
959
+ [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
960
+ ).to(self.current_device)
961
+
962
+ input_data.pop("labels", None) # we don't want to compute LM losses
963
+ return input_data
964
+
965
+ @PPODecorators.empty_device_cache()
966
+ def batched_forward_pass(
967
+ self,
968
+ model: PreTrainedModelWrapper,
969
+ queries: torch.Tensor,
970
+ responses: torch.Tensor,
971
+ model_inputs: dict,
972
+ return_logits: bool = False,
973
+ response_masks: Optional[torch.Tensor] = None,
974
+ ):
975
+ """
976
+ Calculate model outputs in multiple batches.
977
+
978
+ Args:
979
+ queries (`torch.LongTensor`):
980
+ List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
981
+ responses (`torch.LongTensor`):
982
+ List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
983
+ return_logits (`bool`, *optional*, defaults to `False`):
984
+ Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
985
+ Returns:
986
+ (tuple):
987
+ - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
988
+ shape (`batch_size`, `response_length`)
989
+ - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
990
+ shape (`batch_size`, `response_length`)
991
+ - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
992
+ """
993
+ bs = len(queries)
994
+ fbs = self.config.mini_batch_size
995
+ all_logprobs = []
996
+ all_logits = []
997
+ all_masks = []
998
+ all_values = []
999
+
1000
+ model.eval()
1001
+
1002
+ for i in range(math.ceil(bs / fbs)):
1003
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
1004
+ query_batch = queries[i * fbs : (i + 1) * fbs]
1005
+ response_batch = responses[i * fbs : (i + 1) * fbs]
1006
+ if response_masks is not None:
1007
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
1008
+ logits, _, values = model(**input_kwargs)
1009
+
1010
+ if self.is_encoder_decoder:
1011
+ input_ids = input_kwargs["decoder_input_ids"]
1012
+ attention_mask = input_kwargs["decoder_attention_mask"]
1013
+ else:
1014
+ input_ids = input_kwargs["input_ids"]
1015
+ attention_mask = input_kwargs["attention_mask"]
1016
+
1017
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
1018
+ masks = torch.zeros_like(attention_mask)
1019
+ masks[:, :-1] = attention_mask[:, 1:]
1020
+
1021
+ for j in range(len(query_batch)):
1022
+ if self.is_encoder_decoder:
1023
+ # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
1024
+ start = 1
1025
+ end = attention_mask[j, :].sum() - 1
1026
+ else:
1027
+ start = len(query_batch[j]) - 1 # logprobs starts from the second query token
1028
+ if attention_mask[j, 0] == 0: # offset left padding
1029
+ start += attention_mask[j, :].nonzero()[0]
1030
+ end = start + len(response_batch[j])
1031
+ if response_masks is not None:
1032
+ response_masks_batch[j] = torch.cat(
1033
+ (torch.zeros_like(query_batch[j]), response_masks_batch[j])
1034
+ )[1:]
1035
+
1036
+ masks[j, :start] = 0
1037
+ masks[j, end:] = 0
1038
+ if response_masks is not None:
1039
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
1040
+
1041
+ if return_logits:
1042
+ all_logits.append(logits)
1043
+ else:
1044
+ del logits
1045
+ all_values.append(values)
1046
+ all_logprobs.append(logprobs)
1047
+ all_masks.append(masks)
1048
+
1049
+ return (
1050
+ torch.cat(all_logprobs),
1051
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
1052
+ torch.cat(all_values)[:, :-1],
1053
+ torch.cat(all_masks)[:, :-1],
1054
+ )
1055
+
1056
+ @PPODecorators.empty_device_cache()
1057
+ def train_minibatch(
1058
+ self,
1059
+ old_logprobs: torch.FloatTensor,
1060
+ values: torch.FloatTensor,
1061
+ logprobs: torch.FloatTensor,
1062
+ logits: torch.FloatTensor,
1063
+ vpreds: torch.FloatTensor,
1064
+ mask: torch.LongTensor,
1065
+ advantages: torch.FloatTensor,
1066
+ returns: torch.FloatTensor,
1067
+ ):
1068
+ """
1069
+ Train one PPO minibatch
1070
+
1071
+ Args:
1072
+ logprobs (`torch.FloatTensor`):
1073
+ Log probabilities of the model, shape [mini_batch_size, response_length]
1074
+ values (`torch.FloatTensor`):
1075
+ Values of the value head, shape [mini_batch_size, response_length]
1076
+ query (`torch.LongTensor`):
1077
+ Encoded queries, shape [mini_batch_size, query_length]
1078
+ response (`torch.LongTensor`):
1079
+ Encoded responses, shape [mini_batch_size, response_length]
1080
+ model_input (`torch.LongTensor`):
1081
+ Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
1082
+
1083
+ Returns:
1084
+ train_stats (dict[str, `torch.Tensor`]):
1085
+ Dictionary of training statistics
1086
+ """
1087
+ self.model.train()
1088
+ loss_p, loss_v, train_stats = self.loss(
1089
+ old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
1090
+ )
1091
+ loss = loss_p + loss_v
1092
+ self.accelerator.backward(loss)
1093
+ if self.config.max_grad_norm is not None:
1094
+ if self.accelerator.sync_gradients:
1095
+ self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
1096
+ self.optimizer.step()
1097
+ # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
1098
+ # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
1099
+ self.optimizer.zero_grad()
1100
+ return train_stats
1101
+
1102
+ def compute_rewards(
1103
+ self,
1104
+ scores: torch.FloatTensor,
1105
+ logprobs: torch.FloatTensor,
1106
+ ref_logprobs: torch.FloatTensor,
1107
+ masks: torch.LongTensor,
1108
+ ):
1109
+ """
1110
+ Compute per token rewards from scores and KL-penalty.
1111
+
1112
+ Args:
1113
+ scores (`torch.FloatTensor`):
1114
+ Scores from the reward model, shape (`batch_size`)
1115
+ logprobs (`torch.FloatTensor`):
1116
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1117
+ ref_logprobs (`torch.FloatTensor`):
1118
+ Log probabilities of the reference model, shape (`batch_size`, `response_length`)
1119
+
1120
+ Returns:
1121
+ `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
1122
+ `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
1123
+ `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
1124
+ """
1125
+ rewards, non_score_rewards, kls = [], [], []
1126
+ for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
1127
+ # compute KL penalty (from difference in logprobs)
1128
+ kl = self._kl_penalty(logprob, ref_logprob)
1129
+ kls.append(kl)
1130
+ non_score_reward = -self.kl_ctl.value * kl
1131
+ non_score_rewards.append(non_score_reward)
1132
+ reward = non_score_reward.clone()
1133
+ last_non_masked_index = mask.nonzero()[-1]
1134
+
1135
+ # reward is preference model score + KL penalty
1136
+ reward[last_non_masked_index] += score
1137
+ rewards.append(reward)
1138
+ return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
1139
+
1140
+ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
1141
+ if self.config.kl_penalty == "kl":
1142
+ return logprob - ref_logprob
1143
+
1144
+ if self.config.kl_penalty == "abs":
1145
+ return (logprob - ref_logprob).abs()
1146
+
1147
+ if self.config.kl_penalty == "mse":
1148
+ return 0.5 * (logprob - ref_logprob).square()
1149
+
1150
+ if self.config.kl_penalty == "full":
1151
+ # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
1152
+ return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
1153
+
1154
+ raise NotImplementedError
1155
+
1156
+ def compute_advantages(
1157
+ self,
1158
+ values: torch.FloatTensor,
1159
+ rewards: torch.FloatTensor,
1160
+ mask: torch.FloatTensor,
1161
+ ):
1162
+ lastgaelam = 0
1163
+ advantages_reversed = []
1164
+ gen_len = rewards.shape[-1]
1165
+
1166
+ values = values * mask
1167
+ rewards = rewards * mask
1168
+
1169
+ if self.config.whiten_rewards:
1170
+ rewards = masked_whiten(rewards, mask, shift_mean=False)
1171
+
1172
+ for t in reversed(range(gen_len)):
1173
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
1174
+ delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
1175
+ lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
1176
+ advantages_reversed.append(lastgaelam)
1177
+ advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
1178
+
1179
+ returns = advantages + values
1180
+ advantages = masked_whiten(advantages, mask)
1181
+ advantages = advantages.detach()
1182
+ return values, advantages, returns
1183
+
1184
+ def loss(
1185
+ self,
1186
+ old_logprobs: torch.FloatTensor,
1187
+ values: torch.FloatTensor,
1188
+ logits: torch.FloatTensor,
1189
+ vpreds: torch.FloatTensor,
1190
+ logprobs: torch.FloatTensor,
1191
+ mask: torch.LongTensor,
1192
+ advantages: torch.FloatTensor,
1193
+ returns: torch.FloatTensor,
1194
+ ):
1195
+ """
1196
+ Calculate policy and value losses.
1197
+
1198
+ Args:
1199
+ old_logprobs (`torch.FloatTensor`):
1200
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1201
+ values (`torch.FloatTensor`):
1202
+ Values of the value head, shape (`batch_size`, `response_length`)
1203
+ rewards (`torch.FloatTensor`):
1204
+ Rewards from the reward model, shape (`batch_size`, `response_length`)
1205
+ logits (`torch.FloatTensor`):
1206
+ Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
1207
+ v_pred (`torch.FloatTensor`):
1208
+ Values of the value head, shape (`batch_size`, `response_length`)
1209
+ logprobs (`torch.FloatTensor`):
1210
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1211
+ """
1212
+
1213
+ vpredclipped = clip_by_value(
1214
+ vpreds,
1215
+ values - self.config.cliprange_value,
1216
+ values + self.config.cliprange_value,
1217
+ )
1218
+
1219
+ vf_losses1 = (vpreds - returns) ** 2
1220
+ vf_losses2 = (vpredclipped - returns) ** 2
1221
+ vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
1222
+ vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
1223
+
1224
+ ratio = torch.exp(logprobs - old_logprobs)
1225
+
1226
+ pg_losses = -advantages * ratio
1227
+ pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
1228
+
1229
+ pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
1230
+ pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
1231
+
1232
+ loss = pg_loss + self.config.vf_coef * vf_loss
1233
+
1234
+ avg_ratio = masked_mean(ratio, mask).item()
1235
+ if avg_ratio > self.config.ratio_threshold:
1236
+ warnings.warn(
1237
+ f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
1238
+ )
1239
+ pg_loss = pg_loss * 0.0
1240
+ vf_loss = vf_loss * 0.0
1241
+ loss = loss * 0.0
1242
+
1243
+ entropy = masked_mean(entropy_from_logits(logits), mask)
1244
+
1245
+ approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
1246
+ policykl = masked_mean(old_logprobs - logprobs, mask)
1247
+
1248
+ return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
1249
+ value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
1250
+
1251
+ stats = dict(
1252
+ loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
1253
+ policy=dict(
1254
+ entropy=entropy.detach(),
1255
+ approxkl=approxkl.detach(),
1256
+ policykl=policykl.detach(),
1257
+ clipfrac=pg_clipfrac.detach(),
1258
+ advantages=advantages.detach(),
1259
+ advantages_mean=masked_mean(advantages, mask).detach(),
1260
+ ratio=ratio.detach(),
1261
+ ),
1262
+ returns=dict(mean=return_mean.detach(), var=return_var.detach()),
1263
+ val=dict(
1264
+ vpred=masked_mean(vpreds, mask).detach(),
1265
+ error=masked_mean((vpreds - returns) ** 2, mask).detach(),
1266
+ clipfrac=vf_clipfrac.detach(),
1267
+ mean=value_mean.detach(),
1268
+ var=value_var.detach(),
1269
+ ),
1270
+ )
1271
+ return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
1272
+
1273
+ def record_step_stats(self, kl_coef: float, **data):
1274
+ """
1275
+ Record training step statistics.
1276
+ Args:
1277
+ kl_coef (`float`):
1278
+ KL coefficient
1279
+ data (`dict`):
1280
+ Dictionary of training step data
1281
+
1282
+ Returns:
1283
+ stats (`dict`):
1284
+ Dictionary of training step statistics
1285
+ """
1286
+ mask = data.pop("masks")
1287
+
1288
+ kls = data.pop("kls")
1289
+ kl_list = ((kls) * mask).sum(axis=-1)
1290
+ mean_kl = kl_list.mean()
1291
+ mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
1292
+
1293
+ mean_non_score_reward = masked_mean(
1294
+ data["non_score_reward"], mask
1295
+ ) # non_score_reward is size `batch_size`, `response_length`
1296
+ mean_scores = data["scores"].mean() # scores is size `batch_size`
1297
+ std_scores = data["scores"].std()
1298
+
1299
+ if mean_kl.item() < -1.0:
1300
+ # warn users
1301
+ warnings.warn(
1302
+ f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
1303
+ " sometimes this happens because the generation kwargs are not correctly set. Please make sure"
1304
+ " that the generation kwargs are set correctly, or review your training hyperparameters."
1305
+ )
1306
+
1307
+ stats = {
1308
+ "objective/kl": mean_kl,
1309
+ "objective/kl_dist": kl_list,
1310
+ "objective/logprobs": data["logprobs"],
1311
+ "objective/ref_logprobs": data["ref_logprobs"],
1312
+ "objective/kl_coef": kl_coef,
1313
+ "objective/entropy": mean_entropy,
1314
+ "ppo/mean_non_score_reward": mean_non_score_reward,
1315
+ "ppo/mean_scores": mean_scores,
1316
+ "ppo/std_scores": std_scores,
1317
+ }
1318
+
1319
+ # Log text properties
1320
+ query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
1321
+ response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
1322
+
1323
+ stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
1324
+ stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
1325
+ stats["tokens/queries_dist"] = query_lens.cpu().numpy()
1326
+ stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
1327
+ stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
1328
+ stats["tokens/responses_dist"] = response_lens.cpu().numpy()
1329
+
1330
+ for k, v in data["train_stats"].items():
1331
+ stats[f"ppo/{k}"] = torch.mean(v, axis=0)
1332
+ stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
1333
+ return stats
1334
+
1335
+ def log_stats(
1336
+ self,
1337
+ stats: dict,
1338
+ batch: dict,
1339
+ rewards: List[torch.FloatTensor],
1340
+ columns_to_log: typing.Iterable[str] = ("query", "response"),
1341
+ ):
1342
+ """
1343
+ A function that logs all the training stats. Call it at the end of each epoch.
1344
+
1345
+ Args:
1346
+ stats (dict[str, Any]):
1347
+ A dictionary of training stats.
1348
+ batch (dict[str, Any]):
1349
+ A dictionary of batch data, this contains the queries and responses.
1350
+ rewards (`List[torch.FloatTensor]`):
1351
+ A tensor of rewards.
1352
+ """
1353
+
1354
+ # all gather stats
1355
+ if not isinstance(rewards, torch.Tensor):
1356
+ rewards = torch.tensor(rewards).to(self.current_device)
1357
+ rewards = self.accelerator.gather(rewards).flatten()
1358
+
1359
+ if self.config.log_with == "wandb":
1360
+ import wandb
1361
+
1362
+ if any(column_to_log not in batch.keys() for column_to_log in columns_to_log):
1363
+ raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
1364
+
1365
+ batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
1366
+ if self.is_distributed:
1367
+ gathered_batch_list = []
1368
+ for b in batch_list:
1369
+ flattened = gather_object(b)
1370
+ gathered_batch_list.append(flattened)
1371
+ batch_list = gathered_batch_list
1372
+
1373
+ # Log only if we are in the main process
1374
+ if self.accelerator.is_main_process:
1375
+ logs = {}
1376
+
1377
+ # Log stats
1378
+ if "query" not in batch.keys() and "response" not in batch.keys():
1379
+ # warn the user that the game logs will not be logged
1380
+ warnings.warn(
1381
+ "The game logs will not be logged because the batch does not contain the keys 'query' and "
1382
+ "'response'. "
1383
+ )
1384
+ elif self.config.log_with == "wandb":
1385
+ table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
1386
+ logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
1387
+
1388
+ logs.update(stats)
1389
+
1390
+ # manually cast in fp32 for bf16 torch tensors
1391
+ for k, v in logs.items():
1392
+ if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
1393
+ logs[k] = v.float()
1394
+
1395
+ logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
1396
+ logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
1397
+ logs["env/reward_dist"] = rewards.cpu().numpy()
1398
+
1399
+ if self.config.log_with == "tensorboard":
1400
+ # update the current step
1401
+ self.current_step += 1
1402
+
1403
+ self.accelerator.log(
1404
+ logs,
1405
+ step=self.current_step if self.config.log_with == "tensorboard" else None,
1406
+ )
1407
+
1408
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
1409
+ """Creates and saves a model card for a TRL model.
1410
+
1411
+ Args:
1412
+ path (`str`): The path to save the model card to.
1413
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
1414
+ """
1415
+ try:
1416
+ user = whoami()["name"]
1417
+ # handle the offline case
1418
+ except Exception:
1419
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
1420
+ return
1421
+
1422
+ if not os.path.exists(path):
1423
+ os.makedirs(path)
1424
+
1425
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
1426
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
1427
+ f.write(model_card_content)
1428
+
1429
+ def _save_pretrained(self, save_directory: str) -> None:
1430
+ self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
1431
+ self.tokenizer.save_pretrained(save_directory)
1432
+ self.create_model_card(save_directory)
1433
+
1434
+ def _show_tokens(self, tokens, masks):
1435
+ from rich import print
1436
+ from rich.text import Text
1437
+
1438
+ text = Text()
1439
+
1440
+ for _i, (token, mask) in enumerate(zip(tokens, masks)):
1441
+ if mask == 1:
1442
+ text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
1443
+ text.append(" ")
1444
+ else:
1445
+ text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
1446
+ text.append(" ")
1447
+ print(text)
1448
+
1449
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
1450
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
1451
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
1452
+ config_kwargs = deepspeed_plugin.deepspeed_config
1453
+ if model is not None:
1454
+ if hasattr(model, "config"):
1455
+ hidden_size = (
1456
+ max(model.config.hidden_sizes)
1457
+ if getattr(model.config, "hidden_sizes", None)
1458
+ else getattr(model.config, "hidden_size", None)
1459
+ )
1460
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
1461
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
1462
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
1463
+ config_kwargs.update(
1464
+ {
1465
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
1466
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
1467
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
1468
+ }
1469
+ )
1470
+
1471
+ # If ZeRO-3 is used, we shard both the active and reference model.
1472
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
1473
+ if config_kwargs["zero_optimization"]["stage"] != 3:
1474
+ config_kwargs["zero_optimization"]["stage"] = 0
1475
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
1476
+ model.eval()
1477
+ return model
1478
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1479
+ """
1480
+
1481
+ The PPOTrainer uses Proximal Policy Optimization to optimise language models.
1482
+ Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
1483
+ https://github.com/openai/summarize-from-feedback
1484
+
1485
+ Attributes:
1486
+ **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
1487
+ details.
1488
+ **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
1489
+ Check the documentation of `PreTrainedModelWrapper` for more details.
1490
+ **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
1491
+ transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
1492
+ for more details. If no reference model is provided, the trainer will create a reference model with the same
1493
+ architecture as the model to be optimized with shared layers.
1494
+ **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
1495
+ data. Check the documentation of `transformers.PreTrainedTokenizer` and
1496
+ `transformers.PreTrainedTokenizerFast` for more details.
1497
+ **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
1498
+ Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
1499
+ created outside the trainer users needs to design their own dataloader and make sure the batch
1500
+ size that is used is the same as the one specified in the configuration object.
1501
+ **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
1502
+ provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
1503
+ object.
1504
+ **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
1505
+ passed along the dataloader
1506
+ **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
1507
+ model, if no reference model is passed. If no number is provided, all the layers will be shared.
1508
+ **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
1509
+
1510
+ """
1511
+ def __init__(
1512
+ self,
1513
+ config = None,
1514
+ model = None,
1515
+ ref_model = None,
1516
+ tokenizer = None,
1517
+ dataset = None,
1518
+ optimizer = None,
1519
+ data_collator = None,
1520
+ num_shared_layers = None,
1521
+ lr_scheduler = None,
1522
+ **kwargs
1523
+ ):
1524
+ if args is None: args = UnslothPPOConfig()
1525
+ _output_logits = False
1526
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1527
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1528
+ if _output_logits:
1529
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1530
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1531
+ pass
1532
+ else:
1533
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1534
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1535
+ if args_max_seq_length is None and model_max_seq_length is not None:
1536
+ max_seq_length = model.max_seq_length
1537
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1538
+ if model is not None and hasattr(model, 'for_training'):
1539
+ model.for_training()
1540
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1541
+ if 'processing_class' in locals():
1542
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1543
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1544
+ other_metrics = []
1545
+
1546
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1547
+ PatchRLStatistics('ppo_trainer', other_metrics)
1548
+
1549
+ super().__init__(
1550
+ config = config,
1551
+ model = model,
1552
+ ref_model = ref_model,
1553
+ tokenizer = tokenizer,
1554
+ dataset = dataset,
1555
+ optimizer = optimizer,
1556
+ data_collator = data_collator,
1557
+ num_shared_layers = num_shared_layers,
1558
+ lr_scheduler = lr_scheduler,**kwargs)
1559
+ if hasattr(self, 'neftune_hook_handle'):
1560
+ self.neftune_hook_handle.remove()
1561
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1562
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1563
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1564
+ pass
1565
+
1566
+ pass
compilefcach/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.6.8
3
+ 2025.6.12
4
+ 4.53.0
5
+ 0.8.6
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.reward_trainer import (Any, Callable, DataCollator, Dataset, Dict, EvalPrediction, FrozenInstanceError, List, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, TrainingArguments, Tuple, Union, compute_accuracy, inspect, is_peft_available, nested_detach, nn, prepare_model_for_kbit_training, replace, torch, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRewardConfig(RewardConfig):
44
+ """
45
+
46
+ RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
47
+
48
+ Using [`HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int`, *optional*, defaults to `None`):
54
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
55
+ gradient_checkpointing (`bool`, *optional*, defaults to `True`):
56
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
57
+
58
+ """
59
+ vllm_sampling_params: Optional[Any] = field(
60
+ default = None,
61
+ metadata = {'help': 'vLLM SamplingParams'},
62
+ )
63
+ unsloth_num_chunks : Optional[int] = field(
64
+ default = -1,
65
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
66
+ )
67
+ def __init__(
68
+ self,
69
+ output_dir = None,
70
+ overwrite_output_dir = None,
71
+ do_train = False,
72
+ do_eval = False,
73
+ do_predict = False,
74
+ eval_strategy = 'no',
75
+ prediction_loss_only = False,
76
+ per_device_train_batch_size = 4,
77
+ per_device_eval_batch_size = 4,
78
+ per_gpu_train_batch_size = None,
79
+ per_gpu_eval_batch_size = None,
80
+ gradient_accumulation_steps = 2,
81
+ eval_accumulation_steps = 2,
82
+ eval_delay = 0,
83
+ torch_empty_cache_steps = 250,
84
+ learning_rate = 5e-05,
85
+ weight_decay = 0.01,
86
+ adam_beta1 = 0.9,
87
+ adam_beta2 = 0.999,
88
+ adam_epsilon = 1e-08,
89
+ max_grad_norm = 1.0,
90
+ num_train_epochs = 3.0,
91
+ max_steps = -1,
92
+ lr_scheduler_type = 'linear',
93
+ warmup_ratio = 0.1,
94
+ warmup_steps = 0,
95
+ log_level = 'passive',
96
+ log_level_replica = 'warning',
97
+ log_on_each_node = True,
98
+ logging_dir = None,
99
+ logging_strategy = 'steps',
100
+ logging_first_step = False,
101
+ logging_steps = 1,
102
+ logging_nan_inf_filter = False,
103
+ save_strategy = 'steps',
104
+ save_steps = 500,
105
+ save_total_limit = None,
106
+ save_safetensors = True,
107
+ save_on_each_node = False,
108
+ save_only_model = False,
109
+ restore_callback_states_from_checkpoint = False,
110
+ no_cuda = False,
111
+ use_cpu = False,
112
+ use_mps_device = False,
113
+ seed = 3407,
114
+ data_seed = 3407,
115
+ jit_mode_eval = False,
116
+ use_ipex = False,
117
+ bf16 = False,
118
+ fp16 = False,
119
+ fp16_opt_level = 'O1',
120
+ half_precision_backend = 'auto',
121
+ bf16_full_eval = False,
122
+ fp16_full_eval = False,
123
+ tf32 = None,
124
+ local_rank = -1,
125
+ ddp_backend = None,
126
+ tpu_num_cores = None,
127
+ tpu_metrics_debug = False,
128
+ debug = '',
129
+ dataloader_drop_last = False,
130
+ eval_steps = None,
131
+ dataloader_num_workers = 0,
132
+ dataloader_prefetch_factor = None,
133
+ past_index = -1,
134
+ run_name = None,
135
+ disable_tqdm = None,
136
+ remove_unused_columns = True,
137
+ label_names = None,
138
+ load_best_model_at_end = False,
139
+ metric_for_best_model = None,
140
+ greater_is_better = None,
141
+ ignore_data_skip = False,
142
+ fsdp = '',
143
+ fsdp_min_num_params = 0,
144
+ fsdp_config = None,
145
+ fsdp_transformer_layer_cls_to_wrap = None,
146
+ accelerator_config = None,
147
+ deepspeed = None,
148
+ label_smoothing_factor = 0.0,
149
+ optim = 'adamw_8bit',
150
+ optim_args = None,
151
+ adafactor = False,
152
+ group_by_length = False,
153
+ length_column_name = 'length',
154
+ report_to = None,
155
+ ddp_find_unused_parameters = None,
156
+ ddp_bucket_cap_mb = None,
157
+ ddp_broadcast_buffers = None,
158
+ dataloader_pin_memory = True,
159
+ dataloader_persistent_workers = False,
160
+ skip_memory_metrics = True,
161
+ use_legacy_prediction_loop = False,
162
+ push_to_hub = False,
163
+ resume_from_checkpoint = None,
164
+ hub_model_id = None,
165
+ hub_strategy = 'every_save',
166
+ hub_token = None,
167
+ hub_private_repo = None,
168
+ hub_always_push = False,
169
+ hub_revision = None,
170
+ gradient_checkpointing = False,
171
+ gradient_checkpointing_kwargs = None,
172
+ include_inputs_for_metrics = False,
173
+ eval_do_concat_batches = True,
174
+ fp16_backend = 'auto',
175
+ push_to_hub_model_id = None,
176
+ push_to_hub_organization = None,
177
+ push_to_hub_token = None,
178
+ mp_parameters = '',
179
+ auto_find_batch_size = False,
180
+ full_determinism = False,
181
+ torchdynamo = None,
182
+ ray_scope = 'last',
183
+ ddp_timeout = 1800,
184
+ torch_compile = False,
185
+ torch_compile_backend = None,
186
+ torch_compile_mode = None,
187
+ include_tokens_per_second = False,
188
+ include_num_input_tokens_seen = False,
189
+ neftune_noise_alpha = None,
190
+ optim_target_modules = None,
191
+ batch_eval_metrics = False,
192
+ eval_on_start = False,
193
+ use_liger_kernel = False,
194
+ liger_kernel_config = None,
195
+ eval_use_gather_object = False,
196
+ average_tokens_across_devices = False,
197
+ max_length = None,
198
+ vllm_sampling_params = None,
199
+ unsloth_num_chunks = -1,
200
+ **kwargs,
201
+ ):
202
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
203
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
204
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
205
+ output_dir = 'unsloth_training_checkpoints'
206
+ save_strategy = 'no'
207
+
208
+ super().__init__(
209
+ output_dir = output_dir,
210
+ overwrite_output_dir = overwrite_output_dir,
211
+ do_train = do_train,
212
+ do_eval = do_eval,
213
+ do_predict = do_predict,
214
+ eval_strategy = eval_strategy,
215
+ prediction_loss_only = prediction_loss_only,
216
+ per_device_train_batch_size = per_device_train_batch_size,
217
+ per_device_eval_batch_size = per_device_eval_batch_size,
218
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
219
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
220
+ gradient_accumulation_steps = gradient_accumulation_steps,
221
+ eval_accumulation_steps = eval_accumulation_steps,
222
+ eval_delay = eval_delay,
223
+ torch_empty_cache_steps = torch_empty_cache_steps,
224
+ learning_rate = learning_rate,
225
+ weight_decay = weight_decay,
226
+ adam_beta1 = adam_beta1,
227
+ adam_beta2 = adam_beta2,
228
+ adam_epsilon = adam_epsilon,
229
+ max_grad_norm = max_grad_norm,
230
+ num_train_epochs = num_train_epochs,
231
+ max_steps = max_steps,
232
+ lr_scheduler_type = lr_scheduler_type,
233
+ warmup_ratio = warmup_ratio,
234
+ warmup_steps = warmup_steps,
235
+ log_level = log_level,
236
+ log_level_replica = log_level_replica,
237
+ log_on_each_node = log_on_each_node,
238
+ logging_dir = logging_dir,
239
+ logging_strategy = logging_strategy,
240
+ logging_first_step = logging_first_step,
241
+ logging_steps = logging_steps,
242
+ logging_nan_inf_filter = logging_nan_inf_filter,
243
+ save_strategy = save_strategy,
244
+ save_steps = save_steps,
245
+ save_total_limit = save_total_limit,
246
+ save_safetensors = save_safetensors,
247
+ save_on_each_node = save_on_each_node,
248
+ save_only_model = save_only_model,
249
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
250
+ no_cuda = no_cuda,
251
+ use_cpu = use_cpu,
252
+ use_mps_device = use_mps_device,
253
+ seed = seed,
254
+ data_seed = data_seed,
255
+ jit_mode_eval = jit_mode_eval,
256
+ use_ipex = use_ipex,
257
+ bf16 = bf16,
258
+ fp16 = fp16,
259
+ fp16_opt_level = fp16_opt_level,
260
+ half_precision_backend = half_precision_backend,
261
+ bf16_full_eval = bf16_full_eval,
262
+ fp16_full_eval = fp16_full_eval,
263
+ tf32 = tf32,
264
+ local_rank = local_rank,
265
+ ddp_backend = ddp_backend,
266
+ tpu_num_cores = tpu_num_cores,
267
+ tpu_metrics_debug = tpu_metrics_debug,
268
+ debug = debug,
269
+ dataloader_drop_last = dataloader_drop_last,
270
+ eval_steps = eval_steps,
271
+ dataloader_num_workers = dataloader_num_workers,
272
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
273
+ past_index = past_index,
274
+ run_name = run_name,
275
+ disable_tqdm = disable_tqdm,
276
+ remove_unused_columns = remove_unused_columns,
277
+ label_names = label_names,
278
+ load_best_model_at_end = load_best_model_at_end,
279
+ metric_for_best_model = metric_for_best_model,
280
+ greater_is_better = greater_is_better,
281
+ ignore_data_skip = ignore_data_skip,
282
+ fsdp = fsdp,
283
+ fsdp_min_num_params = fsdp_min_num_params,
284
+ fsdp_config = fsdp_config,
285
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
286
+ accelerator_config = accelerator_config,
287
+ deepspeed = deepspeed,
288
+ label_smoothing_factor = label_smoothing_factor,
289
+ optim = optim,
290
+ optim_args = optim_args,
291
+ adafactor = adafactor,
292
+ group_by_length = group_by_length,
293
+ length_column_name = length_column_name,
294
+ report_to = report_to,
295
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
296
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
297
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
298
+ dataloader_pin_memory = dataloader_pin_memory,
299
+ dataloader_persistent_workers = dataloader_persistent_workers,
300
+ skip_memory_metrics = skip_memory_metrics,
301
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
302
+ push_to_hub = push_to_hub,
303
+ resume_from_checkpoint = resume_from_checkpoint,
304
+ hub_model_id = hub_model_id,
305
+ hub_strategy = hub_strategy,
306
+ hub_token = hub_token,
307
+ hub_private_repo = hub_private_repo,
308
+ hub_always_push = hub_always_push,
309
+ hub_revision = hub_revision,
310
+ gradient_checkpointing = gradient_checkpointing,
311
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
312
+ include_inputs_for_metrics = include_inputs_for_metrics,
313
+ eval_do_concat_batches = eval_do_concat_batches,
314
+ fp16_backend = fp16_backend,
315
+ push_to_hub_model_id = push_to_hub_model_id,
316
+ push_to_hub_organization = push_to_hub_organization,
317
+ push_to_hub_token = push_to_hub_token,
318
+ mp_parameters = mp_parameters,
319
+ auto_find_batch_size = auto_find_batch_size,
320
+ full_determinism = full_determinism,
321
+ torchdynamo = torchdynamo,
322
+ ray_scope = ray_scope,
323
+ ddp_timeout = ddp_timeout,
324
+ torch_compile = torch_compile,
325
+ torch_compile_backend = torch_compile_backend,
326
+ torch_compile_mode = torch_compile_mode,
327
+ include_tokens_per_second = include_tokens_per_second,
328
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
329
+ neftune_noise_alpha = neftune_noise_alpha,
330
+ optim_target_modules = optim_target_modules,
331
+ batch_eval_metrics = batch_eval_metrics,
332
+ eval_on_start = eval_on_start,
333
+ use_liger_kernel = use_liger_kernel,
334
+ liger_kernel_config = liger_kernel_config,
335
+ eval_use_gather_object = eval_use_gather_object,
336
+ average_tokens_across_devices = average_tokens_across_devices,
337
+ max_length = max_length,**kwargs)
338
+ self.vllm_sampling_params = vllm_sampling_params
339
+ self.unsloth_num_chunks = unsloth_num_chunks
340
+ pass
341
+
342
+ class _UnslothRewardTrainer(Trainer):
343
+ r""""""
344
+
345
+ _tag_names = ["trl", "reward-trainer"]
346
+
347
+ def __init__(
348
+ self,
349
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
350
+ args: Optional[RewardConfig] = None,
351
+ data_collator: Optional[DataCollator] = None,
352
+ train_dataset: Optional[Dataset] = None,
353
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
354
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
355
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
356
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
357
+ callbacks: Optional[List[TrainerCallback]] = None,
358
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
359
+ None,
360
+ None,
361
+ ),
362
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
363
+ max_length: Optional[int] = None,
364
+ peft_config: Optional[Dict] = None,
365
+ ):
366
+ """
367
+ Initialize RewardTrainer.
368
+
369
+ Args:
370
+ model (`transformers.PreTrainedModel`):
371
+ The model to train, preferably an `AutoModelForSequenceClassification`.
372
+ args (`RewardConfig`):
373
+ The arguments to use for training.
374
+ data_collator (`transformers.DataCollator`):
375
+ The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
376
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
377
+ train_dataset (`datasets.Dataset`):
378
+ The dataset to use for training.
379
+ eval_dataset (`datasets.Dataset`):
380
+ The dataset to use for evaluation.
381
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
382
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
383
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
384
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
385
+ compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):
386
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
387
+ callbacks (`List[transformers.TrainerCallback]`):
388
+ The callbacks to use for training.
389
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
390
+ The optimizer and scheduler to use for training.
391
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
392
+ The function to use to preprocess the logits before computing the metrics.
393
+ max_length (`int`, defaults to `None`):
394
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
395
+ peft_config (`Dict`, defaults to `None`):
396
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
397
+ """
398
+ if type(args) == TrainingArguments:
399
+ warnings.warn(
400
+ "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
401
+ FutureWarning,
402
+ )
403
+ if max_length is not None:
404
+ warnings.warn(
405
+ "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
406
+ FutureWarning,
407
+ )
408
+ else:
409
+ if max_length is not None and args.max_length is not None:
410
+ raise ValueError(
411
+ "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
412
+ )
413
+ if max_length is not None and args.max_length is None:
414
+ warnings.warn(
415
+ "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
416
+ FutureWarning,
417
+ )
418
+ if not is_peft_available() and peft_config is not None:
419
+ raise ValueError(
420
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
421
+ )
422
+ elif is_peft_available() and peft_config is not None:
423
+ if not isinstance(model, PeftModel):
424
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
425
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
426
+ inspect.signature(prepare_model_for_kbit_training).parameters
427
+ )
428
+
429
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
430
+
431
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
432
+ warnings.warn(
433
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
434
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
435
+ )
436
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
437
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
438
+
439
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
440
+
441
+ model = model
442
+
443
+ if compute_metrics is None:
444
+ compute_metrics = compute_accuracy
445
+
446
+ if data_collator is None:
447
+ if tokenizer is None:
448
+ raise ValueError(
449
+ "max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
450
+ )
451
+ if type(args) == TrainingArguments:
452
+ if max_length is None:
453
+ warnings.warn(
454
+ "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
455
+ " It will be set to `512` by default, but you should do it yourself in the future.",
456
+ UserWarning,
457
+ )
458
+ max_length = 512
459
+ else:
460
+ if max_length is None and args.max_length is None:
461
+ warnings.warn(
462
+ "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
463
+ " It will be set to `512` by default, but you should do it yourself in the future.",
464
+ UserWarning,
465
+ )
466
+ max_length = 512
467
+ if max_length is None and args.max_length is not None:
468
+ max_length = args.max_length
469
+
470
+ data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
471
+
472
+ if args.remove_unused_columns:
473
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
474
+ args.remove_unused_columns = False
475
+ except FrozenInstanceError:
476
+ args = replace(args, remove_unused_columns=False)
477
+ # warn users
478
+ warnings.warn(
479
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
480
+ " we have set it for you, but you should do it yourself in the future.",
481
+ UserWarning,
482
+ )
483
+
484
+ self.use_reward_data_collator = True
485
+ else:
486
+ self.use_reward_data_collator = False
487
+ super().__init__(
488
+ model=model,
489
+ args=args,
490
+ data_collator=data_collator,
491
+ train_dataset=train_dataset,
492
+ eval_dataset=eval_dataset,
493
+ tokenizer=tokenizer,
494
+ model_init=model_init,
495
+ compute_metrics=compute_metrics,
496
+ callbacks=callbacks,
497
+ optimizers=optimizers,
498
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
499
+ )
500
+
501
+ # Add tags for models that have been loaded with the correct transformers version
502
+ if hasattr(self.model, "add_model_tags"):
503
+ self.model.add_model_tags(self._tag_names)
504
+
505
+ def compute_loss(
506
+ self,
507
+ model: Union[PreTrainedModel, nn.Module],
508
+ inputs: Dict[str, Union[torch.Tensor, Any]],
509
+ return_outputs=False,
510
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
511
+ if not self.use_reward_data_collator:
512
+ warnings.warn(
513
+ "The current compute_loss is implemented for RewardDataCollatorWithPadding,"
514
+ " if you are using a custom data collator make sure you know what you are doing or"
515
+ " implement your own compute_loss method."
516
+ )
517
+ rewards_chosen = model(
518
+ input_ids=inputs["input_ids_chosen"],
519
+ attention_mask=inputs["attention_mask_chosen"],
520
+ return_dict=True,
521
+ )["logits"]
522
+ rewards_rejected = model(
523
+ input_ids=inputs["input_ids_rejected"],
524
+ attention_mask=inputs["attention_mask_rejected"],
525
+ return_dict=True,
526
+ )["logits"]
527
+ # calculate loss, optionally modulate with margin
528
+ if "margin" in inputs:
529
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
530
+ else:
531
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
532
+
533
+ if return_outputs:
534
+ return loss, {
535
+ "rewards_chosen": rewards_chosen,
536
+ "rewards_rejected": rewards_rejected,
537
+ }
538
+ return loss
539
+
540
+ def prediction_step(
541
+ self,
542
+ model: Union[PreTrainedModel, nn.Module],
543
+ inputs: Dict[str, Union[torch.Tensor, Any]],
544
+ prediction_loss_only: bool,
545
+ ignore_keys: Optional[List[str]] = None,
546
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
547
+ inputs = self._prepare_inputs(inputs)
548
+ if ignore_keys is None:
549
+ if hasattr(self.model, "config"):
550
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
551
+ else:
552
+ ignore_keys = []
553
+
554
+ with torch.no_grad():
555
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
556
+
557
+ if prediction_loss_only:
558
+ return (loss, None, None)
559
+
560
+ loss = loss.detach()
561
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
562
+ logits = nested_detach(logits)
563
+ # Stack accepted against rejected, mean over logits
564
+ # and softmax to get preferences between accepted and rejected to sum to 1
565
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
566
+
567
+ labels = torch.zeros(logits.shape[0])
568
+ labels = self._prepare_inputs(labels)
569
+
570
+ return loss, logits, labels
571
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
572
+ """
573
+
574
+ The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
575
+ `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
576
+ an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
577
+ of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
578
+ predict which example in the pair is more relevant to the task at hand.
579
+
580
+ The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
581
+ if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
582
+ - `input_ids_chosen`
583
+ - `attention_mask_chosen`
584
+ - `input_ids_rejected`
585
+ - `attention_mask_rejected`
586
+
587
+ Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
588
+ loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
589
+ If you don't pass a margin, no margin will be used.
590
+
591
+ """
592
+ def __init__(
593
+ self,
594
+ model = None,
595
+ args = None,
596
+ data_collator = None,
597
+ train_dataset = None,
598
+ eval_dataset = None,
599
+ tokenizer = None,
600
+ model_init = None,
601
+ compute_metrics = None,
602
+ callbacks = None,
603
+ preprocess_logits_for_metrics = None,
604
+ max_length = None,
605
+ peft_config = None,
606
+ **kwargs
607
+ ):
608
+ if args is None: args = UnslothRewardConfig()
609
+ use_bf16 = getattr(args, 'bf16', False)
610
+ if type(use_bf16) is not bool: use_bf16 = False
611
+ use_fp16 = getattr(args, 'fp16', False)
612
+ if type(use_fp16) is not bool: use_fp16 = False
613
+ force_float32 = False
614
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
615
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
616
+ force_float32 = True
617
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
618
+ dtype = getattr(model.config, 'torch_dtype', None)
619
+ if dtype is None: dtype = model.get_input_embeddings().dtype
620
+ from unsloth_zoo.utils import _get_dtype
621
+ dtype = _get_dtype(dtype)
622
+ float16 = dtype == torch.float16
623
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
624
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
625
+ if force_float32:
626
+ args.fp16 = False
627
+ args.bf16 = False
628
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
629
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
630
+ args.fp16 = float16
631
+ args.bf16 = not float16
632
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
633
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
634
+ args.eval_strategy = 'steps'
635
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
636
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
637
+ if ga_steps is not None and ga_steps > 1:
638
+ from transformers import __version__ as transformers_version
639
+ if Version(transformers_version) <= Version('4.45.2'):
640
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
641
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
642
+ if getattr(args, 'eval_strategy', 'no') != 'no':
643
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
644
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
645
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
646
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
647
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
648
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
649
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
650
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
651
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
652
+ if force_float32:
653
+ args.bf16_full_eval = False
654
+ args.fp16_full_eval = False
655
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
656
+ args.bf16_full_eval = True
657
+ args.fp16_full_eval = False
658
+ elif not bf16_full_eval and not fp16_full_eval:
659
+ args.bf16_full_eval = args.bf16
660
+ args.fp16_full_eval = args.fp16
661
+ _output_logits = False
662
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
663
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
664
+ if _output_logits:
665
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
666
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
667
+ pass
668
+ else:
669
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
670
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
671
+ if args_max_seq_length is None and model_max_seq_length is not None:
672
+ max_seq_length = model.max_seq_length
673
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
674
+ if model is not None and hasattr(model, 'for_training'):
675
+ model.for_training()
676
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
677
+ if 'processing_class' in locals():
678
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
679
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
680
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
681
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
682
+ if not isinstance(data_collator, UnslothVisionDataCollator):
683
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
684
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
685
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
686
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
687
+ else:
688
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
689
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
690
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
691
+ if not isinstance(data_collator, UnslothVisionDataCollator):
692
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
693
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
694
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
695
+ else:
696
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
697
+ other_metrics = []
698
+
699
+ from unsloth_zoo.logging_utils import PatchRLStatistics
700
+ PatchRLStatistics('reward_trainer', other_metrics)
701
+
702
+ super().__init__(
703
+ model = model,
704
+ args = args,
705
+ data_collator = data_collator,
706
+ train_dataset = train_dataset,
707
+ eval_dataset = eval_dataset,
708
+ tokenizer = tokenizer,
709
+ model_init = model_init,
710
+ compute_metrics = compute_metrics,
711
+ callbacks = callbacks,
712
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
713
+ max_length = max_length,
714
+ peft_config = peft_config,**kwargs)
715
+ if hasattr(self, 'neftune_hook_handle'):
716
+ self.neftune_hook_handle.remove()
717
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
718
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
719
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
720
+ pass
721
+
722
+ pass
compilefcach/__pycache__/UnslothCPOTrainer.cpython-311.pyc ADDED
Binary file (68.7 kB). View file
 
compilefcach/__pycache__/UnslothDDPOTrainer.cpython-311.pyc ADDED
Binary file (38.7 kB). View file
 
compilefcach/__pycache__/UnslothKTOTrainer.cpython-311.pyc ADDED
Binary file (81.8 kB). View file
 
compilefcach/__pycache__/UnslothORPOTrainer.cpython-311.pyc ADDED
Binary file (69.8 kB). View file
 
compilefcach/__pycache__/UnslothPPOTrainer.cpython-311.pyc ADDED
Binary file (83.5 kB). View file
 
compilefcach/__pycache__/UnslothRewardTrainer.cpython-311.pyc ADDED
Binary file (33.1 kB). View file