import torch |
import gc |
import numpy as np |
import json |
torch.cuda.empty_cache() |
import torch.distributed |
from dataset import AgentDatapointDataset |
import os |
import wandb |
from lightning.pytorch.loggers import WandbLogger |
from peft import get_peft_model, LoraConfig |
from transformers import TrainerCallback |
from transformers import BitsAndBytesConfig |
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration |
from trl import SFTTrainer, SFTConfig |
from transformers import logging as transformers_logging |
import logging |
logging.basicConfig(level=logging.INFO) |
logger = logging.getLogger(__name__) |
transformers_logging.set_verbosity_error() |
import argparse |
from torch.optim import AdamW |
from qwen_vl_utils import process_vision_info |
torch.set_float32_matmul_precision("medium") |
import json |
from evaluate import evaluate_model |
from dataset import AgentEvalDatapointDataset, AgentDatapointDataset |
def train_collate_fn(examples, processor): |
texts = [ |
processor.apply_chat_template(example["messages"], tokenize=False) |
for example in examples |
] |
image_inputs = [process_vision_info(example["messages"])[0] for example in examples] |
model_inputs = processor( |
text=texts, images=image_inputs, return_tensors="pt", padding=True |
) |
labels = model_inputs["input_ids"].clone() |
labels[labels == processor.tokenizer.pad_token_id] = -100 |
if isinstance(processor, Qwen2_5_VLProcessor): |
image_tokens = [151652, 151653, 151655] |
else: |
image_tokens = [ |
processor.tokenizer.convert_tokens_to_ids(processor.image_token) |
] |
for image_token_id in image_tokens: |
labels[labels == image_token_id] = -100 |
return { |
"input_ids": model_inputs["input_ids"], |
"attention_mask": model_inputs["attention_mask"], |
"pixel_values": model_inputs["pixel_values"], |
"image_grid_thw": model_inputs["image_grid_thw"], |
"labels": labels, |
} |
def _wrap_fast_inference(generate, device_type, dtype, model): |
@torch.inference_mode |
def _fast_generate(*args, **kwargs): |
kwargs.pop("token_type_ids", None) |
model_eos_token_id = getattr(model.config, "eos_token_id", None) |
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): |
model_eos_token_id = model_eos_token_id[0] |
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) |
try: |
kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype) |
except: |
pass |
with torch.autocast(device_type=device_type, dtype=dtype): |
output = generate(*args, **kwargs) |
pass |
return output |
pass |
return _fast_generate |
pass |
def for_inference(model): |
model.gradient_checkpointing = False |
model.training = False |
for name, module in model.named_modules(): |
if hasattr(module, "gradient_checkpointing"): |
module.gradient_checkpointing = False |
if hasattr(module, "training"): |
module.training = False |
pass |
dtype = model.config.torch_dtype |
if type(dtype) is str: |
if dtype == "float16": |
dtype = torch.float16 |
elif dtype == "bfloat16": |
dtype = torch.bfloat16 |
pass |
device_type = model.device.type |
if model.generate.__name__ != "_fast_generate": |
model._unwrapped_old_generate = model.generate |
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) |
pass |
internal_model = model |
while hasattr(internal_model, "model"): |
if hasattr(internal_model, "_saved_temp_tokenizer"): |
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" |
pass |
internal_model = internal_model.model |
pass |
if hasattr(internal_model, "_saved_temp_tokenizer"): |
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" |
pass |
if hasattr(model, "get_input_embeddings"): |
embeddings = model.get_input_embeddings() |
if hasattr(embeddings, "training"): |
embeddings.training = False |
pass |
if hasattr(model, "get_output_embeddings"): |
embeddings = model.get_output_embeddings() |
if hasattr(embeddings, "training"): |
embeddings.training = False |
pass |
return model |
def for_training(model, use_gradient_checkpointing=True): |
model.train() |
model.gradient_checkpointing = use_gradient_checkpointing |
model.training = True |
for name, module in model.named_modules(): |
if hasattr(module, "gradient_checkpointing"): |
module.gradient_checkpointing = use_gradient_checkpointing |
if hasattr(module, "training"): |
module.training = True |
pass |
if hasattr(model, "_unwrapped_old_generate"): |
model.generate = model._unwrapped_old_generate |
del model._unwrapped_old_generate |
pass |
internal_model = model |
while hasattr(internal_model, "model"): |
if hasattr(internal_model, "_saved_temp_tokenizer"): |
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" |
pass |
internal_model = internal_model.model |
pass |
if hasattr(internal_model, "_saved_temp_tokenizer"): |
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" |
pass |
if hasattr(model, "get_input_embeddings"): |
embeddings = model.get_input_embeddings() |
if hasattr(embeddings, "training"): |
embeddings.training = True |
pass |
if hasattr(model, "get_output_embeddings"): |
embeddings = model.get_output_embeddings() |
if hasattr(embeddings, "training"): |
embeddings.training = True |
pass |
return model |
class CustomTrainingCallback(TrainerCallback): |
def __init__(self, trainer, eval_epoch_interval=2): |
self.trainer = trainer |
self.eval_epoch_interval = eval_epoch_interval |
self.best_test_accuracy = 0.0 |
self.best_test_epoch = 0 |
self.best_metrics = { |
'test_accuracy': 0.0, |
'train_accuracy': 0.0, |
'epoch': 0, |
'global_step': 0 |
} |
def save_best_metrics(self, output_dir): |
"""Save best metrics to a file in the checkpoint directory""" |
metrics_file = os.path.join(output_dir, 'best_metrics.json') |
with open(metrics_file, 'w') as f: |
json.dump(self.best_metrics, f, indent=4) |
print(f"Saved best metrics to {metrics_file}") |
def on_log(self, args, state, control, logs=None, **kwargs): |
"""Log metrics at each logging step""" |
if logs is not None: |
import wandb |
if not wandb.run: |
wandb.init( |
project="qwen-vl-trainer", |
reinit=True, |
name=f"{os.environ.get('RANK', '0')}-training", |
group=os.environ.get("WANDB_RUN_GROUP", None), |
) |
step = state.global_step if hasattr(state, "global_step") else 0 |
log_data = {} |
for key, value in logs.items(): |
if key not in ["eval_loss", "epoch", "learning_rate"]: |
log_data[f"train/{key}"] = value |
else: |
log_data[key] = value |
wandb.log(log_data, step=step) |
def on_epoch_end(self, args, state, control, **kwargs): |
print(f"Epoch {state.epoch + 1} ended") |
was_training = self.trainer.model.training |
for_inference(self.trainer.model) |
self.trainer.model.eval() |
if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4: |
test_accuracy = self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test") |
train_accuracy = self.trainer.evaluate_step(dataset=self.trainer.train_dataset_eval, split="train") |
print(f"Test accuracy: {test_accuracy:.4f}, Train accuracy: {train_accuracy:.4f}") |
if test_accuracy > self.best_test_accuracy: |
self.best_test_accuracy = test_accuracy |
self.best_test_epoch = state.epoch + 1 |
self.best_metrics.update({ |
'best_test_accuracy': float(test_accuracy), |
'train_accuracy': float(train_accuracy), |
'epoch': int(state.epoch + 1), |
'global_step': int(state.global_step) |
}) |
self.save_best_metrics(args.output_dir) |
print(f"\nNew best test accuracy: {self.best_test_accuracy:.4f} at epoch {self.best_test_epoch}") |
if was_training: |
for_training(self.trainer.model) |
self.trainer.model.train() |
class CustomSFTTrainer(SFTTrainer): |
def __init__( |
self, |
model, |
tokenizer, |
processor, |
data_collator, |
train_dataset=None, |
train_dataset_eval=None, |
eval_dataset=None, |
eval_epoch_interval=2, |
args=None, |
): |
self.custom_callback = CustomTrainingCallback( |
self, eval_epoch_interval=eval_epoch_interval |
) |
callbacks = [self.custom_callback] |
super().__init__( |
model=model, |
tokenizer=tokenizer, |
data_collator=data_collator, |
train_dataset=train_dataset, |
eval_dataset=eval_dataset, |
callbacks=callbacks, |
args=args, |
) |
self.eval_dataset = eval_dataset |
self.train_dataset_eval = train_dataset_eval |
self.state = type("State", (), {"global_step": 0})() |
self.processor = processor |
def evaluate_step(self, dataset, split): |
print(f"Evaluating {split} dataset") |
try: |
device = self.model.device |
accuracy = evaluate_model(self.model, self.processor, dataset, split) |
import wandb |
if not wandb.run: |
wandb.init( |
project="qwen-vl-trainer", |
reinit=True, |
name=f"{os.environ.get('RANK', '0')}-evaluation", |
group=os.environ.get("WANDB_RUN_GROUP", None), |
) |
wandb.log( |
{ |
f"{split}/accuracy": accuracy, |
} |
) |
return accuracy |
except Exception as e: |
logger.error(f"Error evaluating: {e}") |
raise |
def cleanup(self): |
"""Cleanup method to ensure wandb runs are properly closed""" |
import wandb |
if wandb.run: |
wandb.finish() |
def load_model(MODEL_ID: str, USE_QLORA: bool, training_args): |
bnb_config = BitsAndBytesConfig( |
load_in_4bit=True, |
bnb_4bit_use_double_quant=True, |
bnb_4bit_quant_type="nf4", |
bnb_4bit_compute_dtype=torch.bfloat16, |
) |
lora_config = LoraConfig( |
r=200, |
lora_alpha=50, |
lora_dropout=0.001, |
bias="lora_only", |
target_modules=[ |
"qkv_proj", |
"o_proj", |
"gate_up_proj", |
"down_proj", |
"gate_proj", |
"up_proj", |
"down_proj", |
"fc1", |
"fc2", |
"mlp.0", |
"mlp.2", |
], |
task_type="CAUSAL_LM", |
inference_mode=False, |
modules_to_save=None, |
) |
torch.cuda.empty_cache() |
gc.collect() |
with open(training_args.deepspeed, "r") as f: |
ds_config = json.load(f) |
is_deepspeed_zero3_enabled = ( |
ds_config.get("zero_optimization", {}).get("stage", 0) == 3 |
) |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
torch_dtype=torch.bfloat16, |
use_cache=False, |
attn_implementation="flash_attention_2", |
) |
from transformers import GenerationConfig |
model.generation_config = GenerationConfig.from_model_config(model.config) |
model.generation_config.temperature = None |
model.generation_config.top_p = None |
model.generation_config.top_k = None |
model.generation_config.early_stopping = False |
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID) |
model.enable_input_require_grads() |
model = get_peft_model(model, lora_config) |
model.gradient_checkpointing_enable() |
model.config.use_cache = False |
model.config.pretraining_tp = 1 |
model.config.gradient_checkpointing = True |
model.config.use_reentrant = False |
model.config.gradient_checkpointing_kwargs = { |
"use_reentrant": False, |
"checkpoint_every_n_layers": 1, |
"offload_to_cpu": True, |
} |
return model, processor |
def train(args): |
if args.local_rank != -1: |
torch.cuda.set_device(args.local_rank) |
if not torch.distributed.is_initialized(): |
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) |
rank = int(os.environ.get("RANK", args.local_rank)) |
print( |
f"Initializing process group with rank={rank}, world_size={world_size}" |
) |
try: |
torch.distributed.init_process_group( |
backend="nccl", |
init_method="env://", |
world_size=world_size, |
rank=rank, |
) |
print(f"Successfully initialized process group for rank {rank}") |
except Exception as e: |
print(f"Could not initialize process group: {e}") |
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None) |
os.environ.pop("MAX_JOBS", None) |
os.environ.pop("CUDA_LAUNCH_BLOCKING", None) |
ds_config_path = "deepspeed_config.json" |
os.environ["WANDB_MODE"] = "online" |
import datetime |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
run_id = timestamp |
os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}" |
timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}") |
os.makedirs(timestamped_output_dir, exist_ok=True) |
print(f"Model checkpoints will be saved to: {timestamped_output_dir}") |
os.environ["WANDB_PROJECT"] = "qwen-vl-trainer" |
os.environ["WANDB_LOG_MODEL"] = "end" |
os.environ["WANDB_WATCH"] = "all" |
os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}" |
if args.local_rank <= 0: |
import wandb |
wandb.init( |
project="qwen-vl-trainer", |
name=f"transformer_training_{timestamp}", |
group=os.environ.get("WANDB_RUN_GROUP"), |
settings=wandb.Settings(_disable_stats=True, _disable_meta=True), |
) |
wandb.config.update( |
{ |
"model_id": args.model_id, |
"use_qlora": args.use_qlora, |
"output_dir": timestamped_output_dir, |
} |
) |
print(f"Initialized wandb with run ID: {wandb.run.id}") |
training_args = SFTConfig( |
per_device_train_batch_size=1, |
gradient_accumulation_steps=2, |
logging_steps=1, |
logging_strategy="steps", |
log_level="info", |
num_train_epochs=2000, |
bf16=True, |
optim="adamw_8bit", |
lr_scheduler_type="linear", |
seed=3407, |
output_dir=timestamped_output_dir, |
overwrite_output_dir=True, |
report_to="wandb", |
remove_unused_columns=False, |
dataset_text_field="", |
dataset_kwargs={"skip_prepare_dataset": True}, |
dataset_num_proc=4, |
max_seq_length=800000, |
save_strategy="epoch", |
evaluation_strategy="no", |
save_total_limit=2000, |
deepspeed=ds_config_path, |
) |
num_gpus = torch.cuda.device_count() |
devices = list(range(num_gpus)) if num_gpus > 0 else None |
model, processor = load_model(args.model_id, args.use_qlora, training_args) |
train_dataset = AgentDatapointDataset(split="train", num_samples=args.train_size) |
test_dataset = AgentEvalDatapointDataset(split="test", num_samples=args.test_size) |
train_dataset_eval = AgentEvalDatapointDataset(split="train", num_samples=args.train_size) |
for_training(model) |
trainer = CustomSFTTrainer( |
model=model, |
processor=processor, |
tokenizer=processor.tokenizer, |
data_collator=lambda examples: train_collate_fn(examples, processor), |
train_dataset_eval=train_dataset_eval, |
train_dataset=train_dataset, |
eval_dataset=test_dataset, |
args=training_args, |
) |
training_stats = trainer.train() |
logger.info("Training completed.") |
print(f"Training Statistics: {training_stats}") |
final_model_path = os.path.join(timestamped_output_dir, "final_model") |
if args.local_rank <= 0: |
print(f"Saving final model to {final_model_path}") |
trainer.save_model(final_model_path) |
print(f"Final model saved to {final_model_path}") |
processor.save_pretrained(final_model_path) |
trainer.cleanup() |
if args.local_rank <= 0: |
import wandb |
if wandb.run: |
print("Finalizing main wandb run...") |
wandb.finish() |
print("Training process completed successfully.") |
if __name__ == "__main__": |
parser = argparse.ArgumentParser(description="Training configuration") |
parser.add_argument( |
"--model_id", |
type=str, |
default="Qwen/Qwen2.5-VL-7B-Instruct", |
help="Model ID to use", |
) |
parser.add_argument( |
"--use_qlora", type=bool, default=True, help="Whether to use QLoRA" |
) |
parser.add_argument( |
"--output_dir", type=str, default="checkpoints_27feb", help="Output directory" |
) |
parser.add_argument( |
"--local_rank", type=int, default=-1, help="Local rank for distributed training" |
) |
parser.add_argument( |
"--train_size", type=int, default=10000000, help="Number of training samples" |
) |
parser.add_argument( |
"--test_size", type=int, default=10000000, help="Number of test samples" |
) |
args = parser.parse_args() |
train(args) |