qwen2.5-7b-custom / train_transformer.py
alexhotti's picture
Upload folder using huggingface_hub
4b105b2 verified
import torch
import gc
import numpy as np
import json
torch.cuda.empty_cache()
import torch.distributed
from dataset import AgentDatapointDataset
import os
import wandb
from lightning.pytorch.loggers import WandbLogger
from peft import get_peft_model, LoraConfig
from transformers import TrainerCallback
from transformers import BitsAndBytesConfig
# from unsloth import is_bf16_supported
# This version of qwen requires more vram
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
from trl import SFTTrainer, SFTConfig
# This version of qwen requires less vram since is uses compiled componentsand also a fused cross entropy loss
# from model import Qwen2_5_VLForConditionalGeneration
from transformers import logging as transformers_logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
transformers_logging.set_verbosity_error()
import argparse
from torch.optim import AdamW
from qwen_vl_utils import process_vision_info
torch.set_float32_matmul_precision("medium")
import json
from evaluate import evaluate_model
from dataset import AgentEvalDatapointDataset, AgentDatapointDataset
# Perhaps want to add back these later
# from unsloth.models._utils import prepare_model_for_kbit_training
# from gradient_checkpointing import patch_unsloth_smart_gradient_checkpointing
def train_collate_fn(examples, processor):
texts = [
processor.apply_chat_template(example["messages"], tokenize=False)
for example in examples
]
image_inputs = [process_vision_info(example["messages"])[0] for example in examples]
model_inputs = processor(
text=texts, images=image_inputs, return_tensors="pt", padding=True
)
labels = model_inputs["input_ids"].clone()
# mask padding tokens in labels
labels[labels == processor.tokenizer.pad_token_id] = -100
if isinstance(processor, Qwen2_5_VLProcessor):
image_tokens = [151652, 151653, 151655]
else:
image_tokens = [
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
]
# mask image token IDs in the labels
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100
# Return a dictionary instead of a tuple
return {
"input_ids": model_inputs["input_ids"],
"attention_mask": model_inputs["attention_mask"],
"pixel_values": model_inputs["pixel_values"],
"image_grid_thw": model_inputs["image_grid_thw"],
"labels": labels,
}
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):
# For num_logits_to_keep
# kwargs["num_logits_to_keep"] = 1
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
# Check pad_token
model_eos_token_id = getattr(model.config, "eos_token_id", None)
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
model_eos_token_id = model_eos_token_id[0]
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
try:
kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype)
except:
pass
# Autocasted
with torch.autocast(device_type=device_type, dtype=dtype):
output = generate(*args, **kwargs)
pass
return output
pass
return _fast_generate
pass
def for_inference(model):
model.gradient_checkpointing = False
model.training = False
for name, module in model.named_modules():
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = False
if hasattr(module, "training"):
module.training = False
pass
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
pass
device_type = model.device.type
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
# Patch tokenizer to pad to the left
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
pass
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = False
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = False
pass
return model
def for_training(model, use_gradient_checkpointing=True):
model.train()
model.gradient_checkpointing = use_gradient_checkpointing
model.training = True
for name, module in model.named_modules():
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = use_gradient_checkpointing
if hasattr(module, "training"):
module.training = True
pass
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
model.generate = model._unwrapped_old_generate
del model._unwrapped_old_generate
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
pass
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = True
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = True
pass
return model
class CustomTrainingCallback(TrainerCallback):
def __init__(self, trainer, eval_epoch_interval=2):
self.trainer = trainer
self.eval_epoch_interval = eval_epoch_interval
self.best_test_accuracy = 0.0
self.best_test_epoch = 0
self.best_metrics = {
'test_accuracy': 0.0,
'train_accuracy': 0.0,
'epoch': 0,
'global_step': 0
}
def save_best_metrics(self, output_dir):
"""Save best metrics to a file in the checkpoint directory"""
metrics_file = os.path.join(output_dir, 'best_metrics.json')
with open(metrics_file, 'w') as f:
json.dump(self.best_metrics, f, indent=4)
print(f"Saved best metrics to {metrics_file}")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log metrics at each logging step"""
if logs is not None:
# Ensure wandb is initialized
import wandb
if not wandb.run:
wandb.init(
project="qwen-vl-trainer",
reinit=True,
name=f"{os.environ.get('RANK', '0')}-training",
group=os.environ.get("WANDB_RUN_GROUP", None),
)
# Log all metrics from the logs dictionary
step = state.global_step if hasattr(state, "global_step") else 0
# Extract and log training metrics
log_data = {}
for key, value in logs.items():
# Prefix training metrics to differentiate from eval metrics
if key not in ["eval_loss", "epoch", "learning_rate"]:
log_data[f"train/{key}"] = value
else:
log_data[key] = value
wandb.log(log_data, step=step)
def on_epoch_end(self, args, state, control, **kwargs):
print(f"Epoch {state.epoch + 1} ended")
was_training = self.trainer.model.training
for_inference(self.trainer.model)
self.trainer.model.eval()
if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4:
# Get test accuracy
test_accuracy = self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test")
train_accuracy = self.trainer.evaluate_step(dataset=self.trainer.train_dataset_eval, split="train")
print(f"Test accuracy: {test_accuracy:.4f}, Train accuracy: {train_accuracy:.4f}")
# Update best test accuracy if current is better
if test_accuracy > self.best_test_accuracy:
self.best_test_accuracy = test_accuracy
self.best_test_epoch = state.epoch + 1
# Update best metrics dictionary
self.best_metrics.update({
'best_test_accuracy': float(test_accuracy),
'train_accuracy': float(train_accuracy),
'epoch': int(state.epoch + 1),
'global_step': int(state.global_step)
})
# Save best metrics to file
self.save_best_metrics(args.output_dir)
# Log to wandb
print(f"\nNew best test accuracy: {self.best_test_accuracy:.4f} at epoch {self.best_test_epoch}")
if was_training:
for_training(self.trainer.model)
self.trainer.model.train()
class CustomSFTTrainer(SFTTrainer):
def __init__(
self,
model,
tokenizer,
processor,
data_collator,
train_dataset=None,
train_dataset_eval=None,
eval_dataset=None,
eval_epoch_interval=2,
args=None,
):
# train_dataset_eval=train_dataset_eval,
# train_dataset=train_dataset,
# eval_dataset=test_dataset,
self.custom_callback = CustomTrainingCallback(
self, eval_epoch_interval=eval_epoch_interval
)
callbacks = [self.custom_callback]
super().__init__(
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
args=args,
)
self.eval_dataset = eval_dataset
self.train_dataset_eval = train_dataset_eval
self.state = type("State", (), {"global_step": 0})()
self.processor = processor
def evaluate_step(self, dataset, split):
print(f"Evaluating {split} dataset")
try:
device = self.model.device
# The correct signature is: evaluate_model(model, processor, dataset, split, verbose=False)
accuracy = evaluate_model(self.model, self.processor, dataset, split)
# Initialize wandb if not already initialized
import wandb
if not wandb.run:
wandb.init(
project="qwen-vl-trainer",
reinit=True,
name=f"{os.environ.get('RANK', '0')}-evaluation",
group=os.environ.get("WANDB_RUN_GROUP", None),
)
wandb.log(
{
f"{split}/accuracy": accuracy,
}
)
return accuracy # Return the accuracy value
except Exception as e:
logger.error(f"Error evaluating: {e}")
raise
def cleanup(self):
"""Cleanup method to ensure wandb runs are properly closed"""
import wandb
if wandb.run:
wandb.finish()
def load_model(MODEL_ID: str, USE_QLORA: bool, training_args):
# patch_unsloth_smart_gradient_checkpointing()
# Configure more aggressive quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# More aggressive LoRA config
lora_config = LoraConfig(
r=200, # Increase rank for more expressiveness
lora_alpha=50, # Higher scaling factor
lora_dropout=0.001, # Moderate dropout
bias="lora_only",
target_modules=[
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"gate_proj",
"up_proj",
"down_proj",
"fc1",
"fc2",
"mlp.0",
"mlp.2",
],
task_type="CAUSAL_LM",
inference_mode=False,
modules_to_save=None,
)
# Clear memory before model load
torch.cuda.empty_cache()
gc.collect()
# Load DeepSpeed config
with open(training_args.deepspeed, "r") as f:
ds_config = json.load(f)
# Set is_deepspeed_zero3_enabled flag for ZeRO-3
is_deepspeed_zero3_enabled = (
ds_config.get("zero_optimization", {}).get("stage", 0) == 3
)
# Pass DeepSpeed configuration to from_pretrained
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
# quantization_config=bnb_config if USE_QLORA else None, # Use the config
torch_dtype=torch.bfloat16,
# device_map=None, # Let DeepSpeed handle device mapping
use_cache=False,
attn_implementation="flash_attention_2",
)
# Reset generation config to avoid warnings
from transformers import GenerationConfig
model.generation_config = GenerationConfig.from_model_config(model.config)
# Ensure no conflicting generation parameters
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
model.generation_config.early_stopping = False
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
model.enable_input_require_grads() # unsloth added this prior to loading peft
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.config.use_cache = False
model.config.pretraining_tp = 1
# More aggressive gradient checkpointing
model.config.gradient_checkpointing = True
model.config.use_reentrant = False
model.config.gradient_checkpointing_kwargs = {
"use_reentrant": False,
"checkpoint_every_n_layers": 1,
"offload_to_cpu": True,
}
return model, processor
def train(args):
# Set CUDA device explicitly based on local_rank
if args.local_rank != -1:
torch.cuda.set_device(args.local_rank)
# Initialize process group with the correct device
if not torch.distributed.is_initialized():
# Get world size from environment if available
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
rank = int(os.environ.get("RANK", args.local_rank))
print(
f"Initializing process group with rank={rank}, world_size={world_size}"
)
try:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
print(f"Successfully initialized process group for rank {rank}")
except Exception as e:
print(f"Could not initialize process group: {e}")
# Remove memory management env vars that might interfere with DeepSpeed
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
os.environ.pop("MAX_JOBS", None)
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
# Set up DeepSpeed config path first
ds_config_path = "deepspeed_config.json"
# Set up wandb configuration
os.environ["WANDB_MODE"] = "online"
# Create a unique timestamp for this training run
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = timestamp
os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}"
# Create a timestamped output directory
timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
os.makedirs(timestamped_output_dir, exist_ok=True)
print(f"Model checkpoints will be saved to: {timestamped_output_dir}")
# Configure wandb properly for Trainer
os.environ["WANDB_PROJECT"] = "qwen-vl-trainer"
os.environ["WANDB_LOG_MODEL"] = "end" # Changed from "true" to "end"
os.environ["WANDB_WATCH"] = "all" # Monitor all gradients and parameters
os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}"
# Initialize wandb only once at the beginning for the main process
if args.local_rank <= 0: # Only initialize on rank 0 or single GPU
import wandb
wandb.init(
project="qwen-vl-trainer",
name=f"transformer_training_{timestamp}",
group=os.environ.get("WANDB_RUN_GROUP"),
# Important: we're logging the model as an artifact
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
)
# Log config information
wandb.config.update(
{
"model_id": args.model_id,
"use_qlora": args.use_qlora,
"output_dir": timestamped_output_dir,
}
)
print(f"Initialized wandb with run ID: {wandb.run.id}")
# Create SFTConfig with DeepSpeed config before loading the model
training_args = SFTConfig(
per_device_train_batch_size=1, # Equivalent to train_micro_batch_size_per_gpu
gradient_accumulation_steps=2,
logging_steps=1, # Log every step
logging_strategy="steps", # Log based on steps
log_level="info",
num_train_epochs=2000, # Set to desired number of epochs
# eval_steps=100,
bf16=True,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir=timestamped_output_dir, # Use timestamped directory
overwrite_output_dir=True,
report_to="wandb", # Explicitly report to wandb
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=800000,
save_strategy="epoch",
evaluation_strategy="no",
save_total_limit=2000,
deepspeed=ds_config_path, # Pass the DeepSpeed config
)
# Dynamically set devices based on availability
num_gpus = torch.cuda.device_count()
devices = list(range(num_gpus)) if num_gpus > 0 else None
# Pass training args to load_model function
model, processor = load_model(args.model_id, args.use_qlora, training_args)
# Train dataset
train_dataset = AgentDatapointDataset(split="train", num_samples=args.train_size)
# Eval datasets
test_dataset = AgentEvalDatapointDataset(split="test", num_samples=args.test_size)
train_dataset_eval = AgentEvalDatapointDataset(split="train", num_samples=args.train_size)
for_training(model)
trainer = CustomSFTTrainer(
model=model,
processor=processor,
tokenizer=processor.tokenizer,
data_collator=lambda examples: train_collate_fn(examples, processor),
train_dataset_eval=train_dataset_eval,
train_dataset=train_dataset,
eval_dataset=test_dataset,
args=training_args,
)
training_stats = trainer.train()
logger.info("Training completed.")
print(f"Training Statistics: {training_stats}")
# Save the final model explicitly with timestamp
final_model_path = os.path.join(timestamped_output_dir, "final_model")
if args.local_rank <= 0: # Only save on rank 0 or single GPU
print(f"Saving final model to {final_model_path}")
trainer.save_model(final_model_path)
print(f"Final model saved to {final_model_path}")
# Also save the processor
processor.save_pretrained(final_model_path)
# Log the final model to wandb
# import wandb
# if wandb.run:
# model_artifact = wandb.Artifact(
# name=f"model_{timestamp}",
# type="model",
# description=f"Final trained model from run {timestamp}"
# )
# model_artifact.add_dir(final_model_path)
# wandb.log_artifact(model_artifact)
# print(f"Final model logged to wandb as artifact: model_{timestamp}")
#
# print(f"Final model saved to {final_model_path}")
# Ensure proper cleanup of wandb
trainer.cleanup()
# Final cleanup for the main process
if args.local_rank <= 0: # Only finalize on rank 0 or single GPU
import wandb
if wandb.run:
print("Finalizing main wandb run...")
wandb.finish()
print("Training process completed successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training configuration")
parser.add_argument(
"--model_id",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model ID to use",
)
parser.add_argument(
"--use_qlora", type=bool, default=True, help="Whether to use QLoRA"
)
parser.add_argument(
"--output_dir", type=str, default="checkpoints_27feb", help="Output directory"
)
# Add local_rank argument for DeepSpeed
parser.add_argument(
"--local_rank", type=int, default=-1, help="Local rank for distributed training"
)
parser.add_argument(
"--train_size", type=int, default=10000000, help="Number of training samples"
)
parser.add_argument(
"--test_size", type=int, default=10000000, help="Number of test samples"
)
args = parser.parse_args()
train(args)