|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
cleaned main file
|
|
"""
|
|
import argparse
|
|
import ast
|
|
import copy
|
|
import datetime
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import shutil
|
|
import time
|
|
from copy import deepcopy
|
|
from logging import getLogger
|
|
from pathlib import Path
|
|
from typing import DefaultDict, List, Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
from peft import LoraConfig, get_peft_model
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
import rfdetr.util.misc as utils
|
|
from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
|
|
from rfdetr.engine import evaluate, train_one_epoch
|
|
from rfdetr.models import build_model, build_criterion_and_postprocessors
|
|
from rfdetr.util.benchmark import benchmark
|
|
from rfdetr.util.drop_scheduler import drop_scheduler
|
|
from rfdetr.util.files import download_file
|
|
from rfdetr.util.get_param_dicts import get_param_dict
|
|
from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict
|
|
|
|
if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
|
|
import torch.multiprocessing
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
HOSTED_MODELS = {
|
|
"rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
|
|
|
|
"rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
|
|
"rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
|
|
}
|
|
|
|
def download_pretrain_weights(pretrain_weights: str, redownload=False):
|
|
if pretrain_weights in HOSTED_MODELS:
|
|
if redownload or not os.path.exists(pretrain_weights):
|
|
logger.info(
|
|
f"Downloading pretrained weights for {pretrain_weights}"
|
|
)
|
|
download_file(
|
|
HOSTED_MODELS[pretrain_weights],
|
|
pretrain_weights,
|
|
)
|
|
|
|
class Model:
|
|
def __init__(self, **kwargs):
|
|
args = populate_args(**kwargs)
|
|
self.resolution = args.resolution
|
|
self.model = build_model(args)
|
|
self.device = torch.device(args.device)
|
|
if args.pretrain_weights is not None:
|
|
print("Loading pretrain weights")
|
|
try:
|
|
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
|
|
except Exception as e:
|
|
print(f"Failed to load pretrain weights: {e}")
|
|
|
|
print("Failed to load pretrain weights, re-downloading")
|
|
download_pretrain_weights(args.pretrain_weights, redownload=True)
|
|
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
|
|
|
|
|
|
if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'):
|
|
self.class_names = checkpoint['args'].class_names
|
|
|
|
checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
|
|
if checkpoint_num_classes != args.num_classes + 1:
|
|
logger.warning(
|
|
f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n"
|
|
f"reinitializing detection head with {checkpoint_num_classes - 1} classes"
|
|
)
|
|
self.reinitialize_detection_head(checkpoint_num_classes)
|
|
|
|
|
|
if args.pretrain_exclude_keys is not None:
|
|
assert isinstance(args.pretrain_exclude_keys, list)
|
|
for exclude_key in args.pretrain_exclude_keys:
|
|
checkpoint['model'].pop(exclude_key)
|
|
if args.pretrain_keys_modify_to_load is not None:
|
|
from util.obj365_to_coco_model import get_coco_pretrain_from_obj365
|
|
assert isinstance(args.pretrain_keys_modify_to_load, list)
|
|
for modify_key_to_load in args.pretrain_keys_modify_to_load:
|
|
try:
|
|
checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365(
|
|
model_without_ddp.state_dict()[modify_key_to_load],
|
|
checkpoint['model'][modify_key_to_load]
|
|
)
|
|
except:
|
|
print(f"Failed to load {modify_key_to_load}, deleting from checkpoint")
|
|
checkpoint['model'].pop(modify_key_to_load)
|
|
|
|
|
|
num_desired_queries = args.num_queries * args.group_detr
|
|
query_param_names = ["refpoint_embed.weight", "query_feat.weight"]
|
|
for name, state in checkpoint['model'].items():
|
|
if any(name.endswith(x) for x in query_param_names):
|
|
checkpoint['model'][name] = state[:num_desired_queries]
|
|
|
|
self.model.load_state_dict(checkpoint['model'], strict=False)
|
|
|
|
if args.backbone_lora:
|
|
print("Applying LORA to backbone")
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=16,
|
|
use_dora=True,
|
|
target_modules=[
|
|
"q_proj", "v_proj", "k_proj",
|
|
"qkv",
|
|
"query", "key", "value", "cls_token", "register_tokens",
|
|
]
|
|
)
|
|
self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
|
|
self.model = self.model.to(self.device)
|
|
self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
|
|
self.stop_early = False
|
|
|
|
def reinitialize_detection_head(self, num_classes):
|
|
self.model.reinitialize_detection_head(num_classes)
|
|
|
|
def request_early_stop(self):
|
|
self.stop_early = True
|
|
print("Early stopping requested, will complete current epoch and stop")
|
|
|
|
def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
|
|
currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
|
|
for key in callbacks.keys():
|
|
if key not in currently_supported_callbacks:
|
|
raise ValueError(
|
|
f"Callback {key} is not currently supported, please file an issue if you need it!\n"
|
|
f"Currently supported callbacks: {currently_supported_callbacks}"
|
|
)
|
|
args = populate_args(**kwargs)
|
|
utils.init_distributed_mode(args)
|
|
print("git:\n {}\n".format(utils.get_sha()))
|
|
print(args)
|
|
device = torch.device(args.device)
|
|
|
|
|
|
seed = args.seed + utils.get_rank()
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
criterion, postprocessors = build_criterion_and_postprocessors(args)
|
|
model = self.model
|
|
model.to(device)
|
|
|
|
model_without_ddp = model
|
|
if args.distributed:
|
|
if args.sync_bn:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
|
model_without_ddp = model.module
|
|
|
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print('number of params:', n_parameters)
|
|
param_dicts = get_param_dict(args, model_without_ddp)
|
|
|
|
param_dicts = [p for p in param_dicts if p['params'].requires_grad]
|
|
|
|
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
|
weight_decay=args.weight_decay)
|
|
|
|
|
|
dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
|
|
dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)
|
|
|
|
|
|
total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps
|
|
num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr
|
|
total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs
|
|
warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs
|
|
def lr_lambda(current_step: int):
|
|
if current_step < warmup_steps_lr:
|
|
|
|
return float(current_step) / float(max(1, warmup_steps_lr))
|
|
else:
|
|
|
|
if args.lr_scheduler == 'cosine':
|
|
progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr))
|
|
return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress))
|
|
elif args.lr_scheduler == 'step':
|
|
if current_step < args.lr_drop * num_training_steps_per_epoch_lr:
|
|
return 1.0
|
|
else:
|
|
return 0.1
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
|
|
|
if args.distributed:
|
|
sampler_train = DistributedSampler(dataset_train)
|
|
sampler_val = DistributedSampler(dataset_val, shuffle=False)
|
|
else:
|
|
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
|
|
|
effective_batch_size = args.batch_size * args.grad_accum_steps
|
|
min_batches = kwargs.get('min_batches', 5)
|
|
if len(dataset_train) < effective_batch_size * min_batches:
|
|
logger.info(
|
|
f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}"
|
|
)
|
|
sampler = torch.utils.data.RandomSampler(
|
|
dataset_train,
|
|
replacement=True,
|
|
num_samples=effective_batch_size * min_batches,
|
|
)
|
|
data_loader_train = DataLoader(
|
|
dataset_train,
|
|
batch_size=effective_batch_size,
|
|
collate_fn=utils.collate_fn,
|
|
num_workers=args.num_workers,
|
|
sampler=sampler,
|
|
)
|
|
else:
|
|
batch_sampler_train = torch.utils.data.BatchSampler(
|
|
sampler_train, effective_batch_size, drop_last=True)
|
|
data_loader_train = DataLoader(
|
|
dataset_train,
|
|
batch_sampler=batch_sampler_train,
|
|
collate_fn=utils.collate_fn,
|
|
num_workers=args.num_workers
|
|
)
|
|
|
|
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
|
|
drop_last=False, collate_fn=utils.collate_fn,
|
|
num_workers=args.num_workers)
|
|
|
|
base_ds = get_coco_api_from_dataset(dataset_val)
|
|
|
|
if args.use_ema:
|
|
self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau)
|
|
else:
|
|
self.ema_m = None
|
|
|
|
|
|
output_dir = Path(args.output_dir)
|
|
|
|
if utils.is_main_process():
|
|
print("Get benchmark")
|
|
if args.do_benchmark:
|
|
benchmark_model = copy.deepcopy(model_without_ddp)
|
|
bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
|
|
print(json.dumps(bm, indent=2))
|
|
del benchmark_model
|
|
|
|
if args.resume:
|
|
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
|
model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
|
|
if args.use_ema:
|
|
if 'ema_model' in checkpoint:
|
|
self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
|
|
else:
|
|
del self.ema_m
|
|
self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau)
|
|
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
|
args.start_epoch = checkpoint['epoch'] + 1
|
|
|
|
if args.eval:
|
|
test_stats, coco_evaluator = evaluate(
|
|
model, criterion, postprocessors, data_loader_val, base_ds, device, args)
|
|
if args.output_dir:
|
|
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
|
|
return
|
|
|
|
|
|
total_batch_size = effective_batch_size * utils.get_world_size()
|
|
num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size
|
|
schedules = {}
|
|
if args.dropout > 0:
|
|
schedules['do'] = drop_scheduler(
|
|
args.dropout, args.epochs, num_training_steps_per_epoch,
|
|
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
|
|
print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
|
|
|
|
if args.drop_path > 0:
|
|
schedules['dp'] = drop_scheduler(
|
|
args.drop_path, args.epochs, num_training_steps_per_epoch,
|
|
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
|
|
print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
|
|
|
|
print("Start training")
|
|
start_time = time.time()
|
|
best_map_holder = BestMetricHolder(use_ema=args.use_ema)
|
|
best_map_5095 = 0
|
|
best_map_50 = 0
|
|
best_map_ema_5095 = 0
|
|
best_map_ema_50 = 0
|
|
for epoch in range(args.start_epoch, args.epochs):
|
|
epoch_start_time = time.time()
|
|
if args.distributed:
|
|
sampler_train.set_epoch(epoch)
|
|
|
|
model.train()
|
|
criterion.train()
|
|
train_stats = train_one_epoch(
|
|
model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch,
|
|
effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules,
|
|
num_training_steps_per_epoch=num_training_steps_per_epoch,
|
|
vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks)
|
|
train_epoch_time = time.time() - epoch_start_time
|
|
train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time)))
|
|
if args.output_dir:
|
|
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
|
|
|
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
|
|
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
|
|
for checkpoint_path in checkpoint_paths:
|
|
weights = {
|
|
'model': model_without_ddp.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'lr_scheduler': lr_scheduler.state_dict(),
|
|
'epoch': epoch,
|
|
'args': args,
|
|
}
|
|
if args.use_ema:
|
|
weights.update({
|
|
'ema_model': self.ema_m.module.state_dict(),
|
|
})
|
|
if not args.dont_save_weights:
|
|
|
|
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
utils.save_on_master(weights, checkpoint_path)
|
|
|
|
with torch.inference_mode():
|
|
test_stats, coco_evaluator = evaluate(
|
|
model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
|
|
)
|
|
|
|
map_regular = test_stats['coco_eval_bbox'][0]
|
|
_isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
|
|
if _isbest:
|
|
best_map_5095 = max(best_map_5095, map_regular)
|
|
best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1])
|
|
checkpoint_path = output_dir / 'checkpoint0009.pth'
|
|
if not args.dont_save_weights:
|
|
utils.save_on_master({
|
|
'model': model_without_ddp.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'lr_scheduler': lr_scheduler.state_dict(),
|
|
'epoch': epoch,
|
|
'args': args,
|
|
}, checkpoint_path)
|
|
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
|
**{f'test_{k}': v for k, v in test_stats.items()},
|
|
'epoch': epoch,
|
|
'n_parameters': n_parameters}
|
|
if args.use_ema:
|
|
ema_test_stats, _ = evaluate(
|
|
self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
|
|
)
|
|
log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
|
|
map_ema = ema_test_stats['coco_eval_bbox'][0]
|
|
best_map_ema_5095 = max(best_map_ema_5095, map_ema)
|
|
_isbest = best_map_holder.update(map_ema, epoch, is_ema=True)
|
|
if _isbest:
|
|
best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1])
|
|
checkpoint_path = output_dir / 'checkpoint_best_ema.pth'
|
|
if not args.dont_save_weights:
|
|
utils.save_on_master({
|
|
'model': self.ema_m.module.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'lr_scheduler': lr_scheduler.state_dict(),
|
|
'epoch': epoch,
|
|
'args': args,
|
|
}, checkpoint_path)
|
|
log_stats.update(best_map_holder.summary())
|
|
|
|
|
|
ep_paras = {
|
|
'epoch': epoch,
|
|
'n_parameters': n_parameters
|
|
}
|
|
log_stats.update(ep_paras)
|
|
try:
|
|
log_stats.update({'now_time': str(datetime.datetime.now())})
|
|
except:
|
|
pass
|
|
log_stats['train_epoch_time'] = train_epoch_time_str
|
|
epoch_time = time.time() - epoch_start_time
|
|
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
|
|
log_stats['epoch_time'] = epoch_time_str
|
|
if args.output_dir and utils.is_main_process():
|
|
with (output_dir / "log.txt").open("a") as f:
|
|
f.write(json.dumps(log_stats) + "\n")
|
|
|
|
|
|
if coco_evaluator is not None:
|
|
(output_dir / 'eval').mkdir(exist_ok=True)
|
|
if "bbox" in coco_evaluator.coco_eval:
|
|
filenames = ['latest.pth']
|
|
if epoch % 50 == 0:
|
|
filenames.append(f'{epoch:03}.pth')
|
|
for name in filenames:
|
|
torch.save(coco_evaluator.coco_eval["bbox"].eval,
|
|
output_dir / "eval" / name)
|
|
|
|
for callback in callbacks["on_fit_epoch_end"]:
|
|
callback(log_stats)
|
|
|
|
if self.stop_early:
|
|
print(f"Early stopping requested, stopping at epoch {epoch}")
|
|
break
|
|
|
|
best_is_ema = best_map_ema_5095 > best_map_5095
|
|
|
|
if utils.is_main_process():
|
|
if best_is_ema:
|
|
shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
|
|
else:
|
|
shutil.copy2(output_dir / 'checkpoint0009.pth', output_dir / 'checkpoint_best_total.pth')
|
|
|
|
utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth')
|
|
|
|
best_map_5095 = max(best_map_5095, best_map_ema_5095)
|
|
best_map_50 = max(best_map_50, best_map_ema_50)
|
|
|
|
results_json = {
|
|
"map95": best_map_5095,
|
|
"map50": best_map_50,
|
|
"class": "all"
|
|
}
|
|
results = {
|
|
"class_map": {
|
|
"valid": [
|
|
results_json
|
|
]
|
|
}
|
|
}
|
|
with open(output_dir / "results.json", "w") as f:
|
|
json.dump(results, f)
|
|
|
|
total_time = time.time() - start_time
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
print('Training time {}'.format(total_time_str))
|
|
print('Results saved to {}'.format(output_dir / "results.json"))
|
|
|
|
if best_is_ema:
|
|
self.model = self.ema_m.module
|
|
self.model.eval()
|
|
|
|
for callback in callbacks["on_train_end"]:
|
|
callback()
|
|
|
|
def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs):
|
|
"""Export the trained model to ONNX format"""
|
|
print(f"Exporting model to ONNX format")
|
|
try:
|
|
from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image
|
|
except ImportError:
|
|
print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.")
|
|
raise
|
|
|
|
|
|
device = self.device
|
|
model = deepcopy(self.model.to("cpu"))
|
|
model.to(device)
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_dir = Path(output_dir)
|
|
if shape is None:
|
|
shape = (self.resolution, self.resolution)
|
|
else:
|
|
if shape[0] % 14 != 0 or shape[1] % 14 != 0:
|
|
raise ValueError("Shape must be divisible by 14")
|
|
|
|
input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
|
|
input_names = ['input']
|
|
output_names = ['features'] if backbone_only else ['dets', 'labels']
|
|
dynamic_axes = None
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
if backbone_only:
|
|
features = model(input_tensors)
|
|
print(f"PyTorch inference output shape: {features.shape}")
|
|
else:
|
|
outputs = model(input_tensors)
|
|
dets = outputs['pred_boxes']
|
|
labels = outputs['pred_logits']
|
|
print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
|
|
model.cpu()
|
|
input_tensors = input_tensors.cpu()
|
|
|
|
|
|
output_file = export_onnx(
|
|
output_dir=output_dir,
|
|
model=model,
|
|
input_names=input_names,
|
|
input_tensors=input_tensors,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
backbone_only=backbone_only,
|
|
verbose=verbose,
|
|
opset_version=opset_version
|
|
)
|
|
|
|
print(f"Successfully exported ONNX model to: {output_file}")
|
|
|
|
if simplify:
|
|
sim_output_file = onnx_simplify(
|
|
onnx_dir=output_file,
|
|
input_names=input_names,
|
|
input_tensors=input_tensors,
|
|
force=force
|
|
)
|
|
print(f"Successfully simplified ONNX model to: {sim_output_file}")
|
|
|
|
print("ONNX export completed successfully")
|
|
self.model = self.model.to(device)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()])
|
|
args = parser.parse_args()
|
|
|
|
if args.output_dir:
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
config = vars(args)
|
|
|
|
if args.subcommand == 'distill':
|
|
distill(**config)
|
|
elif args.subcommand is None:
|
|
main(**config)
|
|
elif args.subcommand == 'export_model':
|
|
filter_keys = [
|
|
"num_classes",
|
|
"grad_accum_steps",
|
|
"lr",
|
|
"lr_encoder",
|
|
"weight_decay",
|
|
"epochs",
|
|
"lr_drop",
|
|
"clip_max_norm",
|
|
"lr_vit_layer_decay",
|
|
"lr_component_decay",
|
|
"dropout",
|
|
"drop_path",
|
|
"drop_mode",
|
|
"drop_schedule",
|
|
"cutoff_epoch",
|
|
"pretrained_encoder",
|
|
"pretrain_weights",
|
|
"pretrain_exclude_keys",
|
|
"pretrain_keys_modify_to_load",
|
|
"freeze_florence",
|
|
"freeze_aimv2",
|
|
"decoder_norm",
|
|
"set_cost_class",
|
|
"set_cost_bbox",
|
|
"set_cost_giou",
|
|
"cls_loss_coef",
|
|
"bbox_loss_coef",
|
|
"giou_loss_coef",
|
|
"focal_alpha",
|
|
"aux_loss",
|
|
"sum_group_losses",
|
|
"use_varifocal_loss",
|
|
"use_position_supervised_loss",
|
|
"ia_bce_loss",
|
|
"dataset_file",
|
|
"coco_path",
|
|
"dataset_dir",
|
|
"square_resize_div_64",
|
|
"output_dir",
|
|
"checkpoint_interval",
|
|
"seed",
|
|
"resume",
|
|
"start_epoch",
|
|
"eval",
|
|
"use_ema",
|
|
"ema_decay",
|
|
"ema_tau",
|
|
"num_workers",
|
|
"device",
|
|
"world_size",
|
|
"dist_url",
|
|
"sync_bn",
|
|
"fp16_eval",
|
|
"infer_dir",
|
|
"verbose",
|
|
"opset_version",
|
|
"dry_run",
|
|
"shape",
|
|
]
|
|
for key in filter_keys:
|
|
config.pop(key, None)
|
|
|
|
from deploy.export import main as export_main
|
|
if args.batch_size != 1:
|
|
config['batch_size'] = 1
|
|
print(f"Only batch_size 1 is supported for onnx export, \
|
|
but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.")
|
|
export_main(**config)
|
|
|
|
def get_args_parser():
|
|
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
|
parser.add_argument('--num_classes', default=2, type=int)
|
|
parser.add_argument('--grad_accum_steps', default=1, type=int)
|
|
parser.add_argument('--amp', default=False, type=bool)
|
|
parser.add_argument('--lr', default=1e-4, type=float)
|
|
parser.add_argument('--lr_encoder', default=1.5e-4, type=float)
|
|
parser.add_argument('--batch_size', default=2, type=int)
|
|
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
|
parser.add_argument('--epochs', default=12, type=int)
|
|
parser.add_argument('--lr_drop', default=11, type=int)
|
|
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
|
help='gradient clipping max norm')
|
|
parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float)
|
|
parser.add_argument('--lr_component_decay', default=1.0, type=float)
|
|
parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model')
|
|
|
|
|
|
|
|
parser.add_argument('--dropout', type=float, default=0,
|
|
help='Drop path rate (default: 0.0)')
|
|
parser.add_argument('--drop_path', type=float, default=0,
|
|
help='Drop path rate (default: 0.0)')
|
|
|
|
|
|
parser.add_argument('--drop_mode', type=str, default='standard',
|
|
choices=['standard', 'early', 'late'], help='drop mode')
|
|
parser.add_argument('--drop_schedule', type=str, default='constant',
|
|
choices=['constant', 'linear'],
|
|
help='drop schedule for early dropout / s.d. only')
|
|
parser.add_argument('--cutoff_epoch', type=int, default=0,
|
|
help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
|
|
|
|
|
|
parser.add_argument('--pretrained_encoder', type=str, default=None,
|
|
help="Path to the pretrained encoder.")
|
|
parser.add_argument('--pretrain_weights', type=str, default=None,
|
|
help="Path to the pretrained model.")
|
|
parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
|
|
help="Keys you do not want to load.")
|
|
parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+',
|
|
help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.")
|
|
|
|
|
|
parser.add_argument('--encoder', default='vit_tiny', type=str,
|
|
help="Name of the transformer or convolutional encoder to use")
|
|
parser.add_argument('--vit_encoder_num_layers', default=12, type=int,
|
|
help="Number of layers used in ViT encoder")
|
|
parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+')
|
|
parser.add_argument('--position_embedding', default='sine', type=str,
|
|
choices=('sine', 'learned'),
|
|
help="Type of positional embedding to use on top of the image features")
|
|
parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now')
|
|
parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder")
|
|
parser.add_argument("--layer_norm", action="store_true", dest="layer_norm")
|
|
parser.add_argument("--rms_norm", action="store_true", dest="rms_norm")
|
|
parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora")
|
|
parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain")
|
|
|
|
|
|
parser.add_argument('--dec_layers', default=3, type=int,
|
|
help="Number of decoding layers in the transformer")
|
|
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
|
help="Intermediate size of the feedforward layers in the transformer blocks")
|
|
parser.add_argument('--hidden_dim', default=256, type=int,
|
|
help="Size of the embeddings (dimension of the transformer)")
|
|
parser.add_argument('--sa_nheads', default=8, type=int,
|
|
help="Number of attention heads inside the transformer's self-attentions")
|
|
parser.add_argument('--ca_nheads', default=8, type=int,
|
|
help="Number of attention heads inside the transformer's cross-attentions")
|
|
parser.add_argument('--num_queries', default=300, type=int,
|
|
help="Number of query slots")
|
|
parser.add_argument('--group_detr', default=13, type=int,
|
|
help="Number of groups to speed up detr training")
|
|
parser.add_argument('--two_stage', action='store_true')
|
|
parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6'))
|
|
parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up')
|
|
parser.add_argument('--num_select', default=100, type=int,
|
|
help='the number of predictions selected for evaluation')
|
|
parser.add_argument('--dec_n_points', default=4, type=int,
|
|
help='the number of sampling points')
|
|
parser.add_argument('--decoder_norm', default='LN', type=str)
|
|
parser.add_argument('--bbox_reparam', action='store_true')
|
|
parser.add_argument('--freeze_batch_norm', action='store_true')
|
|
|
|
parser.add_argument('--set_cost_class', default=2, type=float,
|
|
help="Class coefficient in the matching cost")
|
|
parser.add_argument('--set_cost_bbox', default=5, type=float,
|
|
help="L1 box coefficient in the matching cost")
|
|
parser.add_argument('--set_cost_giou', default=2, type=float,
|
|
help="giou box coefficient in the matching cost")
|
|
|
|
|
|
parser.add_argument('--cls_loss_coef', default=2, type=float)
|
|
parser.add_argument('--bbox_loss_coef', default=5, type=float)
|
|
parser.add_argument('--giou_loss_coef', default=2, type=float)
|
|
parser.add_argument('--focal_alpha', default=0.25, type=float)
|
|
|
|
|
|
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
|
|
help="Disables auxiliary decoding losses (loss at each layer)")
|
|
parser.add_argument('--sum_group_losses', action='store_true',
|
|
help="To sum losses across groups or mean losses.")
|
|
parser.add_argument('--use_varifocal_loss', action='store_true')
|
|
parser.add_argument('--use_position_supervised_loss', action='store_true')
|
|
parser.add_argument('--ia_bce_loss', action='store_true')
|
|
|
|
|
|
parser.add_argument('--dataset_file', default='coco')
|
|
parser.add_argument('--coco_path', type=str)
|
|
parser.add_argument('--dataset_dir', type=str)
|
|
parser.add_argument('--square_resize_div_64', action='store_true')
|
|
|
|
parser.add_argument('--output_dir', default='output',
|
|
help='path where to save, empty for no saving')
|
|
parser.add_argument('--dont_save_weights', action='store_true')
|
|
parser.add_argument('--checkpoint_interval', default=10, type=int,
|
|
help='epoch interval to save checkpoint')
|
|
parser.add_argument('--seed', default=42, type=int)
|
|
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
|
help='start epoch')
|
|
parser.add_argument('--eval', action='store_true')
|
|
parser.add_argument('--use_ema', action='store_true')
|
|
parser.add_argument('--ema_decay', default=0.9997, type=float)
|
|
parser.add_argument('--ema_tau', default=0, type=float)
|
|
|
|
parser.add_argument('--num_workers', default=2, type=int)
|
|
|
|
|
|
parser.add_argument('--device', default='cuda',
|
|
help='device to use for training / testing')
|
|
parser.add_argument('--world_size', default=1, type=int,
|
|
help='number of distributed processes')
|
|
parser.add_argument('--dist_url', default='env://',
|
|
help='url used to set up distributed training')
|
|
parser.add_argument('--sync_bn', default=True, type=bool,
|
|
help='setup synchronized BatchNorm for distributed training')
|
|
|
|
|
|
parser.add_argument('--fp16_eval', default=False, action='store_true',
|
|
help='evaluate in fp16 precision.')
|
|
|
|
|
|
parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only')
|
|
parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only')
|
|
parser.add_argument('--resolution', type=int, default=640, help="input resolution")
|
|
parser.add_argument('--use_cls_token', action='store_true', help='use cls token')
|
|
parser.add_argument('--multi_scale', action='store_true', help='use multi scale')
|
|
parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales')
|
|
parser.add_argument('--warmup_epochs', default=1, type=float,
|
|
help='Number of warmup epochs for linear warmup before cosine annealing')
|
|
|
|
parser.add_argument(
|
|
'--lr_scheduler',
|
|
default='step',
|
|
choices=['step', 'cosine'],
|
|
help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'"
|
|
)
|
|
parser.add_argument('--lr_min_factor', default=0.0, type=float,
|
|
help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
|
|
|
|
parser.add_argument('--early_stopping', action='store_true',
|
|
help='Enable early stopping based on mAP improvement')
|
|
parser.add_argument('--early_stopping_patience', default=10, type=int,
|
|
help='Number of epochs with no improvement after which training will be stopped')
|
|
parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
|
|
help='Minimum change in mAP to qualify as an improvement')
|
|
parser.add_argument('--early_stopping_use_ema', action='store_true',
|
|
help='Use EMA model metrics for early stopping')
|
|
|
|
subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
|
|
description='valid subcommands', help='additional help')
|
|
|
|
|
|
parser_export = subparsers.add_parser('export_model', help='LWDETR model export')
|
|
parser_export.add_argument('--infer_dir', type=str, default=None)
|
|
parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True)
|
|
parser_export.add_argument('--opset_version', type=int, default=17)
|
|
parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model")
|
|
parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true',
|
|
help="build tensorrt engine")
|
|
parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command")
|
|
parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export')
|
|
parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)")
|
|
return parser
|
|
|
|
def populate_args(
|
|
|
|
num_classes=2,
|
|
grad_accum_steps=1,
|
|
amp=False,
|
|
lr=1e-4,
|
|
lr_encoder=1.5e-4,
|
|
batch_size=2,
|
|
weight_decay=1e-4,
|
|
epochs=12,
|
|
lr_drop=11,
|
|
clip_max_norm=0.1,
|
|
lr_vit_layer_decay=0.8,
|
|
lr_component_decay=1.0,
|
|
do_benchmark=False,
|
|
|
|
|
|
dropout=0,
|
|
drop_path=0,
|
|
drop_mode='standard',
|
|
drop_schedule='constant',
|
|
cutoff_epoch=0,
|
|
|
|
|
|
pretrained_encoder=None,
|
|
pretrain_weights=None,
|
|
pretrain_exclude_keys=None,
|
|
pretrain_keys_modify_to_load=None,
|
|
pretrained_distiller=None,
|
|
|
|
|
|
encoder='vit_tiny',
|
|
vit_encoder_num_layers=12,
|
|
window_block_indexes=None,
|
|
position_embedding='sine',
|
|
out_feature_indexes=[-1],
|
|
freeze_encoder=False,
|
|
layer_norm=False,
|
|
rms_norm=False,
|
|
backbone_lora=False,
|
|
force_no_pretrain=False,
|
|
|
|
|
|
dec_layers=3,
|
|
dim_feedforward=2048,
|
|
hidden_dim=256,
|
|
sa_nheads=8,
|
|
ca_nheads=8,
|
|
num_queries=300,
|
|
group_detr=13,
|
|
two_stage=False,
|
|
projector_scale='P4',
|
|
lite_refpoint_refine=False,
|
|
num_select=100,
|
|
dec_n_points=4,
|
|
decoder_norm='LN',
|
|
bbox_reparam=False,
|
|
freeze_batch_norm=False,
|
|
|
|
|
|
set_cost_class=2,
|
|
set_cost_bbox=5,
|
|
set_cost_giou=2,
|
|
|
|
|
|
cls_loss_coef=2,
|
|
bbox_loss_coef=5,
|
|
giou_loss_coef=2,
|
|
focal_alpha=0.25,
|
|
aux_loss=True,
|
|
sum_group_losses=False,
|
|
use_varifocal_loss=False,
|
|
use_position_supervised_loss=False,
|
|
ia_bce_loss=False,
|
|
|
|
|
|
dataset_file='coco',
|
|
coco_path=None,
|
|
dataset_dir=None,
|
|
square_resize_div_64=False,
|
|
|
|
|
|
output_dir='output',
|
|
dont_save_weights=False,
|
|
checkpoint_interval=10,
|
|
seed=42,
|
|
resume='',
|
|
start_epoch=0,
|
|
eval=False,
|
|
use_ema=False,
|
|
ema_decay=0.9997,
|
|
ema_tau=0,
|
|
num_workers=2,
|
|
|
|
|
|
device='cuda',
|
|
world_size=1,
|
|
dist_url='env://',
|
|
sync_bn=True,
|
|
|
|
|
|
fp16_eval=False,
|
|
|
|
|
|
encoder_only=False,
|
|
backbone_only=False,
|
|
resolution=640,
|
|
use_cls_token=False,
|
|
multi_scale=False,
|
|
expanded_scales=False,
|
|
warmup_epochs=1,
|
|
lr_scheduler='step',
|
|
lr_min_factor=0.0,
|
|
|
|
early_stopping=True,
|
|
early_stopping_patience=10,
|
|
early_stopping_min_delta=0.001,
|
|
early_stopping_use_ema=False,
|
|
gradient_checkpointing=False,
|
|
|
|
subcommand=None,
|
|
**extra_kwargs
|
|
):
|
|
args = argparse.Namespace(
|
|
num_classes=num_classes,
|
|
grad_accum_steps=grad_accum_steps,
|
|
amp=amp,
|
|
lr=lr,
|
|
lr_encoder=lr_encoder,
|
|
batch_size=batch_size,
|
|
weight_decay=weight_decay,
|
|
epochs=epochs,
|
|
lr_drop=lr_drop,
|
|
clip_max_norm=clip_max_norm,
|
|
lr_vit_layer_decay=lr_vit_layer_decay,
|
|
lr_component_decay=lr_component_decay,
|
|
do_benchmark=do_benchmark,
|
|
dropout=dropout,
|
|
drop_path=drop_path,
|
|
drop_mode=drop_mode,
|
|
drop_schedule=drop_schedule,
|
|
cutoff_epoch=cutoff_epoch,
|
|
pretrained_encoder=pretrained_encoder,
|
|
pretrain_weights=pretrain_weights,
|
|
pretrain_exclude_keys=pretrain_exclude_keys,
|
|
pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
|
|
pretrained_distiller=pretrained_distiller,
|
|
encoder=encoder,
|
|
vit_encoder_num_layers=vit_encoder_num_layers,
|
|
window_block_indexes=window_block_indexes,
|
|
position_embedding=position_embedding,
|
|
out_feature_indexes=out_feature_indexes,
|
|
freeze_encoder=freeze_encoder,
|
|
layer_norm=layer_norm,
|
|
rms_norm=rms_norm,
|
|
backbone_lora=backbone_lora,
|
|
force_no_pretrain=force_no_pretrain,
|
|
dec_layers=dec_layers,
|
|
dim_feedforward=dim_feedforward,
|
|
hidden_dim=hidden_dim,
|
|
sa_nheads=sa_nheads,
|
|
ca_nheads=ca_nheads,
|
|
num_queries=num_queries,
|
|
group_detr=group_detr,
|
|
two_stage=two_stage,
|
|
projector_scale=projector_scale,
|
|
lite_refpoint_refine=lite_refpoint_refine,
|
|
num_select=num_select,
|
|
dec_n_points=dec_n_points,
|
|
decoder_norm=decoder_norm,
|
|
bbox_reparam=bbox_reparam,
|
|
freeze_batch_norm=freeze_batch_norm,
|
|
set_cost_class=set_cost_class,
|
|
set_cost_bbox=set_cost_bbox,
|
|
set_cost_giou=set_cost_giou,
|
|
cls_loss_coef=cls_loss_coef,
|
|
bbox_loss_coef=bbox_loss_coef,
|
|
giou_loss_coef=giou_loss_coef,
|
|
focal_alpha=focal_alpha,
|
|
aux_loss=aux_loss,
|
|
sum_group_losses=sum_group_losses,
|
|
use_varifocal_loss=use_varifocal_loss,
|
|
use_position_supervised_loss=use_position_supervised_loss,
|
|
ia_bce_loss=ia_bce_loss,
|
|
dataset_file=dataset_file,
|
|
coco_path=coco_path,
|
|
dataset_dir=dataset_dir,
|
|
square_resize_div_64=square_resize_div_64,
|
|
output_dir=output_dir,
|
|
dont_save_weights=dont_save_weights,
|
|
checkpoint_interval=checkpoint_interval,
|
|
seed=seed,
|
|
resume=resume,
|
|
start_epoch=start_epoch,
|
|
eval=eval,
|
|
use_ema=use_ema,
|
|
ema_decay=ema_decay,
|
|
ema_tau=ema_tau,
|
|
num_workers=num_workers,
|
|
device=device,
|
|
world_size=world_size,
|
|
dist_url=dist_url,
|
|
sync_bn=sync_bn,
|
|
fp16_eval=fp16_eval,
|
|
encoder_only=encoder_only,
|
|
backbone_only=backbone_only,
|
|
resolution=resolution,
|
|
use_cls_token=use_cls_token,
|
|
multi_scale=multi_scale,
|
|
expanded_scales=expanded_scales,
|
|
warmup_epochs=warmup_epochs,
|
|
lr_scheduler=lr_scheduler,
|
|
lr_min_factor=lr_min_factor,
|
|
early_stopping=early_stopping,
|
|
early_stopping_patience=early_stopping_patience,
|
|
early_stopping_min_delta=early_stopping_min_delta,
|
|
early_stopping_use_ema=early_stopping_use_ema,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
**extra_kwargs
|
|
)
|
|
return args |