cstr commited on
Commit
6e81596
1 Parent(s): 2509795

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ src/__pycache__
3
+ scripts/run_orpo.sh
4
+ src/accelerate/fsdp.yaml
5
+ scripts/run_orpo.sh
6
+ src/__pycache__/args.cpython-311.pyc
7
+ src/__pycache__/utils.cpython-311.pyc
8
+ src/accelerate/fsdp.yaml
README.md CHANGED
@@ -1,183 +1,63 @@
1
- ---
2
- tags:
3
- - merge
4
- - mergekit
5
- - lazymergekit
6
- - flemmingmiguel/NeuDist-Ro-7B
7
- - johannhartmann/Brezn3
8
- - ResplendentAI/Flora_DPO_7B
9
- base_model:
10
- - flemmingmiguel/NeuDist-Ro-7B
11
- - johannhartmann/Brezn3
12
- - ResplendentAI/Flora_DPO_7B
13
- language:
14
- - de
15
- - en
16
- ---
17
-
18
- # Spaetzle-v8-7b
19
-
20
- This model is supposed to show adequate performance in German and English on a number of tasks, while mostly behaving well, that is, without rambling on, intermixing tokens from different templates in training and adapting, etc.
21
-
22
- It is mostly a quick test, and considerably weaker in German grammar and orthography than DiscoLM e.g., but for use cases where this is not too important, but e.g. instruction following, reasoning, etc, it might actually be a little bit preferable.
23
-
24
- It is a merge of the following models using [LazyMergekit](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb?usp=sharing):
25
- * [flemmingmiguel/NeuDist-Ro-7B](https://huggingface.co/flemmingmiguel/NeuDist-Ro-7B)
26
- * [johannhartmann/Brezn3](https://huggingface.co/johannhartmann/Brezn3)
27
- * [ResplendentAI/Flora_DPO_7B](https://huggingface.co/ResplendentAI/Flora_DPO_7B)
28
- * on the basis of [mayflowergmbh/Wiedervereinigung-7b-dpo-laser](https://huggingface.co/mayflowergmbh/Wiedervereinigung-7b-dpo-laser)
29
-
30
- All credits are due to the creators of those original models and the training datasets involved.
31
-
32
- For a suitable quantized version, try [cstr/Spaetzle-v8-7b-GGUF](https://huggingface.co/cstr/Spaetzle-v8-7b-GGUF)
33
-
34
-
35
- ## Evaluation
36
- [Open LLM Leaderboard Evaluation Results](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)
37
- Detailed results can be found [here](https://huggingface.co/datasets/open-llm-leaderboard/details_cstr__Spaetzle-v8-7b)
38
-
39
- | Metric |Value|
40
- |---------------------------------|----:|
41
- |Avg. |72.27|
42
- |AI2 Reasoning Challenge (25-Shot)|68.69|
43
- |HellaSwag (10-Shot) |86.68|
44
- |MMLU (5-Shot) |64.60|
45
- |TruthfulQA (0-shot) |64.05|
46
- |Winogrande (5-shot) |81.45|
47
- |GSM8k (5-shot) |68.16|
48
-
49
- EQ-Bench (v2_de): 61.04 / english (v2): 78.3
50
-
51
- | Model |AGIEval|GPT4All|TruthfulQA|Bigbench|Average|
52
- |------------------------------------------------------------|------:|------:|---------:|-------:|------:|
53
- |[Spaetzle-v8-7b](https://huggingface.co/cstr/Spaetzle-v8-7b)| 45.31| 75.69| 63.94| 45.57| 57.63|
54
-
55
- ### AGIEval
56
- | Task |Version| Metric |Value| |Stderr|
57
- |------------------------------|------:|--------|----:|---|-----:|
58
- |agieval_aqua_rat | 0|acc |25.59|± | 2.74|
59
- | | |acc_norm|24.80|± | 2.72|
60
- |agieval_logiqa_en | 0|acc |39.63|± | 1.92|
61
- | | |acc_norm|39.78|± | 1.92|
62
- |agieval_lsat_ar | 0|acc |23.48|± | 2.80|
63
- | | |acc_norm|24.35|± | 2.84|
64
- |agieval_lsat_lr | 0|acc |50.98|± | 2.22|
65
- | | |acc_norm|51.96|± | 2.21|
66
- |agieval_lsat_rc | 0|acc |62.08|± | 2.96|
67
- | | |acc_norm|62.83|± | 2.95|
68
- |agieval_sat_en | 0|acc |78.64|± | 2.86|
69
- | | |acc_norm|79.13|± | 2.84|
70
- |agieval_sat_en_without_passage| 0|acc |44.66|± | 3.47|
71
- | | |acc_norm|44.66|± | 3.47|
72
- |agieval_sat_math | 0|acc |37.27|± | 3.27|
73
- | | |acc_norm|35.00|± | 3.22|
74
-
75
- Average: 45.31%
76
-
77
- ### GPT4All
78
- | Task |Version| Metric |Value| |Stderr|
79
- |-------------|------:|--------|----:|---|-----:|
80
- |arc_challenge| 0|acc |63.14|± | 1.41|
81
- | | |acc_norm|64.51|± | 1.40|
82
- |arc_easy | 0|acc |85.98|± | 0.71|
83
- | | |acc_norm|82.49|± | 0.78|
84
- |boolq | 1|acc |88.10|± | 0.57|
85
- |hellaswag | 0|acc |66.31|± | 0.47|
86
- | | |acc_norm|85.17|± | 0.35|
87
- |openbookqa | 0|acc |38.00|± | 2.17|
88
- | | |acc_norm|47.20|± | 2.23|
89
- |piqa | 0|acc |83.35|± | 0.87|
90
- | | |acc_norm|84.17|± | 0.85|
91
- |winogrande | 0|acc |78.22|± | 1.16|
92
-
93
- Average: 75.69%
94
-
95
- ### TruthfulQA
96
- | Task |Version|Metric|Value| |Stderr|
97
- |-------------|------:|------|----:|---|-----:|
98
- |truthfulqa_mc| 1|mc1 |47.74|± | 1.75|
99
- | | |mc2 |63.94|± | 1.53|
100
-
101
- Average: 63.94%
102
-
103
- ### Bigbench
104
- | Task |Version| Metric |Value| |Stderr|
105
- |------------------------------------------------|------:|---------------------|----:|---|-----:|
106
- |bigbench_causal_judgement | 0|multiple_choice_grade|56.84|± | 3.60|
107
- |bigbench_date_understanding | 0|multiple_choice_grade|66.12|± | 2.47|
108
- |bigbench_disambiguation_qa | 0|multiple_choice_grade|41.47|± | 3.07|
109
- |bigbench_geometric_shapes | 0|multiple_choice_grade|22.01|± | 2.19|
110
- | | |exact_str_match | 0.00|± | 0.00|
111
- |bigbench_logical_deduction_five_objects | 0|multiple_choice_grade|31.40|± | 2.08|
112
- |bigbench_logical_deduction_seven_objects | 0|multiple_choice_grade|23.14|± | 1.60|
113
- |bigbench_logical_deduction_three_objects | 0|multiple_choice_grade|56.00|± | 2.87|
114
- |bigbench_movie_recommendation | 0|multiple_choice_grade|45.00|± | 2.23|
115
- |bigbench_navigate | 0|multiple_choice_grade|50.70|± | 1.58|
116
- |bigbench_reasoning_about_colored_objects | 0|multiple_choice_grade|70.05|± | 1.02|
117
- |bigbench_ruin_names | 0|multiple_choice_grade|45.54|± | 2.36|
118
- |bigbench_salient_translation_error_detection | 0|multiple_choice_grade|26.05|± | 1.39|
119
- |bigbench_snarks | 0|multiple_choice_grade|71.82|± | 3.35|
120
- |bigbench_sports_understanding | 0|multiple_choice_grade|72.92|± | 1.42|
121
- |bigbench_temporal_sequences | 0|multiple_choice_grade|44.20|± | 1.57|
122
- |bigbench_tracking_shuffled_objects_five_objects | 0|multiple_choice_grade|22.80|± | 1.19|
123
- |bigbench_tracking_shuffled_objects_seven_objects| 0|multiple_choice_grade|18.23|± | 0.92|
124
- |bigbench_tracking_shuffled_objects_three_objects| 0|multiple_choice_grade|56.00|± | 2.87|
125
-
126
- Average: 45.57%
127
-
128
- Average score: 57.63%
129
-
130
- ## 💻 Usage
131
-
132
- ```python
133
- !pip install -qU transformers accelerate
134
-
135
- from transformers import AutoTokenizer
136
- import transformers
137
- import torch
138
-
139
- model = "cstr/Spaetzle-v8-7b"
140
- messages = [{"role": "user", "content": "What is a large language model?"}]
141
-
142
- tokenizer = AutoTokenizer.from_pretrained(model)
143
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
144
- pipeline = transformers.pipeline(
145
- "text-generation",
146
- model=model,
147
- torch_dtype=torch.float16,
148
- device_map="auto",
149
- )
150
-
151
- outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
152
- print(outputs[0]["generated_text"])
153
- ```
154
-
155
-
156
- ## 🧩 Configuration
157
-
158
- The model uses ChatML and should work well with this (as it is merged from models which (mostly) saw ChatML templates in training).
159
-
160
- ```yaml
161
- models:
162
- - model: mayflowergmbh/Wiedervereinigung-7b-dpo-laser
163
- # no parameters necessary for base model
164
- - model: flemmingmiguel/NeuDist-Ro-7B
165
- parameters:
166
- density: 0.60
167
- weight: 0.30
168
- - model: johannhartmann/Brezn3
169
- parameters:
170
- density: 0.65
171
- weight: 0.40
172
- - model: ResplendentAI/Flora_DPO_7B
173
- parameters:
174
- density: 0.6
175
- weight: 0.3
176
- merge_method: dare_ties
177
- base_model: mayflowergmbh/Wiedervereinigung-7b-dpo-laser
178
- parameters:
179
- int8_mask: true
180
- dtype: bfloat16
181
- random_seed: 0
182
- tokenizer_source: base
183
- ```
 
1
+ # **ORPO**
2
+
3
+ ### **`Updates (24.03.25)`**
4
+ - [X] Sample script for ORPOTrainer in 🤗<a class="link" href="https://github.com/huggingface/trl">TRL</a> is added to `trl/test_orpo_trainer_demo.py`
5
+ - [X] New model, 🤗<a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-capybara-7k">kaist-ai/mistral-orpo-capybara-7k</a>, is added to 🤗<a class="link" href="https://huggingface.co/collections/kaist-ai/orpo-65efef87544ba100aef30013">ORPO Collection</a>
6
+ - [X] Now you can try ORPO in 🤗<a class="link" href="https://github.com/huggingface/trl">TRL</a> and <a class="link" href="https://github.com/OpenAccess-AI-Collective/axolotl">Axolotl</a>🔥
7
+ - [X] We are making general guideline for training LLMs with ORPO, stay tuned🔥
8
+ - [X] **Mistral-ORPO-β** achieved a 14.7% in the length-controlled (LC) win rate on <a class="link" href="https://tatsu-lab.github.io/alpaca_eval/">official AlpacaEval Leaderboard</a>🔥
9
+
10
+ &nbsp;
11
+
12
+ This is the official repository for <a class="link" href="https://arxiv.org/abs/2403.07691">**ORPO: Monolithic Preference Optimization without Reference Model**</a>. The detailed results in the paper can be found in:
13
+ - [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=kaist-ai%2Fmistral-orpo-beta)
14
+ - [AlpacaEval](#alpacaeval)
15
+ - [MT-Bench](#mt-bench)
16
+ - [IFEval](#ifeval)
17
+
18
+
19
+ ### **`Model Checkpoints`**
20
+
21
+ Our models trained with ORPO can be found in:
22
+
23
+ - [X] **Mistral-ORPO-Capybara-7k**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-capybara-7k">kaist-ai/mistral-orpo-capybara-7k</a>
24
+ - [X] **Mistral-ORPO-⍺**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-alpha">kaist-ai/mistral-orpo-alpha</a>
25
+ - [X] **Mistral-ORPO-β**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-beta">kaist-ai/mistral-orpo-beta</a>
26
+
27
+ And the corresponding logs for the average log probabilities of chosen/rejected responses during training are reported in:
28
+
29
+ - [X] **Mistral-ORPO-Capybara-7k**: TBU
30
+ - [X] **Mistral-ORPO-⍺**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE1NzE0?accessToken=rms6o4mg5vo3feu1bvbpk632m4cspe19l0u1p4he3othx5bgean82chn9neiile6">Wandb Report for Mistral-ORPO-⍺</a>
31
+ - [X] **Mistral-ORPO-β**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE3MzMy?accessToken=dij4qbp6dcrofsanzbgobjsne9el8a2zkly2u5z82rxisd4wiwv1rhp0s2dub11e">Wandb Report for Mistral-ORPO-β</a>
32
+
33
+ &nbsp;
34
+
35
+ ### **`AlpacaEval`**
36
+
37
+ <figure>
38
+ <img class="png" src="/assets/img/alpaca_blog.png" alt="Description of the image">
39
+ <figcaption><b>Figure 1.</b> AlpacaEval 2.0 score for the models trained with different alignment methods.</figcaption>
40
+ </figure>
41
+
42
+ &nbsp;
43
+
44
+ ### **`MT-Bench`**
45
+
46
+ <figure>
47
+ <img class="png" src="/assets/img/mtbench_hf.png" alt="Description of the image">
48
+ <figcaption><b>Figure 2.</b> MT-Bench result by category.</figcaption>
49
+ </figure>
50
+
51
+ &nbsp;
52
+
53
+ ### **`IFEval`**
54
+
55
+ IFEval scores are measured with <a class="link" href="https://github.com/EleutherAI/lm-evaluation-harness">EleutherAI/lm-evaluation-harness</a> by applying the chat template. The scores for Llama-2-Chat (70B), Zephyr-β (7B), and Mixtral-8X7B-Instruct-v0.1 are originally reported in <a class="link" href="https://twitter.com/wiskojo/status/1739767758462877823">this tweet</a>.
56
+
57
+ | **Model Type** | **Prompt-Strict** | **Prompt-Loose** | **Inst-Strict** | **Inst-Loose** |
58
+ |--------------------|:-----------------:|:----------------:|:---------------:|----------------|
59
+ | **Llama-2-Chat (70B)** | 0.4436 | 0.5342 | 0.5468 | 0.6319 |
60
+ | **Zephyr-β (7B)** | 0.4233 | 0.4547 | 0.5492 | 0.5767 |
61
+ | **Mixtral-8X7B-Instruct-v0.1** | 0.5213 | **0.5712** | 0.6343 | **0.6823** |
62
+ | **Mistral-ORPO-⍺ (7B)** | 0.5009 | 0.5083 | 0.5995 | 0.6163 |
63
+ | **Mistral-ORPO-β (7B)** | **0.5287** | 0.5564 | **0.6355** | 0.6619 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/img/alpaca_blog.png ADDED
assets/img/mtbench_hf.png ADDED
main.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import wandb
5
+ import torch
6
+ import argparse
7
+ from datasets import load_dataset
8
+ from typing import List, Dict, Union
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ TrainingArguments,
13
+ DataCollatorForLanguageModeling
14
+ )
15
+
16
+ from src.args import default_args
17
+ from src.orpo_trainer import ORPOTrainer
18
+ from src.utils import preprocess_logits_for_metrics, dataset_split_selector
19
+
20
+ class ORPO(object):
21
+ def __init__(self, args) -> None:
22
+ self.start = time.gmtime()
23
+ self.args = args
24
+
25
+ # Load Tokenizer
26
+ print(">>> 1. Loading Tokenizer")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, cache_dir=self.args.cache_dir)
28
+ if self.tokenizer.chat_template is None:
29
+ self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
30
+ print(" 1-1. Chat Template Applied (<|user|> <|assistant|>)")
31
+ else:
32
+ pass
33
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
34
+
35
+ # Load Model
36
+ print(">>> 2. Loading Model")
37
+ if self.args.flash_attention_2:
38
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
39
+ cache_dir=self.args.cache_dir,
40
+ torch_dtype=torch.bfloat16,
41
+ attn_implementation="flash_attention_2")
42
+ else:
43
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
44
+ cache_dir=self.args.cache_dir,
45
+ torch_dtype=torch.bfloat16)
46
+
47
+ # Load Dataset
48
+ print(">>> 3. Loading Dataset")
49
+ self.data = load_dataset(self.args.data_name, cache_dir=self.args.cache_dir)
50
+
51
+ # Preprocess Dataset
52
+ print(">>> 4. Filtering and Preprocessing Dataset")
53
+ data_split = dataset_split_selector(self.data)
54
+
55
+ if len(data_split) == 1:
56
+ self.is_test = False
57
+ train_split = data_split[0]
58
+ print(f" >>> Test Set = {self.is_test}")
59
+ else:
60
+ self.is_test = True
61
+ train_split = data_split[0]
62
+ test_split = data_split[1]
63
+
64
+ test = self.data[test_split].filter(self.filter_dataset)
65
+ self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
66
+
67
+ train = self.data[train_split].filter(self.filter_dataset)
68
+ print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
69
+ self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
70
+
71
+ # Set WANDB & Logging Configurations
72
+ self.run_name = f"{self.args.model_name.split('/')[-1]}-{self.args.data_name.split('/')[-1]}-lambda{self.args.alpha}-ORPO-{self.start.tm_mday}-{self.start.tm_hour}-{self.start.tm_min}"
73
+ self.save_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}")
74
+ self.log_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}/logs")
75
+
76
+ os.makedirs(self.save_dir, exist_ok=True)
77
+ os.makedirs(self.log_dir, exist_ok=True)
78
+
79
+ def preprocess_dataset(self, examples: Union[List, Dict]):
80
+ if ('instruction' in examples.keys()) or ('question' in examples.keys()):
81
+ prompt_key = 'instruction' if 'instruction' in examples.keys() else 'question'
82
+ prompt = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item}], tokenize=False, add_generation_prompt=True) for item in examples[prompt_key]]
83
+ chosen = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_chosen}], tokenize=False) for item_prompt, item_chosen in zip(examples[prompt_key], examples['chosen'])]
84
+ rejected = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_rejected}], tokenize=False) for item_prompt, item_rejected in zip(examples[prompt_key], examples['rejected'])]
85
+ else:
86
+ prompt = [self.tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
87
+ chosen = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
88
+ rejected = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
89
+
90
+ model_inputs = self.tokenizer(prompt,
91
+ max_length=self.args.response_max_length,
92
+ padding='max_length',
93
+ truncation=True,
94
+ return_tensors='pt')
95
+ pos_labels = self.tokenizer(chosen,
96
+ max_length=self.args.response_max_length,
97
+ padding='max_length',
98
+ truncation=True,
99
+ return_tensors='pt')
100
+ neg_labels = self.tokenizer(rejected,
101
+ max_length=self.args.response_max_length,
102
+ padding='max_length',
103
+ truncation=True,
104
+ return_tensors='pt')
105
+
106
+ model_inputs['positive_input_ids'] = pos_labels['input_ids']
107
+ model_inputs['positive_attention_mask'] = pos_labels['attention_mask']
108
+
109
+ model_inputs['negative_input_ids'] = neg_labels['input_ids']
110
+ model_inputs['negative_attention_mask'] = neg_labels['attention_mask']
111
+
112
+ return model_inputs
113
+
114
+ def filter_dataset(self, examples: Union[List, Dict]):
115
+ if 'instruction' in examples.keys():
116
+ query = examples['instruction']
117
+ prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
118
+ elif 'question' in examples.keys():
119
+ query = examples['question']
120
+ prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
121
+ else:
122
+ prompt_length = self.tokenizer.apply_chat_template([examples['chosen'][0]], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
123
+
124
+ if prompt_length < self.args.prompt_max_length:
125
+ return True
126
+ else:
127
+ return False
128
+
129
+ def prepare_trainer(self):
130
+ wandb.init(name=self.run_name)
131
+ arguments = TrainingArguments(
132
+ output_dir=self.save_dir, # The output directory
133
+ logging_dir=self.log_dir,
134
+ logging_steps=50,
135
+ learning_rate=self.args.lr,
136
+ overwrite_output_dir=True, # overwrite the content of the output directory
137
+ num_train_epochs=self.args.num_train_epochs, # number of training epochs
138
+ per_device_train_batch_size=self.args.per_device_train_batch_size, # batch size for training
139
+ per_device_eval_batch_size=self.args.per_device_eval_batch_size, # batch size for evaluation
140
+ evaluation_strategy=self.args.evaluation_strategy if self.is_test else 'no', # batch size for evaluation
141
+ save_strategy=self.args.evaluation_strategy,
142
+ optim=self.args.optim,
143
+ warmup_steps=self.args.warmup_steps,
144
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
145
+ gradient_checkpointing=True, #if ('llama' in self.args.model_name.lower()) or ('mistral' in self.args.model_name.lower()) else False,
146
+ gradient_checkpointing_kwargs={'use_reentrant':True},
147
+ load_best_model_at_end=self.is_test,
148
+ do_train=True,
149
+ do_eval=self.is_test,
150
+ lr_scheduler_type=self.args.lr_scheduler_type,
151
+ remove_unused_columns=False,
152
+ report_to='wandb',
153
+ run_name=self.run_name,
154
+ bf16=True
155
+ )
156
+
157
+ data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
158
+
159
+ self.trainer = ORPOTrainer(
160
+ model=self.model,
161
+ alpha=self.args.alpha,
162
+ pad=self.tokenizer.pad_token_id,
163
+ args=arguments,
164
+ train_dataset=self.train,
165
+ eval_dataset=self.test if self.is_test else None,
166
+ data_collator=data_collator,
167
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
168
+ )
169
+
170
+ def run(self):
171
+ print(">>> 5. Preparing ORPOTrainer")
172
+ self.prepare_trainer()
173
+ self.trainer.train()
174
+
175
+ # Saving code for FSDP
176
+ if self.trainer.is_fsdp_enabled:
177
+ self.trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
178
+ self.trainer.save_model()
179
+
180
+
181
+ if __name__ == '__main__':
182
+ parser = argparse.ArgumentParser("ORPO")
183
+ args = default_args(parser)
184
+
185
+ # Set WANDB configurations
186
+ if args.wandb_entity is not None and args.wandb_project_name is not None:
187
+ os.environ["WANDB_ENTITY"] = args.wandb_entity
188
+ os.environ["WANDB_PROJECT"] = args.wandb_project_name
189
+ else:
190
+ pass
191
+ os.environ["TOKENIZERS_PARALLELISM"] = 'false'
192
+
193
+ print("================================================================================================\n")
194
+ print(f">>> Fine-tuning {args.model_name} with ORPO on {args.data_name}\n")
195
+ print("================================================================================================")
196
+ print("\n\n>>> Summary:")
197
+ print(f" - Lambda : {args.alpha}")
198
+ print(f" - Training Epochs : {args.num_train_epochs}")
199
+ print(f" - Prompt Max Length : {args.prompt_max_length}")
200
+ print(f" - Response Max Length : {args.response_max_length}")
201
+
202
+ item = ORPO(args=args)
203
+ item.run()
outputs/alpacaeval/Mistral-ORPO-alpha.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/alpacaeval/Mistral-ORPO-beta.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-alpha.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-beta.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate @ file:///home/conda/feedstock_root/build_artifacts/accelerate_1710334587919/work
2
+ aiohttp @ file:///croot/aiohttp_1707342283163/work
3
+ aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
4
+ appdirs==1.4.4
5
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
6
+ attrs @ file:///croot/attrs_1695717823297/work
7
+ bitsandbytes==0.43.0
8
+ Bottleneck @ file:///croot/bottleneck_1707864210935/work
9
+ Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work
10
+ cachetools==5.3.3
11
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
12
+ cffi @ file:///croot/cffi_1700254295673/work
13
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
14
+ click @ file:///croot/click_1698129812380/work
15
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
16
+ datasets @ file:///home/conda/feedstock_root/build_artifacts/datasets_1709395865330/work
17
+ debugpy @ file:///croot/debugpy_1690905042057/work
18
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
19
+ dill @ file:///croot/dill_1692271232022/work
20
+ docker-pycreds @ file:///Users/ktietz/demo/mc3/conda-bld/docker-pycreds_1630654474270/work
21
+ einops==0.7.0
22
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
23
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
24
+ filelock @ file:///croot/filelock_1700591183607/work
25
+ flash-attn==2.5.6
26
+ frozenlist @ file:///croot/frozenlist_1698702560391/work
27
+ fsspec==2023.4.0
28
+ gitdb @ file:///tmp/build/80754af9/gitdb_1617117951232/work
29
+ GitPython @ file:///croot/gitpython_1696936983078/work
30
+ gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work
31
+ huggingface-hub @ file:///croot/huggingface_hub_1708634519519/work
32
+ idna @ file:///work/ci_py311/idna_1676822698822/work
33
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1709821103657/work
34
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
35
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1709559745751/work
36
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
37
+ Jinja2==3.1.2
38
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1710255804825/work
39
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257359434/work
40
+ MarkupSafe @ file:///croot/markupsafe_1704205993651/work
41
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
42
+ mkl-fft @ file:///croot/mkl_fft_1695058164594/work
43
+ mkl-random @ file:///croot/mkl_random_1695059800811/work
44
+ mkl-service==2.4.0
45
+ mpmath @ file:///croot/mpmath_1690848262763/work
46
+ multidict @ file:///croot/multidict_1701096859099/work
47
+ multiprocess @ file:///croot/multiprocess_1692294385131/work
48
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
49
+ networkx==3.2.1
50
+ ninja==1.11.1.1
51
+ numexpr @ file:///croot/numexpr_1696515281613/work
52
+ numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp311-cp311-linux_x86_64.whl#sha256=5f96f274d410a1682519282ae769c877d32fdbf171aa8badec7bf5e1d3a1748a
53
+ nvidia-cublas-cu11==11.11.3.6
54
+ nvidia-cuda-cupti-cu11==11.8.87
55
+ nvidia-cuda-nvrtc-cu11==11.8.89
56
+ nvidia-cuda-runtime-cu11==11.8.89
57
+ nvidia-cudnn-cu11==8.7.0.84
58
+ nvidia-cufft-cu11==10.9.0.58
59
+ nvidia-curand-cu11==10.3.0.86
60
+ nvidia-cusolver-cu11==11.4.1.48
61
+ nvidia-cusparse-cu11==11.7.5.86
62
+ nvidia-ml-py==12.535.133
63
+ nvidia-nccl-cu11==2.19.3
64
+ nvidia-nvtx-cu11==11.8.86
65
+ nvitop==1.3.2
66
+ packaging @ file:///croot/packaging_1693575174725/work
67
+ pandas @ file:///croot/pandas_1709590491089/work/dist/pandas-2.2.1-cp311-cp311-linux_x86_64.whl#sha256=0a2793a31a0135a35735e1431d453a06186a3a7c607d9b441d9bd5f0fe4ded31
68
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
69
+ pathtools @ file:///Users/ktietz/demo/mc3/conda-bld/pathtools_1629713893697/work
70
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
71
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
72
+ pillow==10.2.0
73
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work
74
+ prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
75
+ protobuf==3.20.3
76
+ psutil @ file:///work/ci_py311_2/psutil_1679337388738/work
77
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
78
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
79
+ pyarrow @ file:///croot/pyarrow_1707330824290/work/python
80
+ pyarrow-hotfix @ file:///home/conda/feedstock_root/build_artifacts/pyarrow-hotfix_1700596371886/work
81
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
82
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work
83
+ PySocks @ file:///work/ci_py311/pysocks_1676822712504/work
84
+ python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
85
+ pytz @ file:///croot/pytz_1695131579487/work
86
+ PyYAML @ file:///croot/pyyaml_1698096049011/work
87
+ pyzmq @ file:///croot/pyzmq_1705605076900/work
88
+ regex @ file:///croot/regex_1696515298636/work
89
+ requests @ file:///croot/requests_1707355572290/work
90
+ safetensors @ file:///croot/safetensors_1708633833937/work
91
+ sentry-sdk @ file:///work/ci_py311/sentry-sdk_1676862120883/work
92
+ setproctitle @ file:///work/ci_py311/setproctitle_1676838789127/work
93
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
94
+ smmap @ file:///tmp/build/80754af9/smmap_1611694433573/work
95
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
96
+ sympy @ file:///croot/sympy_1701397643339/work
97
+ termcolor==2.4.0
98
+ tokenizers @ file:///croot/tokenizers_1708633814160/work
99
+ torch==2.2.1+cu118
100
+ torchaudio==2.2.1+cu118
101
+ torchvision==0.17.1+cu118
102
+ tornado @ file:///croot/tornado_1696936946304/work
103
+ tqdm @ file:///croot/tqdm_1679561862951/work
104
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work
105
+ transformers @ file:///home/conda/feedstock_root/build_artifacts/transformers_1709308155748/work
106
+ triton==2.2.0
107
+ typing_extensions==4.8.0
108
+ tzdata @ file:///croot/python-tzdata_1690578112552/work
109
+ urllib3 @ file:///croot/urllib3_1707770551213/work
110
+ wandb @ file:///home/conda/feedstock_root/build_artifacts/wandb_1707246480133/work
111
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
112
+ xxhash @ file:///work/ci_py311/python-xxhash_1676842384694/work
113
+ yarl @ file:///croot/yarl_1701105127787/work
114
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
scripts/run_mistral_orpo_beta.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Mistral-ORPO series are trained on 4 * A100s
4
+
5
+ accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
6
+ --lr 5e-6 \
7
+ --lr_scheduler_type inverse_sqrt \
8
+ --alpha 0.1 \
9
+ --torch_compile False \
10
+ --warmup_steps 200 \
11
+ --model_name mistralai/Mistral-7B-v0.1 \
12
+ --data_name argilla/ultrafeedback-binarized-preferences-cleaned \
13
+ --num_train_epochs 5 \
14
+ --prompt_max_length 1792 \
15
+ --response_max_length 2048 \
16
+ --per_device_train_batch_size 8 \
17
+ --per_device_eval_batch_size 8 \
18
+ --gradient_accumulation_steps 1 \
19
+ --num_proc 8 \
20
+ --flash_attention_2
scripts/run_mistral_orpo_capybara.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Mistral-ORPO series are trained on 4 * A100s
4
+
5
+ accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
6
+ --lr 5e-6 \
7
+ --torch_compile False \
8
+ --alpha 0.05 \
9
+ --lr_scheduler_type inverse_sqrt \
10
+ --cache_dir /projects/hf_cache/ \
11
+ --warmup_steps 100 \
12
+ --model_name mistralai/Mistral-7B-v0.1 \
13
+ --data_name argilla/distilabel-capybara-dpo-7k-binarized \
14
+ --num_train_epochs 3 \
15
+ --optim adamw_bnb_8bit \
16
+ --gradient_accumulation_steps 1 \
17
+ --prompt_max_length 1792 \
18
+ --response_max_length 2048 \
19
+ --per_device_train_batch_size 8 \
20
+ --per_device_eval_batch_size 8 \
21
+ --num_proc 8 \
22
+ --flash_attention_2
src/accelerate/ds2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 2
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
src/args.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def default_args(parser):
2
+ parser.add_argument("--cache_dir", default=None, type=str)
3
+ parser.add_argument("--save_dir", default='./saved', type=str)
4
+ parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
5
+ parser.add_argument("--model_name", default="gpt2", type=str)
6
+
7
+ # Training Arguments
8
+ parser.add_argument("--torch_compile", default=False, type=bool)
9
+ parser.add_argument("--flash_attention_2", action='store_true')
10
+ parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
11
+ parser.add_argument("--optim", default="paged_adamw_32bit", type=str)
12
+ parser.add_argument("--overwrite_output_dir", default=True, type=bool)
13
+ parser.add_argument("--lr", default=2e-5, type=float)
14
+ parser.add_argument("--num_proc", default=1, type=int)
15
+ parser.add_argument("--num_train_epochs", default=10, type=int)
16
+ parser.add_argument("--per_device_train_batch_size", default=2, type=int)
17
+ parser.add_argument("--per_device_eval_batch_size", default=2, type=int)
18
+ parser.add_argument("--warmup_steps", default=5000, type=int)
19
+ parser.add_argument("--evaluation_strategy", default='epoch', type=str)
20
+ parser.add_argument("--do_eval", action='store_true')
21
+ parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
22
+ parser.add_argument("--save_strategy", default='epoch', type=str)
23
+ parser.add_argument("--prompt_max_length", default=256, type=int)
24
+ parser.add_argument("--response_max_length", default=1024, type=int)
25
+ parser.add_argument("--alpha", default=1.0, type=float, help="Hyperparameter for weighting L_OR")
26
+
27
+ # Wandb Configurations
28
+ parser.add_argument("--wandb_entity", default=None, type=str)
29
+ parser.add_argument("--wandb_project_name", default=None, type=str)
30
+
31
+
32
+ args = parser.parse_args()
33
+
34
+ return args
src/orpo_trainer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import wandb
4
+ from transformers import Trainer
5
+
6
+
7
+ class ORPOTrainer(Trainer):
8
+ def __init__(self, alpha, pad, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.pad = pad
11
+ self.alpha = alpha
12
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
13
+ print("Pad Token ID: ", self.pad)
14
+
15
+ def compute_custom_loss(self, logits, labels):
16
+
17
+ logits = logits.contiguous()
18
+
19
+ if labels is not None:
20
+ # move labels to correct device to enable model parallelism
21
+ labels = labels.to(logits.device)
22
+ # Shift so that tokens < n predict n
23
+ shift_logits = logits[..., :-1, :].contiguous()
24
+ shift_labels = labels[..., 1:].contiguous()
25
+
26
+ # Flatten the tokens
27
+ loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
28
+
29
+ return loss
30
+
31
+ def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
32
+ mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
33
+ per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2,
34
+ index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
35
+ return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
36
+
37
+ def compute_loss(self, model, inputs, return_outputs=False):
38
+ if self.label_smoother is not None and "labels" in inputs:
39
+ labels = inputs.pop("labels")
40
+ else:
41
+ labels = None
42
+
43
+ # Generate the hidden states for 'chosen' and 'reject'
44
+ neg_labels = inputs['negative_input_ids'].clone()
45
+ pos_labels = inputs['positive_input_ids'].clone()
46
+
47
+ neg_labels[neg_labels == self.pad] = -100
48
+ pos_labels[pos_labels == self.pad] = -100
49
+
50
+ outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
51
+ 'attention_mask': inputs['negative_attention_mask'],
52
+ 'labels': neg_labels,}, output_hidden_states=True)
53
+ outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
54
+ 'attention_mask': inputs['positive_attention_mask'],
55
+ 'labels': pos_labels,}, output_hidden_states=True)
56
+
57
+ # Calculate NLL loss
58
+ pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids'])
59
+
60
+ # Calculate Log Probability
61
+ pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
62
+ chosen_inputs=inputs['positive_input_ids'],
63
+ chosen_attention_mask=inputs['positive_attention_mask'],
64
+ logits=outputs_pos.logits)
65
+ neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
66
+ chosen_inputs=inputs['negative_input_ids'],
67
+ chosen_attention_mask=inputs['negative_attention_mask'],
68
+ logits=outputs_neg.logits)
69
+
70
+ # Calculate log odds
71
+ log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
72
+ sig_ratio = torch.nn.functional.sigmoid(log_odds)
73
+ ratio = torch.log(sig_ratio)
74
+
75
+ # Calculate the Final Loss
76
+ loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
77
+
78
+ wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
79
+ 'Negative Geometric Mean': torch.mean(neg_prob).item(),
80
+ 'Log Odds Ratio': torch.mean(ratio).item(),
81
+ 'Log Odds': torch.mean(log_odds).item()})
82
+
83
+ return (loss, outputs_pos) if return_outputs else loss
src/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+
4
+ def preprocess_logits_for_metrics(logits, labels):
5
+ if isinstance(logits, tuple):
6
+ logits = logits[0]
7
+ return logits.argmax(dim=-1)
8
+
9
+ def dataset_split_selector(data) -> List:
10
+ """
11
+ This is a function for automating the process of selecting data split.
12
+ Will be further updated.
13
+ """
14
+ if len(data.keys()) == 1:
15
+ return ['train']
16
+ else:
17
+ if 'train_prefs' in data.keys():
18
+ return ['train_prefs', 'test_prefs']
19
+ else:
20
+ return ['train', 'test']
trl/test_orpo_trainer_demo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import os
6
+ import torch
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, HfArgumentParser, pipeline
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from trl import ORPOConfig, ORPOTrainer, set_seed
13
+ from trl.core import LengthSampler
14
+
15
+ # This code is built on top of the example code from Huggingface TRL Team
16
+
17
+ tqdm.pandas()
18
+
19
+ @dataclass
20
+ class ScriptArguments:
21
+ model_name: Optional[str] = field(default="microsoft/phi-2", metadata={"help": "the model name"})
22
+ optim: Optional[str] = field(default="adamw_torch", metadata={"help": "the model name"})
23
+ data_name: Optional[str] = field(default="argilla/dpo-mix-7k", metadata={"help": "the model name"})
24
+ cache_dir: Optional[str] = field(default="", metadata={"help": "the model name"})
25
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
26
+ output_dir: Optional[str] = field(default='', metadata={"help": "use 'wandb' to log with wandb"})
27
+ learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
28
+ lr_scheduler_type: Optional[str] = field(default='cosine', metadata={"help": "the learning rate scheduler"})
29
+ per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the batch size"})
30
+ num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the batch size"})
31
+ beta: Optional[float] = field(default=0.25, metadata={"help": "weighting hyperparameter for L_OR"})
32
+ gradient_accumulation_steps: Optional[int] = field(
33
+ default=1, metadata={"help": "the number of gradient accumulation steps"}
34
+ )
35
+
36
+
37
+ parser = HfArgumentParser(ScriptArguments)
38
+ script_args = parser.parse_args_into_dataclasses()[0]
39
+
40
+ config = ORPOConfig(
41
+ output_dir=script_args.output_dir,
42
+ max_prompt_length=1024,
43
+ max_length=2048,
44
+ logging_steps=100,
45
+ save_strategy='no',
46
+ max_completion_length=2048,
47
+ per_device_train_batch_size=script_args.per_device_train_batch_size,
48
+ remove_unused_columns=False,
49
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
50
+ learning_rate=script_args.learning_rate,
51
+ optim=script_args.optim,
52
+ lr_scheduler_type=script_args.lr_scheduler_type,
53
+ gradient_checkpointing=True,
54
+ gradient_checkpointing_kwargs={'use_reentrant':True},
55
+ evaluation_strategy='epoch',
56
+ beta=script_args.beta,
57
+ report_to='wandb',
58
+ num_train_epochs=script_args.num_train_epochs,
59
+ bf16=True,
60
+ do_eval=True
61
+ )
62
+
63
+ model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
64
+ cache_dir=script_args.cache_dir,
65
+ attn_implementation='flash_attention_2',
66
+ torch_dtype=torch.bfloat16)
67
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name,
68
+ cache_dir=script_args.cache_dir)
69
+ tokenizer.pad_token_id = tokenizer.eos_token_id
70
+ tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
71
+
72
+ def build_dataset(tokenizer):
73
+ ds_train = load_dataset(script_args.data_name, split="train",
74
+ cache_dir=script_args.cache_dir)
75
+ ds_test = load_dataset(script_args.data_name, split="test",
76
+ cache_dir=script_args.cache_dir)
77
+
78
+ def chat_template_to_text(sample):
79
+ sample["chosen"] = [tokenizer.apply_chat_template(item_chosen, tokenize=False) for item_chosen in sample['chosen']]
80
+ sample["rejected"] = [tokenizer.apply_chat_template(item_rejected, tokenize=False) for item_rejected in sample['rejected']]
81
+ sample['prompt'] = [tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in sample['chosen']]
82
+
83
+ return sample
84
+
85
+ ds_train = ds_train.map(chat_template_to_text, batched=True, num_proc=8)
86
+ ds_test = ds_test.map(chat_template_to_text, batched=True, num_proc=8)
87
+
88
+ return ds_train, ds_test
89
+
90
+ train, test = build_dataset(tokenizer=tokenizer)
91
+
92
+ trainer = ORPOTrainer(
93
+ model=model,
94
+ args=config,
95
+ tokenizer=tokenizer,
96
+ train_dataset=train,
97
+ eval_dataset=test
98
+ )
99
+
100
+ trainer.train()