File size: 46,834 Bytes
3f2c461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 |
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
# Copyright (c) 2024 Baidu. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
cleaned main file
"""
import argparse
import ast
import copy
import datetime
import json
import math
import os
import random
import shutil
import time
from copy import deepcopy
from logging import getLogger
from pathlib import Path
from typing import DefaultDict, List, Callable
import numpy as np
import torch
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader, DistributedSampler
import rfdetr.util.misc as utils
from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
from rfdetr.engine import evaluate, train_one_epoch
from rfdetr.models import build_model, build_criterion_and_postprocessors
from rfdetr.util.benchmark import benchmark
from rfdetr.util.drop_scheduler import drop_scheduler
from rfdetr.util.files import download_file
from rfdetr.util.get_param_dicts import get_param_dict
from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict
if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
logger = getLogger(__name__)
HOSTED_MODELS = {
"rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
# below is a less converged model that may be better for finetuning but worse for inference
"rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
"rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
}
def download_pretrain_weights(pretrain_weights: str, redownload=False):
if pretrain_weights in HOSTED_MODELS:
if redownload or not os.path.exists(pretrain_weights):
logger.info(
f"Downloading pretrained weights for {pretrain_weights}"
)
download_file(
HOSTED_MODELS[pretrain_weights],
pretrain_weights,
)
class Model:
def __init__(self, **kwargs):
args = populate_args(**kwargs)
self.resolution = args.resolution
self.model = build_model(args)
self.device = torch.device(args.device)
if args.pretrain_weights is not None:
print("Loading pretrain weights")
try:
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
except Exception as e:
print(f"Failed to load pretrain weights: {e}")
# re-download weights if they are corrupted
print("Failed to load pretrain weights, re-downloading")
download_pretrain_weights(args.pretrain_weights, redownload=True)
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
# Extract class_names from checkpoint if available
if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'):
self.class_names = checkpoint['args'].class_names
checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
if checkpoint_num_classes != args.num_classes + 1:
logger.warning(
f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n"
f"reinitializing detection head with {checkpoint_num_classes - 1} classes"
)
self.reinitialize_detection_head(checkpoint_num_classes)
# add support to exclude_keys
# e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]`
if args.pretrain_exclude_keys is not None:
assert isinstance(args.pretrain_exclude_keys, list)
for exclude_key in args.pretrain_exclude_keys:
checkpoint['model'].pop(exclude_key)
if args.pretrain_keys_modify_to_load is not None:
from util.obj365_to_coco_model import get_coco_pretrain_from_obj365
assert isinstance(args.pretrain_keys_modify_to_load, list)
for modify_key_to_load in args.pretrain_keys_modify_to_load:
try:
checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365(
model_without_ddp.state_dict()[modify_key_to_load],
checkpoint['model'][modify_key_to_load]
)
except:
print(f"Failed to load {modify_key_to_load}, deleting from checkpoint")
checkpoint['model'].pop(modify_key_to_load)
# we may want to resume training with a smaller number of groups for group detr
num_desired_queries = args.num_queries * args.group_detr
query_param_names = ["refpoint_embed.weight", "query_feat.weight"]
for name, state in checkpoint['model'].items():
if any(name.endswith(x) for x in query_param_names):
checkpoint['model'][name] = state[:num_desired_queries]
self.model.load_state_dict(checkpoint['model'], strict=False)
if args.backbone_lora:
print("Applying LORA to backbone")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
use_dora=True,
target_modules=[
"q_proj", "v_proj", "k_proj", # covers OWL-ViT
"qkv", # covers open_clip ie Siglip2
"query", "key", "value", "cls_token", "register_tokens", # covers Dinov2 with windowed attn
]
)
self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
self.model = self.model.to(self.device)
self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
self.stop_early = False
def reinitialize_detection_head(self, num_classes):
self.model.reinitialize_detection_head(num_classes)
def request_early_stop(self):
self.stop_early = True
print("Early stopping requested, will complete current epoch and stop")
def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
for key in callbacks.keys():
if key not in currently_supported_callbacks:
raise ValueError(
f"Callback {key} is not currently supported, please file an issue if you need it!\n"
f"Currently supported callbacks: {currently_supported_callbacks}"
)
args = populate_args(**kwargs)
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
criterion, postprocessors = build_criterion_and_postprocessors(args)
model = self.model
model.to(device)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
param_dicts = get_param_dict(args, model_without_ddp)
param_dicts = [p for p in param_dicts if p['params'].requires_grad]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
# Choose the learning rate scheduler based on the new argument
dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)
# for cosine annealing, calculate total training steps and warmup steps
total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps
num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr
total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs
warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs
def lr_lambda(current_step: int):
if current_step < warmup_steps_lr:
# Linear warmup
return float(current_step) / float(max(1, warmup_steps_lr))
else:
# Cosine annealing from multiplier 1.0 down to lr_min_factor
if args.lr_scheduler == 'cosine':
progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr))
return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress))
elif args.lr_scheduler == 'step':
if current_step < args.lr_drop * num_training_steps_per_epoch_lr:
return 1.0
else:
return 0.1
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
effective_batch_size = args.batch_size * args.grad_accum_steps
min_batches = kwargs.get('min_batches', 5)
if len(dataset_train) < effective_batch_size * min_batches:
logger.info(
f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}"
)
sampler = torch.utils.data.RandomSampler(
dataset_train,
replacement=True,
num_samples=effective_batch_size * min_batches,
)
data_loader_train = DataLoader(
dataset_train,
batch_size=effective_batch_size,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
sampler=sampler,
)
else:
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, effective_batch_size, drop_last=True)
data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
num_workers=args.num_workers
)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn,
num_workers=args.num_workers)
base_ds = get_coco_api_from_dataset(dataset_val)
if args.use_ema:
self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau)
else:
self.ema_m = None
output_dir = Path(args.output_dir)
if utils.is_main_process():
print("Get benchmark")
if args.do_benchmark:
benchmark_model = copy.deepcopy(model_without_ddp)
bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
print(json.dumps(bm, indent=2))
del benchmark_model
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
if args.use_ema:
if 'ema_model' in checkpoint:
self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
else:
del self.ema_m
self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau)
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.eval:
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return
# for drop
total_batch_size = effective_batch_size * utils.get_world_size()
num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size
schedules = {}
if args.dropout > 0:
schedules['do'] = drop_scheduler(
args.dropout, args.epochs, num_training_steps_per_epoch,
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
if args.drop_path > 0:
schedules['dp'] = drop_scheduler(
args.drop_path, args.epochs, num_training_steps_per_epoch,
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
print("Start training")
start_time = time.time()
best_map_holder = BestMetricHolder(use_ema=args.use_ema)
best_map_5095 = 0
best_map_50 = 0
best_map_ema_5095 = 0
best_map_ema_50 = 0
for epoch in range(args.start_epoch, args.epochs):
epoch_start_time = time.time()
if args.distributed:
sampler_train.set_epoch(epoch)
model.train()
criterion.train()
train_stats = train_one_epoch(
model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch,
effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules,
num_training_steps_per_epoch=num_training_steps_per_epoch,
vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks)
train_epoch_time = time.time() - epoch_start_time
train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time)))
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every `checkpoint_interval` epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
weights = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}
if args.use_ema:
weights.update({
'ema_model': self.ema_m.module.state_dict(),
})
if not args.dont_save_weights:
# create checkpoint dir
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
utils.save_on_master(weights, checkpoint_path)
with torch.inference_mode():
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
)
map_regular = test_stats['coco_eval_bbox'][0]
_isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
if _isbest:
best_map_5095 = max(best_map_5095, map_regular)
best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1])
checkpoint_path = output_dir / 'checkpoint0009.pth'
if not args.dont_save_weights:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.use_ema:
ema_test_stats, _ = evaluate(
self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
)
log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
map_ema = ema_test_stats['coco_eval_bbox'][0]
best_map_ema_5095 = max(best_map_ema_5095, map_ema)
_isbest = best_map_holder.update(map_ema, epoch, is_ema=True)
if _isbest:
best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1])
checkpoint_path = output_dir / 'checkpoint_best_ema.pth'
if not args.dont_save_weights:
utils.save_on_master({
'model': self.ema_m.module.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
log_stats.update(best_map_holder.summary())
# epoch parameters
ep_paras = {
'epoch': epoch,
'n_parameters': n_parameters
}
log_stats.update(ep_paras)
try:
log_stats.update({'now_time': str(datetime.datetime.now())})
except:
pass
log_stats['train_epoch_time'] = train_epoch_time_str
epoch_time = time.time() - epoch_start_time
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
log_stats['epoch_time'] = epoch_time_str
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)
for callback in callbacks["on_fit_epoch_end"]:
callback(log_stats)
if self.stop_early:
print(f"Early stopping requested, stopping at epoch {epoch}")
break
best_is_ema = best_map_ema_5095 > best_map_5095
if utils.is_main_process():
if best_is_ema:
shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
else:
shutil.copy2(output_dir / 'checkpoint0009.pth', output_dir / 'checkpoint_best_total.pth')
utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth')
best_map_5095 = max(best_map_5095, best_map_ema_5095)
best_map_50 = max(best_map_50, best_map_ema_50)
results_json = {
"map95": best_map_5095,
"map50": best_map_50,
"class": "all"
}
results = {
"class_map": {
"valid": [
results_json
]
}
}
with open(output_dir / "results.json", "w") as f:
json.dump(results, f)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print('Results saved to {}'.format(output_dir / "results.json"))
if best_is_ema:
self.model = self.ema_m.module
self.model.eval()
for callback in callbacks["on_train_end"]:
callback()
def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs):
"""Export the trained model to ONNX format"""
print(f"Exporting model to ONNX format")
try:
from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image
except ImportError:
print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.")
raise
device = self.device
model = deepcopy(self.model.to("cpu"))
model.to(device)
os.makedirs(output_dir, exist_ok=True)
output_dir = Path(output_dir)
if shape is None:
shape = (self.resolution, self.resolution)
else:
if shape[0] % 14 != 0 or shape[1] % 14 != 0:
raise ValueError("Shape must be divisible by 14")
input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
input_names = ['input']
output_names = ['features'] if backbone_only else ['dets', 'labels']
dynamic_axes = None
self.model.eval()
with torch.no_grad():
if backbone_only:
features = model(input_tensors)
print(f"PyTorch inference output shape: {features.shape}")
else:
outputs = model(input_tensors)
dets = outputs['pred_boxes']
labels = outputs['pred_logits']
print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
model.cpu()
input_tensors = input_tensors.cpu()
# Export to ONNX
output_file = export_onnx(
output_dir=output_dir,
model=model,
input_names=input_names,
input_tensors=input_tensors,
output_names=output_names,
dynamic_axes=dynamic_axes,
backbone_only=backbone_only,
verbose=verbose,
opset_version=opset_version
)
print(f"Successfully exported ONNX model to: {output_file}")
if simplify:
sim_output_file = onnx_simplify(
onnx_dir=output_file,
input_names=input_names,
input_tensors=input_tensors,
force=force
)
print(f"Successfully simplified ONNX model to: {sim_output_file}")
print("ONNX export completed successfully")
self.model = self.model.to(device)
if __name__ == '__main__':
parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
config = vars(args) # Convert Namespace to dictionary
if args.subcommand == 'distill':
distill(**config)
elif args.subcommand is None:
main(**config)
elif args.subcommand == 'export_model':
filter_keys = [
"num_classes",
"grad_accum_steps",
"lr",
"lr_encoder",
"weight_decay",
"epochs",
"lr_drop",
"clip_max_norm",
"lr_vit_layer_decay",
"lr_component_decay",
"dropout",
"drop_path",
"drop_mode",
"drop_schedule",
"cutoff_epoch",
"pretrained_encoder",
"pretrain_weights",
"pretrain_exclude_keys",
"pretrain_keys_modify_to_load",
"freeze_florence",
"freeze_aimv2",
"decoder_norm",
"set_cost_class",
"set_cost_bbox",
"set_cost_giou",
"cls_loss_coef",
"bbox_loss_coef",
"giou_loss_coef",
"focal_alpha",
"aux_loss",
"sum_group_losses",
"use_varifocal_loss",
"use_position_supervised_loss",
"ia_bce_loss",
"dataset_file",
"coco_path",
"dataset_dir",
"square_resize_div_64",
"output_dir",
"checkpoint_interval",
"seed",
"resume",
"start_epoch",
"eval",
"use_ema",
"ema_decay",
"ema_tau",
"num_workers",
"device",
"world_size",
"dist_url",
"sync_bn",
"fp16_eval",
"infer_dir",
"verbose",
"opset_version",
"dry_run",
"shape",
]
for key in filter_keys:
config.pop(key, None) # Use pop with None to avoid KeyError
from deploy.export import main as export_main
if args.batch_size != 1:
config['batch_size'] = 1
print(f"Only batch_size 1 is supported for onnx export, \
but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.")
export_main(**config)
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--num_classes', default=2, type=int)
parser.add_argument('--grad_accum_steps', default=1, type=int)
parser.add_argument('--amp', default=False, type=bool)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_encoder', default=1.5e-4, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=12, type=int)
parser.add_argument('--lr_drop', default=11, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float)
parser.add_argument('--lr_component_decay', default=1.0, type=float)
parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model')
# drop args
# dropout and stochastic depth drop rate; set at most one to non-zero
parser.add_argument('--dropout', type=float, default=0,
help='Drop path rate (default: 0.0)')
parser.add_argument('--drop_path', type=float, default=0,
help='Drop path rate (default: 0.0)')
# early / late dropout and stochastic depth settings
parser.add_argument('--drop_mode', type=str, default='standard',
choices=['standard', 'early', 'late'], help='drop mode')
parser.add_argument('--drop_schedule', type=str, default='constant',
choices=['constant', 'linear'],
help='drop schedule for early dropout / s.d. only')
parser.add_argument('--cutoff_epoch', type=int, default=0,
help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
# Model parameters
parser.add_argument('--pretrained_encoder', type=str, default=None,
help="Path to the pretrained encoder.")
parser.add_argument('--pretrain_weights', type=str, default=None,
help="Path to the pretrained model.")
parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
help="Keys you do not want to load.")
parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+',
help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.")
# * Backbone
parser.add_argument('--encoder', default='vit_tiny', type=str,
help="Name of the transformer or convolutional encoder to use")
parser.add_argument('--vit_encoder_num_layers', default=12, type=int,
help="Number of layers used in ViT encoder")
parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+')
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now')
parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder")
parser.add_argument("--layer_norm", action="store_true", dest="layer_norm")
parser.add_argument("--rms_norm", action="store_true", dest="rms_norm")
parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora")
parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain")
# * Transformer
parser.add_argument('--dec_layers', default=3, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--sa_nheads', default=8, type=int,
help="Number of attention heads inside the transformer's self-attentions")
parser.add_argument('--ca_nheads', default=8, type=int,
help="Number of attention heads inside the transformer's cross-attentions")
parser.add_argument('--num_queries', default=300, type=int,
help="Number of query slots")
parser.add_argument('--group_detr', default=13, type=int,
help="Number of groups to speed up detr training")
parser.add_argument('--two_stage', action='store_true')
parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6'))
parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up')
parser.add_argument('--num_select', default=100, type=int,
help='the number of predictions selected for evaluation')
parser.add_argument('--dec_n_points', default=4, type=int,
help='the number of sampling points')
parser.add_argument('--decoder_norm', default='LN', type=str)
parser.add_argument('--bbox_reparam', action='store_true')
parser.add_argument('--freeze_batch_norm', action='store_true')
# * Matcher
parser.add_argument('--set_cost_class', default=2, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--cls_loss_coef', default=2, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--focal_alpha', default=0.25, type=float)
# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
help="Disables auxiliary decoding losses (loss at each layer)")
parser.add_argument('--sum_group_losses', action='store_true',
help="To sum losses across groups or mean losses.")
parser.add_argument('--use_varifocal_loss', action='store_true')
parser.add_argument('--use_position_supervised_loss', action='store_true')
parser.add_argument('--ia_bce_loss', action='store_true')
# dataset parameters
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--coco_path', type=str)
parser.add_argument('--dataset_dir', type=str)
parser.add_argument('--square_resize_div_64', action='store_true')
parser.add_argument('--output_dir', default='output',
help='path where to save, empty for no saving')
parser.add_argument('--dont_save_weights', action='store_true')
parser.add_argument('--checkpoint_interval', default=10, type=int,
help='epoch interval to save checkpoint')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--use_ema', action='store_true')
parser.add_argument('--ema_decay', default=0.9997, type=float)
parser.add_argument('--ema_tau', default=0, type=float)
parser.add_argument('--num_workers', default=2, type=int)
# distributed training parameters
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--sync_bn', default=True, type=bool,
help='setup synchronized BatchNorm for distributed training')
# fp16
parser.add_argument('--fp16_eval', default=False, action='store_true',
help='evaluate in fp16 precision.')
# custom args
parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only')
parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only')
parser.add_argument('--resolution', type=int, default=640, help="input resolution")
parser.add_argument('--use_cls_token', action='store_true', help='use cls token')
parser.add_argument('--multi_scale', action='store_true', help='use multi scale')
parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales')
parser.add_argument('--warmup_epochs', default=1, type=float,
help='Number of warmup epochs for linear warmup before cosine annealing')
# Add scheduler type argument: 'step' or 'cosine'
parser.add_argument(
'--lr_scheduler',
default='step',
choices=['step', 'cosine'],
help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'"
)
parser.add_argument('--lr_min_factor', default=0.0, type=float,
help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
# Early stopping parameters
parser.add_argument('--early_stopping', action='store_true',
help='Enable early stopping based on mAP improvement')
parser.add_argument('--early_stopping_patience', default=10, type=int,
help='Number of epochs with no improvement after which training will be stopped')
parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
help='Minimum change in mAP to qualify as an improvement')
parser.add_argument('--early_stopping_use_ema', action='store_true',
help='Use EMA model metrics for early stopping')
# subparsers
subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
description='valid subcommands', help='additional help')
# subparser for export model
parser_export = subparsers.add_parser('export_model', help='LWDETR model export')
parser_export.add_argument('--infer_dir', type=str, default=None)
parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True)
parser_export.add_argument('--opset_version', type=int, default=17)
parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model")
parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true',
help="build tensorrt engine")
parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command")
parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export')
parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)")
return parser
def populate_args(
# Basic training parameters
num_classes=2,
grad_accum_steps=1,
amp=False,
lr=1e-4,
lr_encoder=1.5e-4,
batch_size=2,
weight_decay=1e-4,
epochs=12,
lr_drop=11,
clip_max_norm=0.1,
lr_vit_layer_decay=0.8,
lr_component_decay=1.0,
do_benchmark=False,
# Drop parameters
dropout=0,
drop_path=0,
drop_mode='standard',
drop_schedule='constant',
cutoff_epoch=0,
# Model parameters
pretrained_encoder=None,
pretrain_weights=None,
pretrain_exclude_keys=None,
pretrain_keys_modify_to_load=None,
pretrained_distiller=None,
# Backbone parameters
encoder='vit_tiny',
vit_encoder_num_layers=12,
window_block_indexes=None,
position_embedding='sine',
out_feature_indexes=[-1],
freeze_encoder=False,
layer_norm=False,
rms_norm=False,
backbone_lora=False,
force_no_pretrain=False,
# Transformer parameters
dec_layers=3,
dim_feedforward=2048,
hidden_dim=256,
sa_nheads=8,
ca_nheads=8,
num_queries=300,
group_detr=13,
two_stage=False,
projector_scale='P4',
lite_refpoint_refine=False,
num_select=100,
dec_n_points=4,
decoder_norm='LN',
bbox_reparam=False,
freeze_batch_norm=False,
# Matcher parameters
set_cost_class=2,
set_cost_bbox=5,
set_cost_giou=2,
# Loss coefficients
cls_loss_coef=2,
bbox_loss_coef=5,
giou_loss_coef=2,
focal_alpha=0.25,
aux_loss=True,
sum_group_losses=False,
use_varifocal_loss=False,
use_position_supervised_loss=False,
ia_bce_loss=False,
# Dataset parameters
dataset_file='coco',
coco_path=None,
dataset_dir=None,
square_resize_div_64=False,
# Output parameters
output_dir='output',
dont_save_weights=False,
checkpoint_interval=10,
seed=42,
resume='',
start_epoch=0,
eval=False,
use_ema=False,
ema_decay=0.9997,
ema_tau=0,
num_workers=2,
# Distributed training parameters
device='cuda',
world_size=1,
dist_url='env://',
sync_bn=True,
# FP16
fp16_eval=False,
# Custom args
encoder_only=False,
backbone_only=False,
resolution=640,
use_cls_token=False,
multi_scale=False,
expanded_scales=False,
warmup_epochs=1,
lr_scheduler='step',
lr_min_factor=0.0,
# Early stopping parameters
early_stopping=True,
early_stopping_patience=10,
early_stopping_min_delta=0.001,
early_stopping_use_ema=False,
gradient_checkpointing=False,
# Additional
subcommand=None,
**extra_kwargs # To handle any unexpected arguments
):
args = argparse.Namespace(
num_classes=num_classes,
grad_accum_steps=grad_accum_steps,
amp=amp,
lr=lr,
lr_encoder=lr_encoder,
batch_size=batch_size,
weight_decay=weight_decay,
epochs=epochs,
lr_drop=lr_drop,
clip_max_norm=clip_max_norm,
lr_vit_layer_decay=lr_vit_layer_decay,
lr_component_decay=lr_component_decay,
do_benchmark=do_benchmark,
dropout=dropout,
drop_path=drop_path,
drop_mode=drop_mode,
drop_schedule=drop_schedule,
cutoff_epoch=cutoff_epoch,
pretrained_encoder=pretrained_encoder,
pretrain_weights=pretrain_weights,
pretrain_exclude_keys=pretrain_exclude_keys,
pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
pretrained_distiller=pretrained_distiller,
encoder=encoder,
vit_encoder_num_layers=vit_encoder_num_layers,
window_block_indexes=window_block_indexes,
position_embedding=position_embedding,
out_feature_indexes=out_feature_indexes,
freeze_encoder=freeze_encoder,
layer_norm=layer_norm,
rms_norm=rms_norm,
backbone_lora=backbone_lora,
force_no_pretrain=force_no_pretrain,
dec_layers=dec_layers,
dim_feedforward=dim_feedforward,
hidden_dim=hidden_dim,
sa_nheads=sa_nheads,
ca_nheads=ca_nheads,
num_queries=num_queries,
group_detr=group_detr,
two_stage=two_stage,
projector_scale=projector_scale,
lite_refpoint_refine=lite_refpoint_refine,
num_select=num_select,
dec_n_points=dec_n_points,
decoder_norm=decoder_norm,
bbox_reparam=bbox_reparam,
freeze_batch_norm=freeze_batch_norm,
set_cost_class=set_cost_class,
set_cost_bbox=set_cost_bbox,
set_cost_giou=set_cost_giou,
cls_loss_coef=cls_loss_coef,
bbox_loss_coef=bbox_loss_coef,
giou_loss_coef=giou_loss_coef,
focal_alpha=focal_alpha,
aux_loss=aux_loss,
sum_group_losses=sum_group_losses,
use_varifocal_loss=use_varifocal_loss,
use_position_supervised_loss=use_position_supervised_loss,
ia_bce_loss=ia_bce_loss,
dataset_file=dataset_file,
coco_path=coco_path,
dataset_dir=dataset_dir,
square_resize_div_64=square_resize_div_64,
output_dir=output_dir,
dont_save_weights=dont_save_weights,
checkpoint_interval=checkpoint_interval,
seed=seed,
resume=resume,
start_epoch=start_epoch,
eval=eval,
use_ema=use_ema,
ema_decay=ema_decay,
ema_tau=ema_tau,
num_workers=num_workers,
device=device,
world_size=world_size,
dist_url=dist_url,
sync_bn=sync_bn,
fp16_eval=fp16_eval,
encoder_only=encoder_only,
backbone_only=backbone_only,
resolution=resolution,
use_cls_token=use_cls_token,
multi_scale=multi_scale,
expanded_scales=expanded_scales,
warmup_epochs=warmup_epochs,
lr_scheduler=lr_scheduler,
lr_min_factor=lr_min_factor,
early_stopping=early_stopping,
early_stopping_patience=early_stopping_patience,
early_stopping_min_delta=early_stopping_min_delta,
early_stopping_use_ema=early_stopping_use_ema,
gradient_checkpointing=gradient_checkpointing,
**extra_kwargs
)
return args |