# Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py # and https://huggingface.co/blog/gemma-peft import argparse import multiprocessing import os import torch import transformers from accelerate import PartialState from datasets import load_dataset from peft import AutoPeftModelForCausalLM, LoraConfig from transformers import ( AutoModelForCausalLM, BitsAndBytesConfig, is_torch_npu_available, is_torch_xpu_available, logging, set_seed, ) from trl import SFTTrainer, SFTConfig #from trl import SFTTrainer def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B") parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol") parser.add_argument("--subset", type=str, default="data/python") parser.add_argument("--split", type=str, default="train") parser.add_argument("--dataset_text_field", type=str, default="content") parser.add_argument("--max_seq_length", type=int, default=2048) parser.add_argument("--max_steps", type=int, default=1000) parser.add_argument("--micro_batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--bf16", type=bool, default=True) parser.add_argument("--use_bnb", type=bool, default=False) parser.add_argument("--attention_dropout", type=float, default=0.1) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--lr_scheduler_type", type=str, default="cosine") parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python") parser.add_argument("--num_proc", type=int, default=None) parser.add_argument("--save_merged_model", type=bool, default=True) parser.add_argument("--push_to_hub", type=bool, default=True) parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune") return parser.parse_args() def main(args): # config lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"], bias="none", task_type="CAUSAL_LM", ) bnb_config = None if args.use_bnb: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # load model and dataset token = os.environ.get("HF_TOKEN", None) model = AutoModelForCausalLM.from_pretrained( args.model_id, quantization_config=bnb_config, device_map={"": PartialState().process_index}, attention_dropout=args.attention_dropout, ) data = load_dataset( args.dataset_name, data_dir=args.subset, split=args.split, token=token, num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(), ) # setup the trainer trainer = SFTTrainer( model=model, train_dataset=data, #max_seq_length=args.max_seq_length, args=SFTConfig( per_device_train_batch_size=args.micro_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, warmup_steps=args.warmup_steps, max_steps=args.max_steps, max_seq_length=args.max_seq_length, learning_rate=args.learning_rate, lr_scheduler_type=args.lr_scheduler_type, weight_decay=args.weight_decay, bf16=args.bf16, #logging_strategy="steps", #logging_steps=10, output_dir=args.output_dir, optim="paged_adamw_8bit", seed=args.seed, run_name=f"train-{args.model_id.split('/')[-1]}", report_to="none", ), peft_config=lora_config, dataset_text_field=args.dataset_text_field, ) # launch print("Training...") trainer.train() print("Saving the last checkpoint of the model") model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) if args.save_merged_model: # Free memory for merging weights del model if is_torch_xpu_available(): torch.xpu.empty_cache() elif is_torch_npu_available(): torch.npu.empty_cache() else: torch.cuda.empty_cache() model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir, device_map="auto", torch_dtype=torch.bfloat16) model = model.merge_and_unload() output_merged_dir = os.path.join(args.output_dir, "final_merged_checkpoint") model.save_pretrained(output_merged_dir, safe_serialization=True) if args.push_to_hub: model.push_to_hub(args.repo_id, "Upload model") print("Training Done! 💥") if __name__ == "__main__": args = get_args() set_seed(args.seed) os.makedirs(args.output_dir, exist_ok=True) logging.set_verbosity_error() main(args)