import torch import gc import numpy as np import json torch.cuda.empty_cache() import torch.distributed from dataset import AgentDatapointDataset import os import wandb from lightning.pytorch.loggers import WandbLogger from peft import get_peft_model, LoraConfig from transformers import TrainerCallback from transformers import BitsAndBytesConfig # from unsloth import is_bf16_supported # This version of qwen requires more vram from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration from trl import SFTTrainer, SFTConfig # This version of qwen requires less vram since is uses compiled componentsand also a fused cross entropy loss # from model import Qwen2_5_VLForConditionalGeneration from transformers import logging as transformers_logging import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) transformers_logging.set_verbosity_error() import argparse from torch.optim import AdamW from qwen_vl_utils import process_vision_info torch.set_float32_matmul_precision("medium") import json from evaluate import evaluate_model from dataset import AgentEvalDatapointDataset, AgentDatapointDataset # Perhaps want to add back these later # from unsloth.models._utils import prepare_model_for_kbit_training # from gradient_checkpointing import patch_unsloth_smart_gradient_checkpointing def train_collate_fn(examples, processor): texts = [ processor.apply_chat_template(example["messages"], tokenize=False) for example in examples ] image_inputs = [process_vision_info(example["messages"])[0] for example in examples] model_inputs = processor( text=texts, images=image_inputs, return_tensors="pt", padding=True ) labels = model_inputs["input_ids"].clone() # mask padding tokens in labels labels[labels == processor.tokenizer.pad_token_id] = -100 if isinstance(processor, Qwen2_5_VLProcessor): image_tokens = [151652, 151653, 151655] else: image_tokens = [ processor.tokenizer.convert_tokens_to_ids(processor.image_token) ] # mask image token IDs in the labels for image_token_id in image_tokens: labels[labels == image_token_id] = -100 # Return a dictionary instead of a tuple return { "input_ids": model_inputs["input_ids"], "attention_mask": model_inputs["attention_mask"], "pixel_values": model_inputs["pixel_values"], "image_grid_thw": model_inputs["image_grid_thw"], "labels": labels, } def _wrap_fast_inference(generate, device_type, dtype, model): # Wraps inference with bfloat16 / float16 @torch.inference_mode def _fast_generate(*args, **kwargs): # For num_logits_to_keep # kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids kwargs.pop("token_type_ids", None) # Check pad_token model_eos_token_id = getattr(model.config, "eos_token_id", None) if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): model_eos_token_id = model_eos_token_id[0] kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) try: kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype) except: pass # Autocasted with torch.autocast(device_type=device_type, dtype=dtype): output = generate(*args, **kwargs) pass return output pass return _fast_generate pass def for_inference(model): model.gradient_checkpointing = False model.training = False for name, module in model.named_modules(): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = False if hasattr(module, "training"): module.training = False pass dtype = model.config.torch_dtype if type(dtype) is str: if dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 pass device_type = model.device.type # Wrap model.generate if model.generate.__name__ != "_fast_generate": model._unwrapped_old_generate = model.generate model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) pass # Patch tokenizer to pad to the left internal_model = model while hasattr(internal_model, "model"): if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" pass internal_model = internal_model.model pass if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" pass # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass if hasattr(model, "get_output_embeddings"): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass return model def for_training(model, use_gradient_checkpointing=True): model.train() model.gradient_checkpointing = use_gradient_checkpointing model.training = True for name, module in model.named_modules(): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = use_gradient_checkpointing if hasattr(module, "training"): module.training = True pass # Also revert model.generate if hasattr(model, "_unwrapped_old_generate"): model.generate = model._unwrapped_old_generate del model._unwrapped_old_generate pass # Patch tokenizer to pad to the right internal_model = model while hasattr(internal_model, "model"): if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" pass internal_model = internal_model.model pass if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass if hasattr(model, "get_output_embeddings"): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass return model class CustomTrainingCallback(TrainerCallback): def __init__(self, trainer, eval_epoch_interval=2): self.trainer = trainer self.eval_epoch_interval = eval_epoch_interval self.best_test_accuracy = 0.0 self.best_test_epoch = 0 self.best_metrics = { 'test_accuracy': 0.0, 'train_accuracy': 0.0, 'epoch': 0, 'global_step': 0 } def save_best_metrics(self, output_dir): """Save best metrics to a file in the checkpoint directory""" metrics_file = os.path.join(output_dir, 'best_metrics.json') with open(metrics_file, 'w') as f: json.dump(self.best_metrics, f, indent=4) print(f"Saved best metrics to {metrics_file}") def on_log(self, args, state, control, logs=None, **kwargs): """Log metrics at each logging step""" if logs is not None: # Ensure wandb is initialized import wandb if not wandb.run: wandb.init( project="qwen-vl-trainer", reinit=True, name=f"{os.environ.get('RANK', '0')}-training", group=os.environ.get("WANDB_RUN_GROUP", None), ) # Log all metrics from the logs dictionary step = state.global_step if hasattr(state, "global_step") else 0 # Extract and log training metrics log_data = {} for key, value in logs.items(): # Prefix training metrics to differentiate from eval metrics if key not in ["eval_loss", "epoch", "learning_rate"]: log_data[f"train/{key}"] = value else: log_data[key] = value wandb.log(log_data, step=step) def on_epoch_end(self, args, state, control, **kwargs): print(f"Epoch {state.epoch + 1} ended") was_training = self.trainer.model.training for_inference(self.trainer.model) self.trainer.model.eval() if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4: # Get test accuracy test_accuracy = self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test") train_accuracy = self.trainer.evaluate_step(dataset=self.trainer.train_dataset_eval, split="train") print(f"Test accuracy: {test_accuracy:.4f}, Train accuracy: {train_accuracy:.4f}") # Update best test accuracy if current is better if test_accuracy > self.best_test_accuracy: self.best_test_accuracy = test_accuracy self.best_test_epoch = state.epoch + 1 # Update best metrics dictionary self.best_metrics.update({ 'best_test_accuracy': float(test_accuracy), 'train_accuracy': float(train_accuracy), 'epoch': int(state.epoch + 1), 'global_step': int(state.global_step) }) # Save best metrics to file self.save_best_metrics(args.output_dir) # Log to wandb print(f"\nNew best test accuracy: {self.best_test_accuracy:.4f} at epoch {self.best_test_epoch}") if was_training: for_training(self.trainer.model) self.trainer.model.train() class CustomSFTTrainer(SFTTrainer): def __init__( self, model, tokenizer, processor, data_collator, train_dataset=None, train_dataset_eval=None, eval_dataset=None, eval_epoch_interval=2, args=None, ): # train_dataset_eval=train_dataset_eval, # train_dataset=train_dataset, # eval_dataset=test_dataset, self.custom_callback = CustomTrainingCallback( self, eval_epoch_interval=eval_epoch_interval ) callbacks = [self.custom_callback] super().__init__( model=model, tokenizer=tokenizer, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, callbacks=callbacks, args=args, ) self.eval_dataset = eval_dataset self.train_dataset_eval = train_dataset_eval self.state = type("State", (), {"global_step": 0})() self.processor = processor def evaluate_step(self, dataset, split): print(f"Evaluating {split} dataset") try: device = self.model.device # The correct signature is: evaluate_model(model, processor, dataset, split, verbose=False) accuracy = evaluate_model(self.model, self.processor, dataset, split) # Initialize wandb if not already initialized import wandb if not wandb.run: wandb.init( project="qwen-vl-trainer", reinit=True, name=f"{os.environ.get('RANK', '0')}-evaluation", group=os.environ.get("WANDB_RUN_GROUP", None), ) wandb.log( { f"{split}/accuracy": accuracy, } ) return accuracy # Return the accuracy value except Exception as e: logger.error(f"Error evaluating: {e}") raise def cleanup(self): """Cleanup method to ensure wandb runs are properly closed""" import wandb if wandb.run: wandb.finish() def load_model(MODEL_ID: str, USE_QLORA: bool, training_args): # patch_unsloth_smart_gradient_checkpointing() # Configure more aggressive quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # More aggressive LoRA config lora_config = LoraConfig( r=200, # Increase rank for more expressiveness lora_alpha=50, # Higher scaling factor lora_dropout=0.001, # Moderate dropout bias="lora_only", target_modules=[ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "gate_proj", "up_proj", "down_proj", "fc1", "fc2", "mlp.0", "mlp.2", ], task_type="CAUSAL_LM", inference_mode=False, modules_to_save=None, ) # Clear memory before model load torch.cuda.empty_cache() gc.collect() # Load DeepSpeed config with open(training_args.deepspeed, "r") as f: ds_config = json.load(f) # Set is_deepspeed_zero3_enabled flag for ZeRO-3 is_deepspeed_zero3_enabled = ( ds_config.get("zero_optimization", {}).get("stage", 0) == 3 ) # Pass DeepSpeed configuration to from_pretrained model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, # quantization_config=bnb_config if USE_QLORA else None, # Use the config torch_dtype=torch.bfloat16, # device_map=None, # Let DeepSpeed handle device mapping use_cache=False, attn_implementation="flash_attention_2", ) # Reset generation config to avoid warnings from transformers import GenerationConfig model.generation_config = GenerationConfig.from_model_config(model.config) # Ensure no conflicting generation parameters model.generation_config.temperature = None model.generation_config.top_p = None model.generation_config.top_k = None model.generation_config.early_stopping = False processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID) model.enable_input_require_grads() # unsloth added this prior to loading peft model = get_peft_model(model, lora_config) model.gradient_checkpointing_enable() model.config.use_cache = False model.config.pretraining_tp = 1 # More aggressive gradient checkpointing model.config.gradient_checkpointing = True model.config.use_reentrant = False model.config.gradient_checkpointing_kwargs = { "use_reentrant": False, "checkpoint_every_n_layers": 1, "offload_to_cpu": True, } return model, processor def train(args): # Set CUDA device explicitly based on local_rank if args.local_rank != -1: torch.cuda.set_device(args.local_rank) # Initialize process group with the correct device if not torch.distributed.is_initialized(): # Get world size from environment if available world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) rank = int(os.environ.get("RANK", args.local_rank)) print( f"Initializing process group with rank={rank}, world_size={world_size}" ) try: torch.distributed.init_process_group( backend="nccl", init_method="env://", world_size=world_size, rank=rank, ) print(f"Successfully initialized process group for rank {rank}") except Exception as e: print(f"Could not initialize process group: {e}") # Remove memory management env vars that might interfere with DeepSpeed os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None) os.environ.pop("MAX_JOBS", None) os.environ.pop("CUDA_LAUNCH_BLOCKING", None) # Set up DeepSpeed config path first ds_config_path = "deepspeed_config.json" # Set up wandb configuration os.environ["WANDB_MODE"] = "online" # Create a unique timestamp for this training run import datetime timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") run_id = timestamp os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}" # Create a timestamped output directory timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}") os.makedirs(timestamped_output_dir, exist_ok=True) print(f"Model checkpoints will be saved to: {timestamped_output_dir}") # Configure wandb properly for Trainer os.environ["WANDB_PROJECT"] = "qwen-vl-trainer" os.environ["WANDB_LOG_MODEL"] = "end" # Changed from "true" to "end" os.environ["WANDB_WATCH"] = "all" # Monitor all gradients and parameters os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}" # Initialize wandb only once at the beginning for the main process if args.local_rank <= 0: # Only initialize on rank 0 or single GPU import wandb wandb.init( project="qwen-vl-trainer", name=f"transformer_training_{timestamp}", group=os.environ.get("WANDB_RUN_GROUP"), # Important: we're logging the model as an artifact settings=wandb.Settings(_disable_stats=True, _disable_meta=True), ) # Log config information wandb.config.update( { "model_id": args.model_id, "use_qlora": args.use_qlora, "output_dir": timestamped_output_dir, } ) print(f"Initialized wandb with run ID: {wandb.run.id}") # Create SFTConfig with DeepSpeed config before loading the model training_args = SFTConfig( per_device_train_batch_size=1, # Equivalent to train_micro_batch_size_per_gpu gradient_accumulation_steps=2, logging_steps=1, # Log every step logging_strategy="steps", # Log based on steps log_level="info", num_train_epochs=2000, # Set to desired number of epochs # eval_steps=100, bf16=True, optim="adamw_8bit", lr_scheduler_type="linear", seed=3407, output_dir=timestamped_output_dir, # Use timestamped directory overwrite_output_dir=True, report_to="wandb", # Explicitly report to wandb remove_unused_columns=False, dataset_text_field="", dataset_kwargs={"skip_prepare_dataset": True}, dataset_num_proc=4, max_seq_length=800000, save_strategy="epoch", evaluation_strategy="no", save_total_limit=2000, deepspeed=ds_config_path, # Pass the DeepSpeed config ) # Dynamically set devices based on availability num_gpus = torch.cuda.device_count() devices = list(range(num_gpus)) if num_gpus > 0 else None # Pass training args to load_model function model, processor = load_model(args.model_id, args.use_qlora, training_args) # Train dataset train_dataset = AgentDatapointDataset(split="train", num_samples=args.train_size) # Eval datasets test_dataset = AgentEvalDatapointDataset(split="test", num_samples=args.test_size) train_dataset_eval = AgentEvalDatapointDataset(split="train", num_samples=args.train_size) for_training(model) trainer = CustomSFTTrainer( model=model, processor=processor, tokenizer=processor.tokenizer, data_collator=lambda examples: train_collate_fn(examples, processor), train_dataset_eval=train_dataset_eval, train_dataset=train_dataset, eval_dataset=test_dataset, args=training_args, ) training_stats = trainer.train() logger.info("Training completed.") print(f"Training Statistics: {training_stats}") # Save the final model explicitly with timestamp final_model_path = os.path.join(timestamped_output_dir, "final_model") if args.local_rank <= 0: # Only save on rank 0 or single GPU print(f"Saving final model to {final_model_path}") trainer.save_model(final_model_path) print(f"Final model saved to {final_model_path}") # Also save the processor processor.save_pretrained(final_model_path) # Log the final model to wandb # import wandb # if wandb.run: # model_artifact = wandb.Artifact( # name=f"model_{timestamp}", # type="model", # description=f"Final trained model from run {timestamp}" # ) # model_artifact.add_dir(final_model_path) # wandb.log_artifact(model_artifact) # print(f"Final model logged to wandb as artifact: model_{timestamp}") # # print(f"Final model saved to {final_model_path}") # Ensure proper cleanup of wandb trainer.cleanup() # Final cleanup for the main process if args.local_rank <= 0: # Only finalize on rank 0 or single GPU import wandb if wandb.run: print("Finalizing main wandb run...") wandb.finish() print("Training process completed successfully.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Training configuration") parser.add_argument( "--model_id", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Model ID to use", ) parser.add_argument( "--use_qlora", type=bool, default=True, help="Whether to use QLoRA" ) parser.add_argument( "--output_dir", type=str, default="checkpoints_27feb", help="Output directory" ) # Add local_rank argument for DeepSpeed parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training" ) parser.add_argument( "--train_size", type=int, default=10000000, help="Number of training samples" ) parser.add_argument( "--test_size", type=int, default=10000000, help="Number of test samples" ) args = parser.parse_args() train(args)