dufry2024's picture
Create train.py
e6d9233 verified
# Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py
# and https://huggingface.co/blog/gemma-peft
import argparse
import multiprocessing
import os
import torch
import transformers
from accelerate import PartialState
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
is_torch_npu_available,
is_torch_xpu_available,
logging,
set_seed,
)
from trl import SFTTrainer, SFTConfig
#from trl import SFTTrainer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B")
parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol")
parser.add_argument("--subset", type=str, default="data/python")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--dataset_text_field", type=str, default="content")
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--max_steps", type=int, default=1000)
parser.add_argument("--micro_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--bf16", type=bool, default=True)
parser.add_argument("--use_bnb", type=bool, default=False)
parser.add_argument("--attention_dropout", type=float, default=0.1)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python")
parser.add_argument("--num_proc", type=int, default=None)
parser.add_argument("--save_merged_model", type=bool, default=True)
parser.add_argument("--push_to_hub", type=bool, default=True)
parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune")
return parser.parse_args()
def main(args):
# config
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
bnb_config = None
if args.use_bnb:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# load model and dataset
token = os.environ.get("HF_TOKEN", None)
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
quantization_config=bnb_config,
device_map={"": PartialState().process_index},
attention_dropout=args.attention_dropout,
)
data = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
token=token,
num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
)
# setup the trainer
trainer = SFTTrainer(
model=model,
train_dataset=data,
#max_seq_length=args.max_seq_length,
args=SFTConfig(
per_device_train_batch_size=args.micro_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
max_seq_length=args.max_seq_length,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
weight_decay=args.weight_decay,
bf16=args.bf16,
#logging_strategy="steps",
#logging_steps=10,
output_dir=args.output_dir,
optim="paged_adamw_8bit",
seed=args.seed,
run_name=f"train-{args.model_id.split('/')[-1]}",
report_to="none",
),
peft_config=lora_config,
dataset_text_field=args.dataset_text_field,
)
# launch
print("Training...")
trainer.train()
print("Saving the last checkpoint of the model")
model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
if args.save_merged_model:
# Free memory for merging weights
del model
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
output_merged_dir = os.path.join(args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)
if args.push_to_hub:
model.push_to_hub(args.repo_id, "Upload model")
print("Training Done! 💥")
if __name__ == "__main__":
args = get_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
logging.set_verbosity_error()
main(args)