|
import torch |
|
import gc |
|
import numpy as np |
|
import json |
|
|
|
torch.cuda.empty_cache() |
|
import torch.distributed |
|
from dataset import AgentDatapointDataset |
|
import os |
|
import wandb |
|
from lightning.pytorch.loggers import WandbLogger |
|
from peft import get_peft_model, LoraConfig |
|
from transformers import TrainerCallback |
|
|
|
from transformers import BitsAndBytesConfig |
|
|
|
|
|
|
|
|
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration |
|
from trl import SFTTrainer, SFTConfig |
|
|
|
|
|
|
|
from transformers import logging as transformers_logging |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
transformers_logging.set_verbosity_error() |
|
import argparse |
|
from torch.optim import AdamW |
|
from qwen_vl_utils import process_vision_info |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
import json |
|
|
|
|
|
from evaluate import evaluate_model |
|
|
|
from dataset import AgentEvalDatapointDataset, AgentDatapointDataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_collate_fn(examples, processor): |
|
|
|
texts = [ |
|
processor.apply_chat_template(example["messages"], tokenize=False) |
|
for example in examples |
|
] |
|
|
|
image_inputs = [process_vision_info(example["messages"])[0] for example in examples] |
|
|
|
model_inputs = processor( |
|
text=texts, images=image_inputs, return_tensors="pt", padding=True |
|
) |
|
|
|
labels = model_inputs["input_ids"].clone() |
|
|
|
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100 |
|
|
|
if isinstance(processor, Qwen2_5_VLProcessor): |
|
image_tokens = [151652, 151653, 151655] |
|
else: |
|
image_tokens = [ |
|
processor.tokenizer.convert_tokens_to_ids(processor.image_token) |
|
] |
|
|
|
|
|
for image_token_id in image_tokens: |
|
labels[labels == image_token_id] = -100 |
|
|
|
|
|
return { |
|
"input_ids": model_inputs["input_ids"], |
|
"attention_mask": model_inputs["attention_mask"], |
|
"pixel_values": model_inputs["pixel_values"], |
|
"image_grid_thw": model_inputs["image_grid_thw"], |
|
"labels": labels, |
|
} |
|
|
|
|
|
def _wrap_fast_inference(generate, device_type, dtype, model): |
|
|
|
@torch.inference_mode |
|
def _fast_generate(*args, **kwargs): |
|
|
|
|
|
|
|
|
|
kwargs.pop("token_type_ids", None) |
|
|
|
|
|
model_eos_token_id = getattr(model.config, "eos_token_id", None) |
|
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): |
|
model_eos_token_id = model_eos_token_id[0] |
|
|
|
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) |
|
|
|
try: |
|
kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype) |
|
except: |
|
pass |
|
|
|
|
|
with torch.autocast(device_type=device_type, dtype=dtype): |
|
output = generate(*args, **kwargs) |
|
pass |
|
return output |
|
|
|
pass |
|
return _fast_generate |
|
|
|
|
|
pass |
|
|
|
|
|
def for_inference(model): |
|
model.gradient_checkpointing = False |
|
model.training = False |
|
|
|
for name, module in model.named_modules(): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = False |
|
if hasattr(module, "training"): |
|
module.training = False |
|
pass |
|
|
|
dtype = model.config.torch_dtype |
|
if type(dtype) is str: |
|
if dtype == "float16": |
|
dtype = torch.float16 |
|
elif dtype == "bfloat16": |
|
dtype = torch.bfloat16 |
|
pass |
|
device_type = model.device.type |
|
|
|
|
|
if model.generate.__name__ != "_fast_generate": |
|
model._unwrapped_old_generate = model.generate |
|
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
|
|
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" |
|
pass |
|
|
|
|
|
if hasattr(model, "get_input_embeddings"): |
|
embeddings = model.get_input_embeddings() |
|
if hasattr(embeddings, "training"): |
|
embeddings.training = False |
|
pass |
|
if hasattr(model, "get_output_embeddings"): |
|
embeddings = model.get_output_embeddings() |
|
if hasattr(embeddings, "training"): |
|
embeddings.training = False |
|
pass |
|
|
|
return model |
|
|
|
|
|
def for_training(model, use_gradient_checkpointing=True): |
|
model.train() |
|
model.gradient_checkpointing = use_gradient_checkpointing |
|
model.training = True |
|
|
|
for name, module in model.named_modules(): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = use_gradient_checkpointing |
|
if hasattr(module, "training"): |
|
module.training = True |
|
pass |
|
|
|
|
|
if hasattr(model, "_unwrapped_old_generate"): |
|
model.generate = model._unwrapped_old_generate |
|
del model._unwrapped_old_generate |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" |
|
pass |
|
|
|
|
|
if hasattr(model, "get_input_embeddings"): |
|
embeddings = model.get_input_embeddings() |
|
if hasattr(embeddings, "training"): |
|
embeddings.training = True |
|
pass |
|
if hasattr(model, "get_output_embeddings"): |
|
embeddings = model.get_output_embeddings() |
|
if hasattr(embeddings, "training"): |
|
embeddings.training = True |
|
pass |
|
|
|
return model |
|
|
|
|
|
class CustomTrainingCallback(TrainerCallback): |
|
def __init__(self, trainer, eval_epoch_interval=2): |
|
self.trainer = trainer |
|
self.eval_epoch_interval = eval_epoch_interval |
|
self.best_test_accuracy = 0.0 |
|
self.best_test_epoch = 0 |
|
self.best_metrics = { |
|
'test_accuracy': 0.0, |
|
'train_accuracy': 0.0, |
|
'epoch': 0, |
|
'global_step': 0 |
|
} |
|
|
|
def save_best_metrics(self, output_dir): |
|
"""Save best metrics to a file in the checkpoint directory""" |
|
metrics_file = os.path.join(output_dir, 'best_metrics.json') |
|
with open(metrics_file, 'w') as f: |
|
json.dump(self.best_metrics, f, indent=4) |
|
print(f"Saved best metrics to {metrics_file}") |
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
"""Log metrics at each logging step""" |
|
if logs is not None: |
|
|
|
import wandb |
|
|
|
if not wandb.run: |
|
wandb.init( |
|
project="qwen-vl-trainer", |
|
reinit=True, |
|
name=f"{os.environ.get('RANK', '0')}-training", |
|
group=os.environ.get("WANDB_RUN_GROUP", None), |
|
) |
|
|
|
|
|
step = state.global_step if hasattr(state, "global_step") else 0 |
|
|
|
|
|
log_data = {} |
|
for key, value in logs.items(): |
|
|
|
if key not in ["eval_loss", "epoch", "learning_rate"]: |
|
log_data[f"train/{key}"] = value |
|
else: |
|
log_data[key] = value |
|
|
|
wandb.log(log_data, step=step) |
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
print(f"Epoch {state.epoch + 1} ended") |
|
was_training = self.trainer.model.training |
|
for_inference(self.trainer.model) |
|
self.trainer.model.eval() |
|
|
|
if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4: |
|
|
|
test_accuracy = self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test") |
|
train_accuracy = self.trainer.evaluate_step(dataset=self.trainer.train_dataset_eval, split="train") |
|
|
|
print(f"Test accuracy: {test_accuracy:.4f}, Train accuracy: {train_accuracy:.4f}") |
|
|
|
|
|
if test_accuracy > self.best_test_accuracy: |
|
self.best_test_accuracy = test_accuracy |
|
self.best_test_epoch = state.epoch + 1 |
|
|
|
|
|
self.best_metrics.update({ |
|
'best_test_accuracy': float(test_accuracy), |
|
'train_accuracy': float(train_accuracy), |
|
'epoch': int(state.epoch + 1), |
|
'global_step': int(state.global_step) |
|
}) |
|
|
|
|
|
self.save_best_metrics(args.output_dir) |
|
|
|
|
|
|
|
print(f"\nNew best test accuracy: {self.best_test_accuracy:.4f} at epoch {self.best_test_epoch}") |
|
|
|
if was_training: |
|
for_training(self.trainer.model) |
|
self.trainer.model.train() |
|
|
|
|
|
class CustomSFTTrainer(SFTTrainer): |
|
def __init__( |
|
self, |
|
model, |
|
tokenizer, |
|
processor, |
|
data_collator, |
|
train_dataset=None, |
|
train_dataset_eval=None, |
|
eval_dataset=None, |
|
eval_epoch_interval=2, |
|
args=None, |
|
): |
|
|
|
|
|
|
|
self.custom_callback = CustomTrainingCallback( |
|
self, eval_epoch_interval=eval_epoch_interval |
|
) |
|
callbacks = [self.custom_callback] |
|
|
|
super().__init__( |
|
model=model, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
callbacks=callbacks, |
|
args=args, |
|
) |
|
self.eval_dataset = eval_dataset |
|
self.train_dataset_eval = train_dataset_eval |
|
self.state = type("State", (), {"global_step": 0})() |
|
self.processor = processor |
|
|
|
def evaluate_step(self, dataset, split): |
|
print(f"Evaluating {split} dataset") |
|
try: |
|
device = self.model.device |
|
|
|
|
|
accuracy = evaluate_model(self.model, self.processor, dataset, split) |
|
|
|
|
|
import wandb |
|
|
|
if not wandb.run: |
|
wandb.init( |
|
project="qwen-vl-trainer", |
|
reinit=True, |
|
name=f"{os.environ.get('RANK', '0')}-evaluation", |
|
group=os.environ.get("WANDB_RUN_GROUP", None), |
|
) |
|
|
|
wandb.log( |
|
{ |
|
f"{split}/accuracy": accuracy, |
|
} |
|
) |
|
|
|
return accuracy |
|
|
|
except Exception as e: |
|
logger.error(f"Error evaluating: {e}") |
|
raise |
|
|
|
def cleanup(self): |
|
"""Cleanup method to ensure wandb runs are properly closed""" |
|
import wandb |
|
|
|
if wandb.run: |
|
wandb.finish() |
|
|
|
|
|
def load_model(MODEL_ID: str, USE_QLORA: bool, training_args): |
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
lora_config = LoraConfig( |
|
r=200, |
|
lora_alpha=50, |
|
lora_dropout=0.001, |
|
bias="lora_only", |
|
target_modules=[ |
|
"qkv_proj", |
|
"o_proj", |
|
"gate_up_proj", |
|
"down_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
"fc1", |
|
"fc2", |
|
"mlp.0", |
|
"mlp.2", |
|
], |
|
task_type="CAUSAL_LM", |
|
inference_mode=False, |
|
modules_to_save=None, |
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
with open(training_args.deepspeed, "r") as f: |
|
ds_config = json.load(f) |
|
|
|
|
|
is_deepspeed_zero3_enabled = ( |
|
ds_config.get("zero_optimization", {}).get("stage", 0) == 3 |
|
) |
|
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
MODEL_ID, |
|
|
|
torch_dtype=torch.bfloat16, |
|
|
|
use_cache=False, |
|
attn_implementation="flash_attention_2", |
|
) |
|
|
|
|
|
from transformers import GenerationConfig |
|
|
|
model.generation_config = GenerationConfig.from_model_config(model.config) |
|
|
|
model.generation_config.temperature = None |
|
model.generation_config.top_p = None |
|
model.generation_config.top_k = None |
|
model.generation_config.early_stopping = False |
|
|
|
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID) |
|
|
|
model.enable_input_require_grads() |
|
model = get_peft_model(model, lora_config) |
|
model.gradient_checkpointing_enable() |
|
|
|
model.config.use_cache = False |
|
model.config.pretraining_tp = 1 |
|
|
|
|
|
model.config.gradient_checkpointing = True |
|
model.config.use_reentrant = False |
|
model.config.gradient_checkpointing_kwargs = { |
|
"use_reentrant": False, |
|
"checkpoint_every_n_layers": 1, |
|
"offload_to_cpu": True, |
|
} |
|
|
|
return model, processor |
|
|
|
|
|
def train(args): |
|
|
|
if args.local_rank != -1: |
|
torch.cuda.set_device(args.local_rank) |
|
|
|
|
|
if not torch.distributed.is_initialized(): |
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) |
|
rank = int(os.environ.get("RANK", args.local_rank)) |
|
print( |
|
f"Initializing process group with rank={rank}, world_size={world_size}" |
|
) |
|
|
|
try: |
|
torch.distributed.init_process_group( |
|
backend="nccl", |
|
init_method="env://", |
|
world_size=world_size, |
|
rank=rank, |
|
) |
|
print(f"Successfully initialized process group for rank {rank}") |
|
except Exception as e: |
|
print(f"Could not initialize process group: {e}") |
|
|
|
|
|
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None) |
|
os.environ.pop("MAX_JOBS", None) |
|
os.environ.pop("CUDA_LAUNCH_BLOCKING", None) |
|
|
|
|
|
ds_config_path = "deepspeed_config.json" |
|
|
|
|
|
os.environ["WANDB_MODE"] = "online" |
|
|
|
|
|
import datetime |
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
run_id = timestamp |
|
os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}" |
|
|
|
|
|
timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}") |
|
os.makedirs(timestamped_output_dir, exist_ok=True) |
|
print(f"Model checkpoints will be saved to: {timestamped_output_dir}") |
|
|
|
|
|
os.environ["WANDB_PROJECT"] = "qwen-vl-trainer" |
|
os.environ["WANDB_LOG_MODEL"] = "end" |
|
os.environ["WANDB_WATCH"] = "all" |
|
os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}" |
|
|
|
|
|
if args.local_rank <= 0: |
|
import wandb |
|
|
|
wandb.init( |
|
project="qwen-vl-trainer", |
|
name=f"transformer_training_{timestamp}", |
|
group=os.environ.get("WANDB_RUN_GROUP"), |
|
|
|
settings=wandb.Settings(_disable_stats=True, _disable_meta=True), |
|
) |
|
|
|
wandb.config.update( |
|
{ |
|
"model_id": args.model_id, |
|
"use_qlora": args.use_qlora, |
|
"output_dir": timestamped_output_dir, |
|
} |
|
) |
|
print(f"Initialized wandb with run ID: {wandb.run.id}") |
|
|
|
|
|
training_args = SFTConfig( |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=2, |
|
logging_steps=1, |
|
logging_strategy="steps", |
|
log_level="info", |
|
num_train_epochs=2000, |
|
|
|
bf16=True, |
|
optim="adamw_8bit", |
|
lr_scheduler_type="linear", |
|
seed=3407, |
|
output_dir=timestamped_output_dir, |
|
overwrite_output_dir=True, |
|
report_to="wandb", |
|
remove_unused_columns=False, |
|
dataset_text_field="", |
|
dataset_kwargs={"skip_prepare_dataset": True}, |
|
dataset_num_proc=4, |
|
max_seq_length=800000, |
|
save_strategy="epoch", |
|
evaluation_strategy="no", |
|
save_total_limit=2000, |
|
deepspeed=ds_config_path, |
|
) |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
devices = list(range(num_gpus)) if num_gpus > 0 else None |
|
|
|
|
|
model, processor = load_model(args.model_id, args.use_qlora, training_args) |
|
|
|
train_dataset = AgentDatapointDataset(split="train", num_samples=args.train_size) |
|
|
|
test_dataset = AgentEvalDatapointDataset(split="test", num_samples=args.test_size) |
|
train_dataset_eval = AgentEvalDatapointDataset(split="train", num_samples=args.train_size) |
|
for_training(model) |
|
|
|
trainer = CustomSFTTrainer( |
|
model=model, |
|
processor=processor, |
|
tokenizer=processor.tokenizer, |
|
data_collator=lambda examples: train_collate_fn(examples, processor), |
|
train_dataset_eval=train_dataset_eval, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
args=training_args, |
|
) |
|
|
|
training_stats = trainer.train() |
|
logger.info("Training completed.") |
|
print(f"Training Statistics: {training_stats}") |
|
|
|
|
|
final_model_path = os.path.join(timestamped_output_dir, "final_model") |
|
if args.local_rank <= 0: |
|
print(f"Saving final model to {final_model_path}") |
|
trainer.save_model(final_model_path) |
|
print(f"Final model saved to {final_model_path}") |
|
|
|
processor.save_pretrained(final_model_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.cleanup() |
|
|
|
|
|
if args.local_rank <= 0: |
|
import wandb |
|
|
|
if wandb.run: |
|
print("Finalizing main wandb run...") |
|
wandb.finish() |
|
|
|
print("Training process completed successfully.") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Training configuration") |
|
parser.add_argument( |
|
"--model_id", |
|
type=str, |
|
default="Qwen/Qwen2.5-VL-7B-Instruct", |
|
help="Model ID to use", |
|
) |
|
parser.add_argument( |
|
"--use_qlora", type=bool, default=True, help="Whether to use QLoRA" |
|
) |
|
parser.add_argument( |
|
"--output_dir", type=str, default="checkpoints_27feb", help="Output directory" |
|
) |
|
|
|
parser.add_argument( |
|
"--local_rank", type=int, default=-1, help="Local rank for distributed training" |
|
) |
|
parser.add_argument( |
|
"--train_size", type=int, default=10000000, help="Number of training samples" |
|
) |
|
parser.add_argument( |
|
"--test_size", type=int, default=10000000, help="Number of test samples" |
|
) |
|
args = parser.parse_args() |
|
train(args) |
|
|