AIR-hl commited on
Commit
6928048
·
verified ·
1 Parent(s): c2a5993

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +127 -3
README.md CHANGED
@@ -1,3 +1,127 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - HuggingFaceH4/ultrafeedback_binarized
5
+ base_model:
6
+ - AIR-hl/Qwen2.5-1.5B-ultrachat200k
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - trl
10
+ - qwen
11
+ - simpo
12
+ - alignment
13
+ - transformers
14
+ - custome
15
+ - chat
16
+ ---
17
+ # Qwen2.5-1.5B-SimPO
18
+
19
+
20
+ ## Model Details
21
+
22
+ - **Model type:** aligned model
23
+ - **License:** Apache license 2.0
24
+ - **Finetuned from model:** [AIR-hl/Qwen2.5-1.5B-ultrachat200k](https://huggingface.co/AIR-hl/Qwen2.5-1.5B-ultrachat200k)
25
+ - **Training data:** [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
26
+ - **Training framework:** [trl](https://github.com/huggingface/trl)
27
+
28
+ ## Training Details
29
+
30
+ devices: 4 * NPU 910B-64GB \
31
+ precision: bf16 mixed-precision \
32
+ global_batch_size: 128
33
+
34
+ ### Training Hyperparameters
35
+ `beta`: 1 \
36
+ `gamma`: 0.1 \
37
+ `bf16`: True \
38
+ `learning_rate`: 1e-6 \
39
+ `lr_scheduler_type`: cosine \
40
+ `per_device_train_batch_size`: 16 \
41
+ `gradient_accumulation_steps`: 2 \
42
+ `torch_dtype`: bfloat16 \
43
+ `num_train_epochs`: 1 \
44
+ `max_prompt_length`: 512 \
45
+ `max_length`: 1024 \
46
+ `warmup_ratio`: 0.05
47
+
48
+ ### Results
49
+
50
+ `init_train_loss`: 0.7551 \
51
+ `final_train_loss`: 0.6715 \
52
+ `accuracy`: 0.6375 \
53
+ `reward_margin`: 0.3633
54
+
55
+ ### Training script
56
+
57
+ ```python
58
+ import torch
59
+ from datasets import load_dataset
60
+ from transformers import AutoModelForCausalLM, AutoTokenizer
61
+ from trl import (
62
+ CPOConfig,
63
+ CPOTrainer,
64
+ ModelConfig,
65
+ ScriptArguments,
66
+ TrlParser,
67
+ get_kbit_device_map,
68
+ get_peft_config,
69
+ get_quantization_config,
70
+ )
71
+ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
72
+
73
+ if __name__ == "__main__":
74
+ parser = TrlParser((ScriptArguments, CPOConfig, ModelConfig))
75
+ script_args, training_args, model_config = parser.parse_args_and_config()
76
+
77
+ torch_dtype = (
78
+ model_config.torch_dtype
79
+ if model_config.torch_dtype in ["auto", None]
80
+ else getattr(torch, model_config.torch_dtype)
81
+ )
82
+
83
+ quantization_config = get_quantization_config(model_config)
84
+
85
+ model_kwargs = dict(
86
+ revision=model_config.model_revision,
87
+ attn_implementation=model_config.attn_implementation,
88
+ torch_dtype=torch_dtype,
89
+ use_cache=False if training_args.gradient_checkpointing else True,
90
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
91
+ quantization_config=quantization_config,
92
+ )
93
+
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
96
+ )
97
+
98
+ peft_config = get_peft_config(model_config)
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained(
101
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
102
+ )
103
+ if tokenizer.pad_token is None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+ if tokenizer.chat_template is None:
106
+ tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
107
+ if script_args.ignore_bias_buffers:
108
+ model._ddp_params_and_buffers_to_ignore = [
109
+ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
110
+ ]
111
+
112
+ dataset=load_dataset(script_args.dataset_name,
113
+ split=script_args.dataset_train_split)
114
+ dataset=dataset.select_columns(['prompt', 'chosen', 'rejected'])
115
+
116
+ trainer = CPOTrainer(
117
+ model,
118
+ args=training_args,
119
+ train_dataset=dataset,
120
+ processing_class=tokenizer,
121
+ peft_config=peft_config,
122
+ )
123
+
124
+ trainer.train()
125
+
126
+ trainer.save_model(training_args.output_dir)
127
+ ```