File size: 22,697 Bytes
0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 0c4bdb6 4b105b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 |
import torch
import gc
import numpy as np
import json
torch.cuda.empty_cache()
import torch.distributed
from dataset import AgentDatapointDataset
import os
import wandb
from lightning.pytorch.loggers import WandbLogger
from peft import get_peft_model, LoraConfig
from transformers import TrainerCallback
from transformers import BitsAndBytesConfig
# from unsloth import is_bf16_supported
# This version of qwen requires more vram
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
from trl import SFTTrainer, SFTConfig
# This version of qwen requires less vram since is uses compiled componentsand also a fused cross entropy loss
# from model import Qwen2_5_VLForConditionalGeneration
from transformers import logging as transformers_logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
transformers_logging.set_verbosity_error()
import argparse
from torch.optim import AdamW
from qwen_vl_utils import process_vision_info
torch.set_float32_matmul_precision("medium")
import json
from evaluate import evaluate_model
from dataset import AgentEvalDatapointDataset, AgentDatapointDataset
# Perhaps want to add back these later
# from unsloth.models._utils import prepare_model_for_kbit_training
# from gradient_checkpointing import patch_unsloth_smart_gradient_checkpointing
def train_collate_fn(examples, processor):
texts = [
processor.apply_chat_template(example["messages"], tokenize=False)
for example in examples
]
image_inputs = [process_vision_info(example["messages"])[0] for example in examples]
model_inputs = processor(
text=texts, images=image_inputs, return_tensors="pt", padding=True
)
labels = model_inputs["input_ids"].clone()
# mask padding tokens in labels
labels[labels == processor.tokenizer.pad_token_id] = -100
if isinstance(processor, Qwen2_5_VLProcessor):
image_tokens = [151652, 151653, 151655]
else:
image_tokens = [
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
]
# mask image token IDs in the labels
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100
# Return a dictionary instead of a tuple
return {
"input_ids": model_inputs["input_ids"],
"attention_mask": model_inputs["attention_mask"],
"pixel_values": model_inputs["pixel_values"],
"image_grid_thw": model_inputs["image_grid_thw"],
"labels": labels,
}
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):
# For num_logits_to_keep
# kwargs["num_logits_to_keep"] = 1
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
# Check pad_token
model_eos_token_id = getattr(model.config, "eos_token_id", None)
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
model_eos_token_id = model_eos_token_id[0]
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
try:
kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype)
except:
pass
# Autocasted
with torch.autocast(device_type=device_type, dtype=dtype):
output = generate(*args, **kwargs)
pass
return output
pass
return _fast_generate
pass
def for_inference(model):
model.gradient_checkpointing = False
model.training = False
for name, module in model.named_modules():
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = False
if hasattr(module, "training"):
module.training = False
pass
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
pass
device_type = model.device.type
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
# Patch tokenizer to pad to the left
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
pass
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = False
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = False
pass
return model
def for_training(model, use_gradient_checkpointing=True):
model.train()
model.gradient_checkpointing = use_gradient_checkpointing
model.training = True
for name, module in model.named_modules():
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = use_gradient_checkpointing
if hasattr(module, "training"):
module.training = True
pass
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
model.generate = model._unwrapped_old_generate
del model._unwrapped_old_generate
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
pass
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = True
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"):
embeddings.training = True
pass
return model
class CustomTrainingCallback(TrainerCallback):
def __init__(self, trainer, eval_epoch_interval=2):
self.trainer = trainer
self.eval_epoch_interval = eval_epoch_interval
self.best_test_accuracy = 0.0
self.best_test_epoch = 0
self.best_metrics = {
'test_accuracy': 0.0,
'train_accuracy': 0.0,
'epoch': 0,
'global_step': 0
}
def save_best_metrics(self, output_dir):
"""Save best metrics to a file in the checkpoint directory"""
metrics_file = os.path.join(output_dir, 'best_metrics.json')
with open(metrics_file, 'w') as f:
json.dump(self.best_metrics, f, indent=4)
print(f"Saved best metrics to {metrics_file}")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log metrics at each logging step"""
if logs is not None:
# Ensure wandb is initialized
import wandb
if not wandb.run:
wandb.init(
project="qwen-vl-trainer",
reinit=True,
name=f"{os.environ.get('RANK', '0')}-training",
group=os.environ.get("WANDB_RUN_GROUP", None),
)
# Log all metrics from the logs dictionary
step = state.global_step if hasattr(state, "global_step") else 0
# Extract and log training metrics
log_data = {}
for key, value in logs.items():
# Prefix training metrics to differentiate from eval metrics
if key not in ["eval_loss", "epoch", "learning_rate"]:
log_data[f"train/{key}"] = value
else:
log_data[key] = value
wandb.log(log_data, step=step)
def on_epoch_end(self, args, state, control, **kwargs):
print(f"Epoch {state.epoch + 1} ended")
was_training = self.trainer.model.training
for_inference(self.trainer.model)
self.trainer.model.eval()
if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4:
# Get test accuracy
test_accuracy = self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test")
train_accuracy = self.trainer.evaluate_step(dataset=self.trainer.train_dataset_eval, split="train")
print(f"Test accuracy: {test_accuracy:.4f}, Train accuracy: {train_accuracy:.4f}")
# Update best test accuracy if current is better
if test_accuracy > self.best_test_accuracy:
self.best_test_accuracy = test_accuracy
self.best_test_epoch = state.epoch + 1
# Update best metrics dictionary
self.best_metrics.update({
'best_test_accuracy': float(test_accuracy),
'train_accuracy': float(train_accuracy),
'epoch': int(state.epoch + 1),
'global_step': int(state.global_step)
})
# Save best metrics to file
self.save_best_metrics(args.output_dir)
# Log to wandb
print(f"\nNew best test accuracy: {self.best_test_accuracy:.4f} at epoch {self.best_test_epoch}")
if was_training:
for_training(self.trainer.model)
self.trainer.model.train()
class CustomSFTTrainer(SFTTrainer):
def __init__(
self,
model,
tokenizer,
processor,
data_collator,
train_dataset=None,
train_dataset_eval=None,
eval_dataset=None,
eval_epoch_interval=2,
args=None,
):
# train_dataset_eval=train_dataset_eval,
# train_dataset=train_dataset,
# eval_dataset=test_dataset,
self.custom_callback = CustomTrainingCallback(
self, eval_epoch_interval=eval_epoch_interval
)
callbacks = [self.custom_callback]
super().__init__(
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
args=args,
)
self.eval_dataset = eval_dataset
self.train_dataset_eval = train_dataset_eval
self.state = type("State", (), {"global_step": 0})()
self.processor = processor
def evaluate_step(self, dataset, split):
print(f"Evaluating {split} dataset")
try:
device = self.model.device
# The correct signature is: evaluate_model(model, processor, dataset, split, verbose=False)
accuracy = evaluate_model(self.model, self.processor, dataset, split)
# Initialize wandb if not already initialized
import wandb
if not wandb.run:
wandb.init(
project="qwen-vl-trainer",
reinit=True,
name=f"{os.environ.get('RANK', '0')}-evaluation",
group=os.environ.get("WANDB_RUN_GROUP", None),
)
wandb.log(
{
f"{split}/accuracy": accuracy,
}
)
return accuracy # Return the accuracy value
except Exception as e:
logger.error(f"Error evaluating: {e}")
raise
def cleanup(self):
"""Cleanup method to ensure wandb runs are properly closed"""
import wandb
if wandb.run:
wandb.finish()
def load_model(MODEL_ID: str, USE_QLORA: bool, training_args):
# patch_unsloth_smart_gradient_checkpointing()
# Configure more aggressive quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# More aggressive LoRA config
lora_config = LoraConfig(
r=200, # Increase rank for more expressiveness
lora_alpha=50, # Higher scaling factor
lora_dropout=0.001, # Moderate dropout
bias="lora_only",
target_modules=[
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"gate_proj",
"up_proj",
"down_proj",
"fc1",
"fc2",
"mlp.0",
"mlp.2",
],
task_type="CAUSAL_LM",
inference_mode=False,
modules_to_save=None,
)
# Clear memory before model load
torch.cuda.empty_cache()
gc.collect()
# Load DeepSpeed config
with open(training_args.deepspeed, "r") as f:
ds_config = json.load(f)
# Set is_deepspeed_zero3_enabled flag for ZeRO-3
is_deepspeed_zero3_enabled = (
ds_config.get("zero_optimization", {}).get("stage", 0) == 3
)
# Pass DeepSpeed configuration to from_pretrained
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
# quantization_config=bnb_config if USE_QLORA else None, # Use the config
torch_dtype=torch.bfloat16,
# device_map=None, # Let DeepSpeed handle device mapping
use_cache=False,
attn_implementation="flash_attention_2",
)
# Reset generation config to avoid warnings
from transformers import GenerationConfig
model.generation_config = GenerationConfig.from_model_config(model.config)
# Ensure no conflicting generation parameters
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
model.generation_config.early_stopping = False
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
model.enable_input_require_grads() # unsloth added this prior to loading peft
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.config.use_cache = False
model.config.pretraining_tp = 1
# More aggressive gradient checkpointing
model.config.gradient_checkpointing = True
model.config.use_reentrant = False
model.config.gradient_checkpointing_kwargs = {
"use_reentrant": False,
"checkpoint_every_n_layers": 1,
"offload_to_cpu": True,
}
return model, processor
def train(args):
# Set CUDA device explicitly based on local_rank
if args.local_rank != -1:
torch.cuda.set_device(args.local_rank)
# Initialize process group with the correct device
if not torch.distributed.is_initialized():
# Get world size from environment if available
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
rank = int(os.environ.get("RANK", args.local_rank))
print(
f"Initializing process group with rank={rank}, world_size={world_size}"
)
try:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
print(f"Successfully initialized process group for rank {rank}")
except Exception as e:
print(f"Could not initialize process group: {e}")
# Remove memory management env vars that might interfere with DeepSpeed
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
os.environ.pop("MAX_JOBS", None)
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
# Set up DeepSpeed config path first
ds_config_path = "deepspeed_config.json"
# Set up wandb configuration
os.environ["WANDB_MODE"] = "online"
# Create a unique timestamp for this training run
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = timestamp
os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}"
# Create a timestamped output directory
timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
os.makedirs(timestamped_output_dir, exist_ok=True)
print(f"Model checkpoints will be saved to: {timestamped_output_dir}")
# Configure wandb properly for Trainer
os.environ["WANDB_PROJECT"] = "qwen-vl-trainer"
os.environ["WANDB_LOG_MODEL"] = "end" # Changed from "true" to "end"
os.environ["WANDB_WATCH"] = "all" # Monitor all gradients and parameters
os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}"
# Initialize wandb only once at the beginning for the main process
if args.local_rank <= 0: # Only initialize on rank 0 or single GPU
import wandb
wandb.init(
project="qwen-vl-trainer",
name=f"transformer_training_{timestamp}",
group=os.environ.get("WANDB_RUN_GROUP"),
# Important: we're logging the model as an artifact
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
)
# Log config information
wandb.config.update(
{
"model_id": args.model_id,
"use_qlora": args.use_qlora,
"output_dir": timestamped_output_dir,
}
)
print(f"Initialized wandb with run ID: {wandb.run.id}")
# Create SFTConfig with DeepSpeed config before loading the model
training_args = SFTConfig(
per_device_train_batch_size=1, # Equivalent to train_micro_batch_size_per_gpu
gradient_accumulation_steps=2,
logging_steps=1, # Log every step
logging_strategy="steps", # Log based on steps
log_level="info",
num_train_epochs=2000, # Set to desired number of epochs
# eval_steps=100,
bf16=True,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir=timestamped_output_dir, # Use timestamped directory
overwrite_output_dir=True,
report_to="wandb", # Explicitly report to wandb
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=800000,
save_strategy="epoch",
evaluation_strategy="no",
save_total_limit=2000,
deepspeed=ds_config_path, # Pass the DeepSpeed config
)
# Dynamically set devices based on availability
num_gpus = torch.cuda.device_count()
devices = list(range(num_gpus)) if num_gpus > 0 else None
# Pass training args to load_model function
model, processor = load_model(args.model_id, args.use_qlora, training_args)
# Train dataset
train_dataset = AgentDatapointDataset(split="train", num_samples=args.train_size)
# Eval datasets
test_dataset = AgentEvalDatapointDataset(split="test", num_samples=args.test_size)
train_dataset_eval = AgentEvalDatapointDataset(split="train", num_samples=args.train_size)
for_training(model)
trainer = CustomSFTTrainer(
model=model,
processor=processor,
tokenizer=processor.tokenizer,
data_collator=lambda examples: train_collate_fn(examples, processor),
train_dataset_eval=train_dataset_eval,
train_dataset=train_dataset,
eval_dataset=test_dataset,
args=training_args,
)
training_stats = trainer.train()
logger.info("Training completed.")
print(f"Training Statistics: {training_stats}")
# Save the final model explicitly with timestamp
final_model_path = os.path.join(timestamped_output_dir, "final_model")
if args.local_rank <= 0: # Only save on rank 0 or single GPU
print(f"Saving final model to {final_model_path}")
trainer.save_model(final_model_path)
print(f"Final model saved to {final_model_path}")
# Also save the processor
processor.save_pretrained(final_model_path)
# Log the final model to wandb
# import wandb
# if wandb.run:
# model_artifact = wandb.Artifact(
# name=f"model_{timestamp}",
# type="model",
# description=f"Final trained model from run {timestamp}"
# )
# model_artifact.add_dir(final_model_path)
# wandb.log_artifact(model_artifact)
# print(f"Final model logged to wandb as artifact: model_{timestamp}")
#
# print(f"Final model saved to {final_model_path}")
# Ensure proper cleanup of wandb
trainer.cleanup()
# Final cleanup for the main process
if args.local_rank <= 0: # Only finalize on rank 0 or single GPU
import wandb
if wandb.run:
print("Finalizing main wandb run...")
wandb.finish()
print("Training process completed successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training configuration")
parser.add_argument(
"--model_id",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model ID to use",
)
parser.add_argument(
"--use_qlora", type=bool, default=True, help="Whether to use QLoRA"
)
parser.add_argument(
"--output_dir", type=str, default="checkpoints_27feb", help="Output directory"
)
# Add local_rank argument for DeepSpeed
parser.add_argument(
"--local_rank", type=int, default=-1, help="Local rank for distributed training"
)
parser.add_argument(
"--train_size", type=int, default=10000000, help="Number of training samples"
)
parser.add_argument(
"--test_size", type=int, default=10000000, help="Number of test samples"
)
args = parser.parse_args()
train(args)
|