dufry2024 commited on
Commit
e6d9233
·
verified ·
1 Parent(s): 9e849a4

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +150 -0
train.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py
2
+ # and https://huggingface.co/blog/gemma-peft
3
+ import argparse
4
+ import multiprocessing
5
+ import os
6
+
7
+ import torch
8
+ import transformers
9
+ from accelerate import PartialState
10
+ from datasets import load_dataset
11
+ from peft import AutoPeftModelForCausalLM, LoraConfig
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ BitsAndBytesConfig,
15
+ is_torch_npu_available,
16
+ is_torch_xpu_available,
17
+ logging,
18
+ set_seed,
19
+ )
20
+ from trl import SFTTrainer, SFTConfig
21
+ #from trl import SFTTrainer
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B")
27
+ parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol")
28
+ parser.add_argument("--subset", type=str, default="data/python")
29
+ parser.add_argument("--split", type=str, default="train")
30
+ parser.add_argument("--dataset_text_field", type=str, default="content")
31
+
32
+ parser.add_argument("--max_seq_length", type=int, default=2048)
33
+ parser.add_argument("--max_steps", type=int, default=1000)
34
+ parser.add_argument("--micro_batch_size", type=int, default=1)
35
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
36
+ parser.add_argument("--weight_decay", type=float, default=0.01)
37
+ parser.add_argument("--bf16", type=bool, default=True)
38
+
39
+ parser.add_argument("--use_bnb", type=bool, default=False)
40
+ parser.add_argument("--attention_dropout", type=float, default=0.1)
41
+ parser.add_argument("--learning_rate", type=float, default=2e-4)
42
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
43
+ parser.add_argument("--warmup_steps", type=int, default=100)
44
+ parser.add_argument("--seed", type=int, default=0)
45
+ parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python")
46
+ parser.add_argument("--num_proc", type=int, default=None)
47
+ parser.add_argument("--save_merged_model", type=bool, default=True)
48
+ parser.add_argument("--push_to_hub", type=bool, default=True)
49
+ parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune")
50
+ return parser.parse_args()
51
+
52
+
53
+ def main(args):
54
+ # config
55
+ lora_config = LoraConfig(
56
+ r=16,
57
+ lora_alpha=32,
58
+ lora_dropout=0.05,
59
+ target_modules=["q_proj", "v_proj"],
60
+ bias="none",
61
+ task_type="CAUSAL_LM",
62
+ )
63
+ bnb_config = None
64
+ if args.use_bnb:
65
+ bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.bfloat16,
69
+ )
70
+ # load model and dataset
71
+ token = os.environ.get("HF_TOKEN", None)
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ args.model_id,
74
+ quantization_config=bnb_config,
75
+ device_map={"": PartialState().process_index},
76
+ attention_dropout=args.attention_dropout,
77
+ )
78
+
79
+ data = load_dataset(
80
+ args.dataset_name,
81
+ data_dir=args.subset,
82
+ split=args.split,
83
+ token=token,
84
+ num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
85
+ )
86
+
87
+ # setup the trainer
88
+ trainer = SFTTrainer(
89
+ model=model,
90
+ train_dataset=data,
91
+ #max_seq_length=args.max_seq_length,
92
+ args=SFTConfig(
93
+ per_device_train_batch_size=args.micro_batch_size,
94
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
95
+ warmup_steps=args.warmup_steps,
96
+ max_steps=args.max_steps,
97
+ max_seq_length=args.max_seq_length,
98
+ learning_rate=args.learning_rate,
99
+ lr_scheduler_type=args.lr_scheduler_type,
100
+ weight_decay=args.weight_decay,
101
+ bf16=args.bf16,
102
+ #logging_strategy="steps",
103
+ #logging_steps=10,
104
+ output_dir=args.output_dir,
105
+ optim="paged_adamw_8bit",
106
+ seed=args.seed,
107
+ run_name=f"train-{args.model_id.split('/')[-1]}",
108
+ report_to="none",
109
+ ),
110
+ peft_config=lora_config,
111
+ dataset_text_field=args.dataset_text_field,
112
+ )
113
+
114
+ # launch
115
+ print("Training...")
116
+ trainer.train()
117
+
118
+ print("Saving the last checkpoint of the model")
119
+ model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
120
+
121
+ if args.save_merged_model:
122
+ # Free memory for merging weights
123
+ del model
124
+ if is_torch_xpu_available():
125
+ torch.xpu.empty_cache()
126
+ elif is_torch_npu_available():
127
+ torch.npu.empty_cache()
128
+ else:
129
+ torch.cuda.empty_cache()
130
+
131
+ model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir, device_map="auto", torch_dtype=torch.bfloat16)
132
+ model = model.merge_and_unload()
133
+
134
+ output_merged_dir = os.path.join(args.output_dir, "final_merged_checkpoint")
135
+ model.save_pretrained(output_merged_dir, safe_serialization=True)
136
+
137
+ if args.push_to_hub:
138
+ model.push_to_hub(args.repo_id, "Upload model")
139
+
140
+ print("Training Done! 💥")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ args = get_args()
145
+ set_seed(args.seed)
146
+ os.makedirs(args.output_dir, exist_ok=True)
147
+
148
+ logging.set_verbosity_error()
149
+
150
+ main(args)