Create train.py
Browse files
train.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py
|
2 |
+
# and https://huggingface.co/blog/gemma-peft
|
3 |
+
import argparse
|
4 |
+
import multiprocessing
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
from accelerate import PartialState
|
10 |
+
from datasets import load_dataset
|
11 |
+
from peft import AutoPeftModelForCausalLM, LoraConfig
|
12 |
+
from transformers import (
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
BitsAndBytesConfig,
|
15 |
+
is_torch_npu_available,
|
16 |
+
is_torch_xpu_available,
|
17 |
+
logging,
|
18 |
+
set_seed,
|
19 |
+
)
|
20 |
+
from trl import SFTTrainer, SFTConfig
|
21 |
+
#from trl import SFTTrainer
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B")
|
27 |
+
parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol")
|
28 |
+
parser.add_argument("--subset", type=str, default="data/python")
|
29 |
+
parser.add_argument("--split", type=str, default="train")
|
30 |
+
parser.add_argument("--dataset_text_field", type=str, default="content")
|
31 |
+
|
32 |
+
parser.add_argument("--max_seq_length", type=int, default=2048)
|
33 |
+
parser.add_argument("--max_steps", type=int, default=1000)
|
34 |
+
parser.add_argument("--micro_batch_size", type=int, default=1)
|
35 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
36 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
37 |
+
parser.add_argument("--bf16", type=bool, default=True)
|
38 |
+
|
39 |
+
parser.add_argument("--use_bnb", type=bool, default=False)
|
40 |
+
parser.add_argument("--attention_dropout", type=float, default=0.1)
|
41 |
+
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
42 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
43 |
+
parser.add_argument("--warmup_steps", type=int, default=100)
|
44 |
+
parser.add_argument("--seed", type=int, default=0)
|
45 |
+
parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python")
|
46 |
+
parser.add_argument("--num_proc", type=int, default=None)
|
47 |
+
parser.add_argument("--save_merged_model", type=bool, default=True)
|
48 |
+
parser.add_argument("--push_to_hub", type=bool, default=True)
|
49 |
+
parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune")
|
50 |
+
return parser.parse_args()
|
51 |
+
|
52 |
+
|
53 |
+
def main(args):
|
54 |
+
# config
|
55 |
+
lora_config = LoraConfig(
|
56 |
+
r=16,
|
57 |
+
lora_alpha=32,
|
58 |
+
lora_dropout=0.05,
|
59 |
+
target_modules=["q_proj", "v_proj"],
|
60 |
+
bias="none",
|
61 |
+
task_type="CAUSAL_LM",
|
62 |
+
)
|
63 |
+
bnb_config = None
|
64 |
+
if args.use_bnb:
|
65 |
+
bnb_config = BitsAndBytesConfig(
|
66 |
+
load_in_4bit=True,
|
67 |
+
bnb_4bit_quant_type="nf4",
|
68 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
69 |
+
)
|
70 |
+
# load model and dataset
|
71 |
+
token = os.environ.get("HF_TOKEN", None)
|
72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
args.model_id,
|
74 |
+
quantization_config=bnb_config,
|
75 |
+
device_map={"": PartialState().process_index},
|
76 |
+
attention_dropout=args.attention_dropout,
|
77 |
+
)
|
78 |
+
|
79 |
+
data = load_dataset(
|
80 |
+
args.dataset_name,
|
81 |
+
data_dir=args.subset,
|
82 |
+
split=args.split,
|
83 |
+
token=token,
|
84 |
+
num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
|
85 |
+
)
|
86 |
+
|
87 |
+
# setup the trainer
|
88 |
+
trainer = SFTTrainer(
|
89 |
+
model=model,
|
90 |
+
train_dataset=data,
|
91 |
+
#max_seq_length=args.max_seq_length,
|
92 |
+
args=SFTConfig(
|
93 |
+
per_device_train_batch_size=args.micro_batch_size,
|
94 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
95 |
+
warmup_steps=args.warmup_steps,
|
96 |
+
max_steps=args.max_steps,
|
97 |
+
max_seq_length=args.max_seq_length,
|
98 |
+
learning_rate=args.learning_rate,
|
99 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
100 |
+
weight_decay=args.weight_decay,
|
101 |
+
bf16=args.bf16,
|
102 |
+
#logging_strategy="steps",
|
103 |
+
#logging_steps=10,
|
104 |
+
output_dir=args.output_dir,
|
105 |
+
optim="paged_adamw_8bit",
|
106 |
+
seed=args.seed,
|
107 |
+
run_name=f"train-{args.model_id.split('/')[-1]}",
|
108 |
+
report_to="none",
|
109 |
+
),
|
110 |
+
peft_config=lora_config,
|
111 |
+
dataset_text_field=args.dataset_text_field,
|
112 |
+
)
|
113 |
+
|
114 |
+
# launch
|
115 |
+
print("Training...")
|
116 |
+
trainer.train()
|
117 |
+
|
118 |
+
print("Saving the last checkpoint of the model")
|
119 |
+
model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
|
120 |
+
|
121 |
+
if args.save_merged_model:
|
122 |
+
# Free memory for merging weights
|
123 |
+
del model
|
124 |
+
if is_torch_xpu_available():
|
125 |
+
torch.xpu.empty_cache()
|
126 |
+
elif is_torch_npu_available():
|
127 |
+
torch.npu.empty_cache()
|
128 |
+
else:
|
129 |
+
torch.cuda.empty_cache()
|
130 |
+
|
131 |
+
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir, device_map="auto", torch_dtype=torch.bfloat16)
|
132 |
+
model = model.merge_and_unload()
|
133 |
+
|
134 |
+
output_merged_dir = os.path.join(args.output_dir, "final_merged_checkpoint")
|
135 |
+
model.save_pretrained(output_merged_dir, safe_serialization=True)
|
136 |
+
|
137 |
+
if args.push_to_hub:
|
138 |
+
model.push_to_hub(args.repo_id, "Upload model")
|
139 |
+
|
140 |
+
print("Training Done! 💥")
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
args = get_args()
|
145 |
+
set_seed(args.seed)
|
146 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
147 |
+
|
148 |
+
logging.set_verbosity_error()
|
149 |
+
|
150 |
+
main(args)
|