|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
Train and eval functions used in main.py
|
|
"""
|
|
import math
|
|
import sys
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
|
|
import rfdetr.util.misc as utils
|
|
from rfdetr.datasets.coco_eval import CocoEvaluator
|
|
|
|
try:
|
|
from torch.amp import autocast, GradScaler
|
|
DEPRECATED_AMP = False
|
|
except ImportError:
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
DEPRECATED_AMP = True
|
|
from typing import DefaultDict, List, Callable
|
|
from rfdetr.util.misc import NestedTensor
|
|
|
|
|
|
|
|
def get_autocast_args(args):
|
|
if DEPRECATED_AMP:
|
|
return {'enabled': args.amp, 'dtype': torch.bfloat16}
|
|
else:
|
|
return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}
|
|
|
|
|
|
def train_one_epoch(
|
|
model: torch.nn.Module,
|
|
criterion: torch.nn.Module,
|
|
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
data_loader: Iterable,
|
|
optimizer: torch.optim.Optimizer,
|
|
device: torch.device,
|
|
epoch: int,
|
|
batch_size: int,
|
|
max_norm: float = 0,
|
|
ema_m: torch.nn.Module = None,
|
|
schedules: dict = {},
|
|
num_training_steps_per_epoch=None,
|
|
vit_encoder_num_layers=None,
|
|
args=None,
|
|
callbacks: DefaultDict[str, List[Callable]] = None,
|
|
):
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
|
metric_logger.add_meter(
|
|
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
|
)
|
|
header = "Epoch: [{}]".format(epoch)
|
|
print_freq = 10
|
|
start_steps = epoch * num_training_steps_per_epoch
|
|
|
|
print("Grad accum steps: ", args.grad_accum_steps)
|
|
print("Total batch size: ", batch_size * utils.get_world_size())
|
|
|
|
|
|
if DEPRECATED_AMP:
|
|
scaler = GradScaler(enabled=args.amp)
|
|
else:
|
|
scaler = GradScaler('cuda', enabled=args.amp)
|
|
|
|
optimizer.zero_grad()
|
|
assert batch_size % args.grad_accum_steps == 0
|
|
sub_batch_size = batch_size // args.grad_accum_steps
|
|
print("LENGTH OF DATA LOADER:", len(data_loader))
|
|
for data_iter_step, (samples, targets) in enumerate(
|
|
metric_logger.log_every(data_loader, print_freq, header)
|
|
):
|
|
it = start_steps + data_iter_step
|
|
callback_dict = {
|
|
"step": it,
|
|
"model": model,
|
|
"epoch": epoch,
|
|
}
|
|
for callback in callbacks["on_train_batch_start"]:
|
|
callback(callback_dict)
|
|
if "dp" in schedules:
|
|
if args.distributed:
|
|
model.module.update_drop_path(
|
|
schedules["dp"][it], vit_encoder_num_layers
|
|
)
|
|
else:
|
|
model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers)
|
|
if "do" in schedules:
|
|
if args.distributed:
|
|
model.module.update_dropout(schedules["do"][it])
|
|
else:
|
|
model.update_dropout(schedules["do"][it])
|
|
|
|
for i in range(args.grad_accum_steps):
|
|
start_idx = i * sub_batch_size
|
|
final_idx = start_idx + sub_batch_size
|
|
new_samples_tensors = samples.tensors[start_idx:final_idx]
|
|
new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
|
|
new_samples = new_samples.to(device)
|
|
new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]
|
|
|
|
with autocast(**get_autocast_args(args)):
|
|
outputs = model(new_samples, new_targets)
|
|
loss_dict = criterion(outputs, new_targets)
|
|
weight_dict = criterion.weight_dict
|
|
losses = sum(
|
|
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
|
|
for k in loss_dict.keys()
|
|
if k in weight_dict
|
|
)
|
|
|
|
|
|
scaler.scale(losses).backward()
|
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
|
loss_dict_reduced_unscaled = {
|
|
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
|
}
|
|
loss_dict_reduced_scaled = {
|
|
k: v * weight_dict[k]
|
|
for k, v in loss_dict_reduced.items()
|
|
if k in weight_dict
|
|
}
|
|
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
|
|
|
|
loss_value = losses_reduced_scaled.item()
|
|
|
|
if not math.isfinite(loss_value):
|
|
print(loss_dict_reduced)
|
|
raise ValueError("Loss is {}, stopping training".format(loss_value))
|
|
|
|
if max_norm > 0:
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
if ema_m is not None:
|
|
if epoch >= 0:
|
|
ema_m.update(model)
|
|
metric_logger.update(
|
|
loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
|
|
)
|
|
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
|
|
|
metric_logger.synchronize_between_processes()
|
|
print("Averaged stats:", metric_logger)
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
|
|
|
|
|
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
|
|
model.eval()
|
|
if args.fp16_eval:
|
|
model.half()
|
|
criterion.eval()
|
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|
metric_logger.add_meter(
|
|
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
|
)
|
|
header = "Test:"
|
|
|
|
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
|
|
coco_evaluator = CocoEvaluator(base_ds, iou_types)
|
|
|
|
for samples, targets in metric_logger.log_every(data_loader, 10, header):
|
|
samples = samples.to(device)
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
|
if args.fp16_eval:
|
|
samples.tensors = samples.tensors.half()
|
|
|
|
|
|
with autocast(**get_autocast_args(args)):
|
|
outputs = model(samples)
|
|
|
|
if args.fp16_eval:
|
|
for key in outputs.keys():
|
|
if key == "enc_outputs":
|
|
for sub_key in outputs[key].keys():
|
|
outputs[key][sub_key] = outputs[key][sub_key].float()
|
|
elif key == "aux_outputs":
|
|
for idx in range(len(outputs[key])):
|
|
for sub_key in outputs[key][idx].keys():
|
|
outputs[key][idx][sub_key] = outputs[key][idx][
|
|
sub_key
|
|
].float()
|
|
else:
|
|
outputs[key] = outputs[key].float()
|
|
|
|
loss_dict = criterion(outputs, targets)
|
|
weight_dict = criterion.weight_dict
|
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
|
loss_dict_reduced_scaled = {
|
|
k: v * weight_dict[k]
|
|
for k, v in loss_dict_reduced.items()
|
|
if k in weight_dict
|
|
}
|
|
loss_dict_reduced_unscaled = {
|
|
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
|
}
|
|
metric_logger.update(
|
|
loss=sum(loss_dict_reduced_scaled.values()),
|
|
**loss_dict_reduced_scaled,
|
|
**loss_dict_reduced_unscaled,
|
|
)
|
|
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
|
|
|
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
|
results = postprocessors["bbox"](outputs, orig_target_sizes)
|
|
res = {
|
|
target["image_id"].item(): output
|
|
for target, output in zip(targets, results)
|
|
}
|
|
if coco_evaluator is not None:
|
|
coco_evaluator.update(res)
|
|
|
|
|
|
metric_logger.synchronize_between_processes()
|
|
print("Averaged stats:", metric_logger)
|
|
if coco_evaluator is not None:
|
|
coco_evaluator.synchronize_between_processes()
|
|
|
|
|
|
if coco_evaluator is not None:
|
|
coco_evaluator.accumulate()
|
|
coco_evaluator.summarize()
|
|
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
|
if coco_evaluator is not None:
|
|
if "bbox" in postprocessors.keys():
|
|
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
|
|
if "segm" in postprocessors.keys():
|
|
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
|
|
return stats, coco_evaluator
|
|
|